From 7cb9cb3c74d2384b2ed77b293369331ea13059d1 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 15 Aug 2024 11:32:29 +0200 Subject: [PATCH 01/35] Testing rdkit functions --- chebai/molecule.py | 95 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/chebai/molecule.py b/chebai/molecule.py index db191976..33c05f17 100644 --- a/chebai/molecule.py +++ b/chebai/molecule.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division import logging +from itertools import product from typing import Any, List, Optional, Union try: @@ -66,7 +67,7 @@ class Molecule: max_number_of_parents = 7 def __init__( - self, smile: str, logp: Optional[float] = None, contract_rings: bool = False + self, smile: str, logp: Optional[float] = None, contract_rings: bool = False ): """ Initializes a Molecule object. @@ -400,8 +401,8 @@ def num_of_features() -> int: int: Total number of features. """ return ( - Molecule.max_number_of_parents * Molecule.num_bond_features() - + Molecule.num_atom_features() + Molecule.max_number_of_parents * Molecule.num_bond_features() + + Molecule.num_atom_features() ) @staticmethod @@ -467,7 +468,95 @@ def num_bond_features() -> int: return len(Molecule.bond_features(simple_mol.GetBonds()[0])) + @staticmethod + def find_smile(): + original_smiles = "OC(=O)C(C(N)C(O)=O)C" + mol = Chem.MolFromSmiles(original_smiles) + smile=Chem.MolToSmiles(mol,doRandom=True,rootedAtAtom=4) + return smile + + + + @staticmethod + def find_smile1(): + original_smiles = "OC(=O)C(C(N)C(O)=O)C" + mol = Chem.MolFromSmiles(original_smiles) + + # Test combinations of doRandom and rootedAtAtom + for do_random in [True, False]: + for rooted_at_atom in [4, 3, 2, 1, 0, -1 - 2, -3, -4, -5]: + try: + smiles = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) + print(f"Configuration: doRandom={do_random}, rootedAtAtom={rooted_at_atom}\n{smiles}\n") + except Exception as e: + print(f"Error with configuration: doRandom={do_random}, rootedAtAtom={rooted_at_atom}\nError: {e}\n") + + + @staticmethod + def find_smiles(): + original_smiles = "[Cl-].[H][N+]([H])([H])[H]" + mol = Chem.MolFromSmiles(original_smiles) + + # Dictionary to store SMILES and the corresponding configurations + smiles_dict = {} + + # Test combinations of doRandom and rootedAtAtom + for do_random in [True, False]: + for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]: + try: + smiles = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) + config = f"doRandom={do_random}, rootedAtAtom={rooted_at_atom}" + + if smiles in smiles_dict: + smiles_dict[smiles].append(config) + else: + smiles_dict[smiles] = [config] + + print(f"Configuration: {config}\n{smiles}\n") + except Exception as e: + print(f"Error with configuration: doRandom={do_random}, rootedAtAtom={rooted_at_atom}\nError: {e}\n") + + # Print configurations that generated the same SMILES string + for smiles, configs in smiles_dict.items(): + if len(configs) > 1: + print(f"SMILES string '{smiles}' was generated by the following configurations:") + for config in configs: + print(f"- {config}") + print("\n") + + + + @staticmethod + def find_config(): + original_smiles = "Oc1ccc2nccc(O)c2c1" + mol = Chem.MolFromSmiles(original_smiles) + + # List of Boolean parameters to iterate over + boolean_params = ['isomericSmiles', 'kekuleSmiles', 'canonical', 'allBondsExplicit', 'allHsExplicit', + 'doRandom'] + + # Generate all combinations of True/False for these parameters + combinations = list(product([True, False], repeat=len(boolean_params))) + + # Dictionary to store the configuration and its generated SMILES + results = [] + + # Iterate through all combinations + for combination in combinations: + params = dict(zip(boolean_params, combination)) + smiles = Chem.MolToSmiles(mol, **params) + results.append({ + 'config': params, + 'generated_smiles': smiles, + 'matches_original': smiles == original_smiles + }) + + return results + + if __name__ == "__main__": + + # print(Molecule.find_smiles()) log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger(__name__) From 1f0daa4c8fa487bb4b4b88163d93a361213bc63b Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 15 Aug 2024 11:35:41 +0200 Subject: [PATCH 02/35] Added notebook For Testing --- .../preprocessing/datasets/augmentation.ipynb | 1851 +++++++++++++++++ 1 file changed, 1851 insertions(+) create mode 100644 chebai/preprocessing/datasets/augmentation.ipynb diff --git a/chebai/preprocessing/datasets/augmentation.ipynb b/chebai/preprocessing/datasets/augmentation.ipynb new file mode 100644 index 00000000..22b26cc6 --- /dev/null +++ b/chebai/preprocessing/datasets/augmentation.ipynb @@ -0,0 +1,1851 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 269, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from rdkit import Chem\n", + "from rdkit.Chem import AllChem" + ] + }, + { + "cell_type": "code", + "execution_count": 270, + "metadata": {}, + "outputs": [], + "source": [ + "# Path to the original data.pkl file\n", + "data_path = \"D:\\Knowledge\\Hiwi\\python-chebai\\data\\chebi_v231\\ChEBI50\\processed\\data.pkl\"\n", + "# data_path1=\"data\\chebi_v231\\ChEBI50\\processed\\data.pkl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 271, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
033429monoatomic monoanion[*-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
130151aluminide(1-)[Al-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
216042halide anion[*-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
317051fluoride[F-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
428741sodium fluoride[F-].[Na+]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", + "

5 rows × 1514 columns

\n", + "
" + ], + "text/plain": [ + " id name SMILES 1722 2440 2468 2571 2580 \\\n", + "0 33429 monoatomic monoanion [*-] False False False False False \n", + "1 30151 aluminide(1-) [Al-] False False False False False \n", + "2 16042 halide anion [*-] False False False False False \n", + "3 17051 fluoride [F-] False False False False False \n", + "4 28741 sodium fluoride [F-].[Na+] False False False False False \n", + "\n", + " 2634 3098 ... 176910 177333 183508 183509 189832 189840 192499 \\\n", + "0 False False ... False False False False False False False \n", + "1 False False ... False False False False False False False \n", + "2 False False ... False False False False False False False \n", + "3 False False ... False False False False False False False \n", + "4 False False ... False False False False False False False \n", + "\n", + " 194321 197504 229684 \n", + "0 False False False \n", + "1 False False False \n", + "2 False False False \n", + "3 False False False \n", + "4 False False False \n", + "\n", + "[5 rows x 1514 columns]" + ] + }, + "execution_count": 271, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_pickle(\n", + " open(data_path, \"rb\"\n", + " )\n", + ")\n", + "df[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": 272, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
313565385gliotoxin[H][C@@]12[C@@H](O)C=CC=C1C[C@@]13SS[C@@](CO)(...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", + "

1 rows × 1514 columns

\n", + "
" + ], + "text/plain": [ + " id name SMILES \\\n", + "31356 5385 gliotoxin [H][C@@]12[C@@H](O)C=CC=C1C[C@@]13SS[C@@](CO)(... \n", + "\n", + " 1722 2440 2468 2571 2580 2634 3098 ... 176910 177333 \\\n", + "31356 False False False False False False False ... False False \n", + "\n", + " 183508 183509 189832 189840 192499 194321 197504 229684 \n", + "31356 False False False False False False False False \n", + "\n", + "[1 rows x 1514 columns]" + ] + }, + "execution_count": 272, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = df[df[\"id\"] == 5385]\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": 273, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a new empty DataFrame for storing new variations\n", + "new_df = pd.DataFrame(columns=df.columns)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 274, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to generate SMILES variations using different configurations\n", + "def generate_smiles_variations(smiles, num_variations=5):\n", + " mol = Chem.MolFromSmiles(smiles)\n", + " if mol is None:\n", + " return [] # Return an empty list if conversion fails\n", + "\n", + " variations = set()\n", + "\n", + " # Loop through all combinations of doRandom and rootedAtAtom values\n", + " for do_random in [True, False]:\n", + " for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]:\n", + " try:\n", + " # Generate SMILES with the given configuration\n", + " variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom)\n", + " if variant != smiles: # Avoid duplicates with the original SMILES\n", + " variations.add(variant)\n", + "\n", + " # Check the number of variations after adding\n", + " if len(variations) >= num_variations:\n", + " return list(variations) # Return immediately when enough variations are found\n", + "\n", + " except Exception as e:\n", + " # Skip invalid configurations\n", + " continue\n", + "\n", + " return list(variations)" + ] + }, + { + "cell_type": "code", + "execution_count": 275, + "metadata": {}, + "outputs": [], + "source": [ + "# smile=\"C(CC[N+]1(C)CCCC1)(O)(C2CCCCC2)C3=CC=CC=C3.[Cl-]\"\n", + "# sample_variation=generate_smiles_variations(smile)" + ] + }, + { + "cell_type": "code", + "execution_count": 276, + "metadata": {}, + "outputs": [], + "source": [ + "# sample_variation" + ] + }, + { + "cell_type": "code", + "execution_count": 277, + "metadata": {}, + "outputs": [], + "source": [ + "# Set to keep track of already seen SMILES to avoid duplicates\n", + "seen_smiles = set(df['SMILES'])" + ] + }, + { + "cell_type": "code", + "execution_count": 278, + "metadata": {}, + "outputs": [], + "source": [ + "test_df=df[-5::]" + ] + }, + { + "cell_type": "code", + "execution_count": 279, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
1850112295182-Amino-3-methylsuccinic acidOC(=O)C(C(N)C(O)=O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
18501283380dinocap-4C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
185013140503kaolin[Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
18501481948tralkoxydimCCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
185015140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", + "

5 rows × 1514 columns

\n", + "
" + ], + "text/plain": [ + " id name \\\n", + "185011 229518 2-Amino-3-methylsuccinic acid \n", + "185012 83380 dinocap-4 \n", + "185013 140503 kaolin \n", + "185014 81948 tralkoxydim \n", + "185015 140499 kaolinite \n", + "\n", + " SMILES 1722 2440 \\\n", + "185011 OC(=O)C(C(N)C(O)=O)C False False \n", + "185012 C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O... False False \n", + "185013 [Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O False False \n", + "185014 CCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1C False False \n", + "185015 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[... False False \n", + "\n", + " 2468 2571 2580 2634 3098 ... 176910 177333 183508 \\\n", + "185011 False False False False False ... False False False \n", + "185012 False False False False False ... False False False \n", + "185013 False False False False False ... False False False \n", + "185014 False False False False False ... False False False \n", + "185015 False False False False False ... False False False \n", + "\n", + " 183509 189832 189840 192499 194321 197504 229684 \n", + "185011 False False False False False False False \n", + "185012 False False False False False False False \n", + "185013 False False False False False False False \n", + "185014 False False False False False False False \n", + "185015 False False False False False False False \n", + "\n", + "[5 rows x 1514 columns]" + ] + }, + "execution_count": 279, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df" + ] + }, + { + "cell_type": "code", + "execution_count": 280, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "id 5\n", + "name 5\n", + "SMILES 5\n", + "1722 1\n", + "2440 1\n", + " ..\n", + "189840 1\n", + "192499 1\n", + "194321 1\n", + "197504 1\n", + "229684 1\n", + "Length: 1514, dtype: int64" + ] + }, + "execution_count": 280, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.nunique()" + ] + }, + { + "cell_type": "code", + "execution_count": 281, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 5 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 4 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 3 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 2 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 1 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 5 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 4 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 3 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 2 < 1\n", + "****\n", + "\n", + "[20:01:38] \n", + "\n", + "****\n", + "Range Error\n", + "idx\n", + "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", + "Failed Expression: 1 < 1\n", + "****\n", + "\n" + ] + } + ], + "source": [ + "# Process each row in the original DataFrame\n", + "for _, row in test_df.iterrows():\n", + " original_smiles = row['SMILES']\n", + " \n", + " # Generate new SMILES variations\n", + " variations = generate_smiles_variations(original_smiles)\n", + " \n", + " # Filter out variations that are already seen\n", + " variations = [var for var in variations if var not in seen_smiles]\n", + " \n", + " for var in variations:\n", + " # Create a new row with the new SMILES and the rest of the features and labels unchanged\n", + " new_row = row.copy()\n", + " new_row['SMILES'] = var\n", + " new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True)\n", + " \n", + " # Add the new SMILES to the seen set to avoid duplicates\n", + " seen_smiles.add(var)" + ] + }, + { + "cell_type": "code", + "execution_count": 282, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "id 5\n", + "name 5\n", + "SMILES 24\n", + "1722 1\n", + "2440 1\n", + " ..\n", + "189840 1\n", + "192499 1\n", + "194321 1\n", + "197504 1\n", + "229684 1\n", + "Length: 1514, dtype: int64" + ] + }, + "execution_count": 282, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_df.nunique()" + ] + }, + { + "cell_type": "code", + "execution_count": 283, + "metadata": {}, + "outputs": [], + "source": [ + "# Append the new DataFrame (new_df) to the original DataFrame (df)\n", + "df_combined = pd.concat([test_df, new_df], ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 284, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(29, 1514)" + ] + }, + "execution_count": 284, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_combined.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 285, + "metadata": {}, + "outputs": [], + "source": [ + "new_data_path=\"test_data.pkl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 286, + "metadata": {}, + "outputs": [], + "source": [ + "pd.to_pickle(df_combined, open(new_data_path, \"wb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 287, + "metadata": {}, + "outputs": [], + "source": [ + "test_data_df= pd.read_pickle(\n", + " open(\"test_data.pkl\", \"rb\"\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 288, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
02295182-Amino-3-methylsuccinic acidOC(=O)C(C(N)C(O)=O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
183380dinocap-4C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2140503kaolin[Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
381948tralkoxydimCCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
4140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
52295182-Amino-3-methylsuccinic acidO=C(C(C(C(O)=O)N)C)OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
62295182-Amino-3-methylsuccinic acidC(C(C(N)C(=O)O)C)(=O)OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
72295182-Amino-3-methylsuccinic acidNC(C(=O)O)C(C(O)=O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
82295182-Amino-3-methylsuccinic acidC(C(C(O)=O)C)(C(O)=O)NFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
92295182-Amino-3-methylsuccinic acidC(C(C(O)=O)N)(C(O)=O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1083380dinocap-4O=C(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/C=C/CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1183380dinocap-4C(=C/C(=O)Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1283380dinocap-4O(c1c(cc(*)cc1[N+](=O)[O-])[N+]([O-])=O)C(/C=C...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1383380dinocap-4C(\\C(=O)Oc1c([N+]([O-])=O)cc(*)cc1[N+](=O)[O-]...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1483380dinocap-4C(=O)(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
15140503kaolinO.O.O([Al]=O)[Si](=O)O[Si](=O)O[Al]=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
16140503kaolinO.O.O([Si](=O)O[Al]=O)[Si](O[Al]=O)=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
17140503kaolinO.O.[Si](=O)(O[Si](=O)O[Al]=O)O[Al]=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
18140503kaolinO.O.[Al](=O)O[Si](=O)O[Si](O[Al]=O)=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
19140503kaolinO.O.[Si](O[Al]=O)(O[Si](=O)O[Al]=O)=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2081948tralkoxydimC(C)O/N=C(/C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)CCFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2181948tralkoxydimC(=N\\OCC)(\\CC)C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2281948tralkoxydimO(/N=C(\\CC)C1C(CC(CC=1O)c1c(C)cc(C)cc1C)=O)CCFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2381948tralkoxydimN(\\OCC)=C(\\CC)C1C(CC(c2c(C)cc(cc2C)C)CC=1O)=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2481948tralkoxydimC(C)/C(C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C)=N\\OCCFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
25140499kaoliniteO=[Si]([O-])O[Si](=O)[O-].[Al+3].[Al+3].[OH-]....FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
26140499kaolinite[Al+3].[Al+3].[O-][Si](=O)O[Si]([O-])=O.[OH-]....FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
27140499kaoliniteO=[Si](O[Si]([O-])=O)[O-].[Al+3].[Al+3].[OH-]....FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
28140499kaolinite[Al+3].[Al+3].[O-][Si](=O)O[Si](=O)[O-].[OH-]....FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", + "

29 rows × 1514 columns

\n", + "
" + ], + "text/plain": [ + " id name \\\n", + "0 229518 2-Amino-3-methylsuccinic acid \n", + "1 83380 dinocap-4 \n", + "2 140503 kaolin \n", + "3 81948 tralkoxydim \n", + "4 140499 kaolinite \n", + "5 229518 2-Amino-3-methylsuccinic acid \n", + "6 229518 2-Amino-3-methylsuccinic acid \n", + "7 229518 2-Amino-3-methylsuccinic acid \n", + "8 229518 2-Amino-3-methylsuccinic acid \n", + "9 229518 2-Amino-3-methylsuccinic acid \n", + "10 83380 dinocap-4 \n", + "11 83380 dinocap-4 \n", + "12 83380 dinocap-4 \n", + "13 83380 dinocap-4 \n", + "14 83380 dinocap-4 \n", + "15 140503 kaolin \n", + "16 140503 kaolin \n", + "17 140503 kaolin \n", + "18 140503 kaolin \n", + "19 140503 kaolin \n", + "20 81948 tralkoxydim \n", + "21 81948 tralkoxydim \n", + "22 81948 tralkoxydim \n", + "23 81948 tralkoxydim \n", + "24 81948 tralkoxydim \n", + "25 140499 kaolinite \n", + "26 140499 kaolinite \n", + "27 140499 kaolinite \n", + "28 140499 kaolinite \n", + "\n", + " SMILES 1722 2440 2468 \\\n", + "0 OC(=O)C(C(N)C(O)=O)C False False False \n", + "1 C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O... False False False \n", + "2 [Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O False False False \n", + "3 CCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1C False False False \n", + "4 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[... False False False \n", + "5 O=C(C(C(C(O)=O)N)C)O False False False \n", + "6 C(C(C(N)C(=O)O)C)(=O)O False False False \n", + "7 NC(C(=O)O)C(C(O)=O)C False False False \n", + "8 C(C(C(O)=O)C)(C(O)=O)N False False False \n", + "9 C(C(C(O)=O)N)(C(O)=O)C False False False \n", + "10 O=C(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/C=C/C False False False \n", + "11 C(=C/C(=O)Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O... False False False \n", + "12 O(c1c(cc(*)cc1[N+](=O)[O-])[N+]([O-])=O)C(/C=C... False False False \n", + "13 C(\\C(=O)Oc1c([N+]([O-])=O)cc(*)cc1[N+](=O)[O-]... False False False \n", + "14 C(=O)(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/... False False False \n", + "15 O.O.O([Al]=O)[Si](=O)O[Si](=O)O[Al]=O False False False \n", + "16 O.O.O([Si](=O)O[Al]=O)[Si](O[Al]=O)=O False False False \n", + "17 O.O.[Si](=O)(O[Si](=O)O[Al]=O)O[Al]=O False False False \n", + "18 O.O.[Al](=O)O[Si](=O)O[Si](O[Al]=O)=O False False False \n", + "19 O.O.[Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O False False False \n", + "20 C(C)O/N=C(/C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)CC False False False \n", + "21 C(=N\\OCC)(\\CC)C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C False False False \n", + "22 O(/N=C(\\CC)C1C(CC(CC=1O)c1c(C)cc(C)cc1C)=O)CC False False False \n", + "23 N(\\OCC)=C(\\CC)C1C(CC(c2c(C)cc(cc2C)C)CC=1O)=O False False False \n", + "24 C(C)/C(C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C)=N\\OCC False False False \n", + "25 O=[Si]([O-])O[Si](=O)[O-].[Al+3].[Al+3].[OH-].... False False False \n", + "26 [Al+3].[Al+3].[O-][Si](=O)O[Si]([O-])=O.[OH-].... False False False \n", + "27 O=[Si](O[Si]([O-])=O)[O-].[Al+3].[Al+3].[OH-].... False False False \n", + "28 [Al+3].[Al+3].[O-][Si](=O)O[Si](=O)[O-].[OH-].... False False False \n", + "\n", + " 2571 2580 2634 3098 ... 176910 177333 183508 183509 189832 189840 \\\n", + "0 False False False False ... False False False False False False \n", + "1 False False False False ... False False False False False False \n", + "2 False False False False ... False False False False False False \n", + "3 False False False False ... False False False False False False \n", + "4 False False False False ... False False False False False False \n", + "5 False False False False ... False False False False False False \n", + "6 False False False False ... False False False False False False \n", + "7 False False False False ... False False False False False False \n", + "8 False False False False ... False False False False False False \n", + "9 False False False False ... False False False False False False \n", + "10 False False False False ... False False False False False False \n", + "11 False False False False ... False False False False False False \n", + "12 False False False False ... False False False False False False \n", + "13 False False False False ... False False False False False False \n", + "14 False False False False ... False False False False False False \n", + "15 False False False False ... False False False False False False \n", + "16 False False False False ... False False False False False False \n", + "17 False False False False ... False False False False False False \n", + "18 False False False False ... False False False False False False \n", + "19 False False False False ... False False False False False False \n", + "20 False False False False ... False False False False False False \n", + "21 False False False False ... False False False False False False \n", + "22 False False False False ... False False False False False False \n", + "23 False False False False ... False False False False False False \n", + "24 False False False False ... False False False False False False \n", + "25 False False False False ... False False False False False False \n", + "26 False False False False ... False False False False False False \n", + "27 False False False False ... False False False False False False \n", + "28 False False False False ... False False False False False False \n", + "\n", + " 192499 194321 197504 229684 \n", + "0 False False False False \n", + "1 False False False False \n", + "2 False False False False \n", + "3 False False False False \n", + "4 False False False False \n", + "5 False False False False \n", + "6 False False False False \n", + "7 False False False False \n", + "8 False False False False \n", + "9 False False False False \n", + "10 False False False False \n", + "11 False False False False \n", + "12 False False False False \n", + "13 False False False False \n", + "14 False False False False \n", + "15 False False False False \n", + "16 False False False False \n", + "17 False False False False \n", + "18 False False False False \n", + "19 False False False False \n", + "20 False False False False \n", + "21 False False False False \n", + "22 False False False False \n", + "23 False False False False \n", + "24 False False False False \n", + "25 False False False False \n", + "26 False False False False \n", + "27 False False False False \n", + "28 False False False False \n", + "\n", + "[29 rows x 1514 columns]" + ] + }, + "execution_count": 288, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data_df" + ] + }, + { + "cell_type": "code", + "execution_count": 289, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(29, 1514)" + ] + }, + "execution_count": 289, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data_df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 290, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "id 5\n", + "name 5\n", + "SMILES 29\n", + "1722 1\n", + "2440 1\n", + " ..\n", + "189840 1\n", + "192499 1\n", + "194321 1\n", + "197504 1\n", + "229684 1\n", + "Length: 1514, dtype: int64" + ] + }, + "execution_count": 290, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data_df.nunique()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each time we run it is generating different number of new variations" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From a9f57652f00b7210dec952710d0fbf74625319ce Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 15 Aug 2024 11:37:43 +0200 Subject: [PATCH 03/35] Added changes for new smiles variant generation --- chebai/preprocessing/datasets/chebi.py | 107 ++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 5876577f..71baaf3f 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -24,6 +24,7 @@ MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit, ) +from rdkit import Chem from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule @@ -450,6 +451,21 @@ def setup_processed(self) -> None: os.path.join(self.processed_dir, processed_name), ) + if not os.path.isfile(os.path.join(self.processed_dir, "augmented_data.pt")): + print( + f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", + "augmented_data.pt", + ) + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + "augmented_data.pkl", + ) + ), + os.path.join(self.processed_dir, "augmented_data.pt"), + ) + # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist if self.chebi_version_train is not None and not os.path.isfile( os.path.join( @@ -694,6 +710,11 @@ def _load_chebi(self, version: int) -> str: url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" r = requests.get(url, allow_redirects=True) open(chebi_path, "wb").write(r.content) + # # Define the source path of the backup file + # local_backup_path = r"D:\Knowledge\Hiwi\chebi_v231\raw\chebi.obo" + # # Copy the file from the local backup to the target directory + # shutil.copy(local_backup_path, chebi_path) + # print(f"Copied local backup to {chebi_path}.") return chebi_path def prepare_data(self, *args: Any, **kwargs: Any) -> None: @@ -759,7 +780,8 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: g = self.extract_class_hierarchy(chebi_path) df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"]) self.save_processed(df, filename=self.raw_file_names_dict["data"]) - + print("Reached before augment data") + self.augment_data(self.processed_dir_main) if self.chebi_version_train is not None: if not os.path.isfile( os.path.join( @@ -774,6 +796,86 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: # Generate the "chebi_version_train" data if it doesn't exist self._chebi_version_train_obj.prepare_data(*args, **kwargs) + def augment_data(self, path: str) -> None: + print(("inside_augment_data")) + if not os.path.isfile( + os.path.join( + path, "augmented_data.pkl")): + if os.path.isfile(os.path.join( + path, self.raw_file_names_dict["data"] + )): + data = self.read_file(os.path.join( + path, self.raw_file_names_dict["data"])) + print("Original Dataset size:",data.shape) + # Create a new empty DataFrame for storing new variations + new_df = pd.DataFrame(columns=data.columns) + + # Set to keep track of already seen SMILES to avoid duplicates + seen_smiles = set(data['SMILES']) + + # Process each row in the original DataFrame + print("Generating New SMILES") + for _, row in data.iterrows(): + original_smiles = row['SMILES'] + # Generate new SMILES variations + variations = self.generate_smiles_variations(original_smiles) + + # Filter out variations that are already seen + variations = [var for var in variations if var not in seen_smiles] + + for var in variations: + # Create a new row with the new SMILES and the rest of the features and labels unchanged + new_row = row.copy() + new_row['SMILES'] = var + new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True) + + # Add the new SMILES to the seen set to avoid duplicates + seen_smiles.add(var) + + new_dataset = pd.concat([data, new_df], ignore_index=True) + self.save_file(new_dataset, os.path.join(path, "augmented_data.pkl")) + + def save_file(self, dataset: pd.DataFrame, file_path: str): + pd.to_pickle(dataset, open(file_path, "wb")) + + # Function to generate SMILES variations using different configurations + def generate_smiles_variations(self, original_smiles): + num_variations=5 + print(type(original_smiles), original_smiles) + if not isinstance(original_smiles, str): + print(f"Non-string SMILES found: {original_smiles}") + mol = Chem.MolFromSmiles(original_smiles) + if mol is None: + return [] # Return an empty list if conversion fails + + variations = set() + + # Loop through all combinations of doRandom and rootedAtAtom values + for do_random in [True, False]: + for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]: + try: + # Generate SMILES with the given configuration + variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) + if variant != original_smiles: # Avoid duplicates with the original SMILES + variations.add(variant) + + # Check the number of variations after adding + if len(variations) >= num_variations: + return list(variations) # Return immediately when enough variations are found + + except Exception as e: + # Skip invalid configurations + continue + + return list(variations) + + def read_file(self,file_path : str): + df = pd.read_pickle( + open(file_path, "rb" + ) + ) + return df + def _generate_dynamic_splits(self) -> None: """ Generate data splits during runtime and save them in class variables. @@ -1929,3 +2031,6 @@ def term_callback(doc) -> dict: ] JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS] + + + From d84efad8671849ad113cf4aca872fc278668456c Mon Sep 17 00:00:00 2001 From: vidvath Date: Tue, 27 Aug 2024 18:05:29 +0200 Subject: [PATCH 04/35] changes-tqdm,randomized rootedAtAtom --- .../preprocessing/datasets/augmentation.ipynb | 519 ++++++++++++------ chebai/preprocessing/datasets/chebi.py | 72 ++- 2 files changed, 424 insertions(+), 167 deletions(-) diff --git a/chebai/preprocessing/datasets/augmentation.ipynb b/chebai/preprocessing/datasets/augmentation.ipynb index 22b26cc6..8a09e825 100644 --- a/chebai/preprocessing/datasets/augmentation.ipynb +++ b/chebai/preprocessing/datasets/augmentation.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 269, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 270, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 271, + "execution_count": 112, "metadata": {}, "outputs": [ { @@ -222,7 +222,7 @@ "[5 rows x 1514 columns]" ] }, - "execution_count": 271, + "execution_count": 112, "metadata": {}, "output_type": "execute_result" } @@ -237,7 +237,27 @@ }, { "cell_type": "code", - "execution_count": 272, + "execution_count": 113, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(185007, 1514)" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 114, "metadata": {}, "outputs": [ { @@ -286,10 +306,10 @@ " \n", " \n", " \n", - " 31356\n", - " 5385\n", - " gliotoxin\n", - " [H][C@@]12[C@@H](O)C=CC=C1C[C@@]13SS[C@@](CO)(...\n", + " 16992\n", + " 112763\n", + " N2,N4-bis[[4-(dimethylamino)phenyl]methylidene...\n", + " CC1=C(NC(=C1C(=O)NN=CC2=CC=C(C=C2)N(C)C)C)C(=O...\n", " False\n", " False\n", " False\n", @@ -315,31 +335,34 @@ "" ], "text/plain": [ - " id name SMILES \\\n", - "31356 5385 gliotoxin [H][C@@]12[C@@H](O)C=CC=C1C[C@@]13SS[C@@](CO)(... \n", + " id name \\\n", + "16992 112763 N2,N4-bis[[4-(dimethylamino)phenyl]methylidene... \n", + "\n", + " SMILES 1722 2440 2468 \\\n", + "16992 CC1=C(NC(=C1C(=O)NN=CC2=CC=C(C=C2)N(C)C)C)C(=O... False False False \n", "\n", - " 1722 2440 2468 2571 2580 2634 3098 ... 176910 177333 \\\n", - "31356 False False False False False False False ... False False \n", + " 2571 2580 2634 3098 ... 176910 177333 183508 183509 \\\n", + "16992 False False False False ... False False False False \n", "\n", - " 183508 183509 189832 189840 192499 194321 197504 229684 \n", - "31356 False False False False False False False False \n", + " 189832 189840 192499 194321 197504 229684 \n", + "16992 False False False False False False \n", "\n", "[1 rows x 1514 columns]" ] }, - "execution_count": 272, + "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "result = df[df[\"id\"] == 5385]\n", + "result = df[df[\"SMILES\"] == \"CC1=C(NC(=C1C(=O)NN=CC2=CC=C(C=C2)N(C)C)C)C(=O)NN=CC3=CC=C(C=C3)N(C)C\"]\n", "result" ] }, { "cell_type": "code", - "execution_count": 273, + "execution_count": 115, "metadata": {}, "outputs": [], "source": [ @@ -349,12 +372,12 @@ }, { "cell_type": "code", - "execution_count": 274, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ "# Function to generate SMILES variations using different configurations\n", - "def generate_smiles_variations(smiles, num_variations=5):\n", + "def generate_smiles_variations1(smiles, num_variations=5):\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is None:\n", " return [] # Return an empty list if conversion fails\n", @@ -369,6 +392,7 @@ " variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom)\n", " if variant != smiles: # Avoid duplicates with the original SMILES\n", " variations.add(variant)\n", + " # print(\"len-variations:\", len(variations))\n", "\n", " # Check the number of variations after adding\n", " if len(variations) >= num_variations:\n", @@ -383,26 +407,122 @@ }, { "cell_type": "code", - "execution_count": 275, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ - "# smile=\"C(CC[N+]1(C)CCCC1)(O)(C2CCCCC2)C3=CC=CC=C3.[Cl-]\"\n", - "# sample_variation=generate_smiles_variations(smile)" + "import random\n", + "from rdkit import Chem\n", + "from tqdm import tqdm\n", + "\n", + "# Function to generate SMILES variations using different configurations\n", + "def generate_smiles_variations(smiles, num_variations=5):\n", + " \"\"\"\n", + " Generates a list of SMILES variations based on different configurations.\n", + "\n", + " Parameters:\n", + " smiles (str): The input SMILES string.\n", + " num_variations (int): The number of SMILES variations to generate.\n", + " canonical (bool): Whether to generate canonical SMILES.\n", + "\n", + " Returns:\n", + " list: A list of unique SMILES variations.\n", + " \"\"\"\n", + " mol = Chem.MolFromSmiles(smiles)\n", + " if mol is None:\n", + " return [] # Return an empty list if conversion fails\n", + "\n", + " variations = set()\n", + "\n", + " # List of rootedAtAtom values to pick from randomly\n", + " rooted_at_atoms = [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]\n", + " random.shuffle(rooted_at_atoms) # Randomize the order of rootedAtAtom values\n", + "\n", + " # Flag to track if we've already computed a SMILES with doRandom=False and a negative rootedAtAtom\n", + " already_computed_negative_rooted = False\n", + " # Initialize tqdm progress bar for SMILES variation generation\n", + " with tqdm(total=num_variations, desc=\"Generating SMILES Variations\", unit=\"variant\", leave=False) as pbar:\n", + " # Loop through all combinations of doRandom and rootedAtAtom values\n", + " for do_random in [True, False]:\n", + " for rooted_at_atom in rooted_at_atoms:\n", + " try:\n", + " # Skip redundant computations\n", + " if not do_random and rooted_at_atom < 0:\n", + " if already_computed_negative_rooted:\n", + " continue\n", + " already_computed_negative_rooted = True\n", + "\n", + " # Generate SMILES with the given configuration\n", + " variant = Chem.MolToSmiles(\n", + " mol, \n", + " doRandom=do_random, \n", + " rootedAtAtom=rooted_at_atom, \n", + " canonical=False\n", + " )\n", + "\n", + " # Print the configuration and the generated SMILES string\n", + " # print(f\"Config: doRandom={do_random}, rootedAtAtom={rooted_at_atom}, canonical={False} -> SMILES: {variant}\")\n", + " \n", + " # Avoid duplicates with the original SMILES\n", + " if variant != smiles:\n", + " variations.add(variant)\n", + " pbar.update(1) # Update tqdm progress bar with each new variant\n", + "\n", + " # Check the number of variations after adding\n", + " if len(variations) >= num_variations:\n", + " pbar.close() # Close the progress bar when done\n", + " return list(variations) # Return immediately when enough variations are found\n", + "\n", + " except Exception as e:\n", + " # Skip invalid configurations\n", + " continue\n", + " pbar.close() # Close the progress bar if not already closed\n", + " return list(variations)\n", + "\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 276, + "execution_count": 118, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['C(CC[N+]1(C)CCCC1)(C1CCCCC1)(c1ccccc1)O.[Cl-]', 'C(O)(C1CCCCC1)(c1ccccc1)CC[N+]1(CCCC1)C.[Cl-]', 'C(c1ccccc1)(C1CCCCC1)(CC[N+]1(C)CCCC1)O.[Cl-]', '[N+]1(C)(CCCC1)CCC(C1CCCCC1)(O)c1ccccc1.[Cl-]', 'C(O)(C1CCCCC1)(c1ccccc1)CC[N+]1(C)CCCC1.[Cl-]']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r" + ] + } + ], "source": [ - "# sample_variation" + "# Example usage\n", + "smile1=\"OC(=O)C(C(N)C(O)=O)C\"\n", + "smile2=\"[Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O\"\n", + "smile3=\"[Cl-].[H][N+]([H])([H])[H]\"\n", + "smile4=\"[Ca++].OC[C@@H](O)[C@@H](O)[C@H](O)[C@@H](O)C(O)C([O-])=O.OC[C@@H](O)[C@@H](O)[C@H](O)[C@@H](O)C(O)C([O-])=O\"\n", + "smile5=\"C(CC[N+]1(C)CCCC1)(O)(C2CCCCC2)C3=CC=CC=C3.[Cl-]\"\n", + "variations = generate_smiles_variations(smile5, num_variations=5)\n", + "print(variations)" ] }, { "cell_type": "code", - "execution_count": 277, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -412,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 278, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -421,7 +541,7 @@ }, { "cell_type": "code", - "execution_count": 279, + "execution_count": 121, "metadata": {}, "outputs": [ { @@ -626,7 +746,7 @@ "[5 rows x 1514 columns]" ] }, - "execution_count": 279, + "execution_count": 121, "metadata": {}, "output_type": "execute_result" } @@ -637,7 +757,7 @@ }, { "cell_type": "code", - "execution_count": 280, + "execution_count": 122, "metadata": {}, "outputs": [ { @@ -657,7 +777,7 @@ "Length: 1514, dtype: int64" ] }, - "execution_count": 280, + "execution_count": 122, "metadata": {}, "output_type": "execute_result" } @@ -668,23 +788,23 @@ }, { "cell_type": "code", - "execution_count": 281, + "execution_count": 123, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 124, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[20:01:38] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 5 < 1\n", - "****\n", - "\n", - "[20:01:38] \n", + "Processing Rows: 60%|██████ | 3/5 [00:00<00:00, 6.54row/s][17:59:42] \n", "\n", "****\n", "Range Error\n", @@ -693,34 +813,7 @@ "Failed Expression: 4 < 1\n", "****\n", "\n", - "[20:01:38] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 3 < 1\n", - "****\n", - "\n", - "[20:01:38] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 2 < 1\n", - "****\n", - "\n", - "[20:01:38] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 1 < 1\n", - "****\n", - "\n", - "[20:01:38] \n", + "[17:59:42] \n", "\n", "****\n", "Range Error\n", @@ -729,16 +822,16 @@ "Failed Expression: 5 < 1\n", "****\n", "\n", - "[20:01:38] \n", + "[17:59:42] \n", "\n", "****\n", "Range Error\n", "idx\n", "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 4 < 1\n", + "Failed Expression: 1 < 1\n", "****\n", "\n", - "[20:01:38] \n", + "[17:59:42] \n", "\n", "****\n", "Range Error\n", @@ -747,7 +840,7 @@ "Failed Expression: 3 < 1\n", "****\n", "\n", - "[20:01:38] \n", + "[17:59:42] \n", "\n", "****\n", "Range Error\n", @@ -756,21 +849,13 @@ "Failed Expression: 2 < 1\n", "****\n", "\n", - "[20:01:38] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 1 < 1\n", - "****\n", - "\n" + "Processing Rows: 100%|██████████| 5/5 [00:00<00:00, 7.81row/s]\n" ] } ], "source": [ "# Process each row in the original DataFrame\n", - "for _, row in test_df.iterrows():\n", + "for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc=\"Processing Rows\", unit=\"row\"):\n", " original_smiles = row['SMILES']\n", " \n", " # Generate new SMILES variations\n", @@ -791,7 +876,7 @@ }, { "cell_type": "code", - "execution_count": 282, + "execution_count": 125, "metadata": {}, "outputs": [ { @@ -799,7 +884,7 @@ "text/plain": [ "id 5\n", "name 5\n", - "SMILES 24\n", + "SMILES 25\n", "1722 1\n", "2440 1\n", " ..\n", @@ -811,7 +896,7 @@ "Length: 1514, dtype: int64" ] }, - "execution_count": 282, + "execution_count": 125, "metadata": {}, "output_type": "execute_result" } @@ -822,7 +907,7 @@ }, { "cell_type": "code", - "execution_count": 283, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -832,16 +917,16 @@ }, { "cell_type": "code", - "execution_count": 284, + "execution_count": 127, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(29, 1514)" + "(30, 1514)" ] }, - "execution_count": 284, + "execution_count": 127, "metadata": {}, "output_type": "execute_result" } @@ -852,16 +937,16 @@ }, { "cell_type": "code", - "execution_count": 285, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [ - "new_data_path=\"test_data.pkl\"" + "new_data_path=\"augmented_data.pkl\"" ] }, { "cell_type": "code", - "execution_count": 286, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -870,19 +955,19 @@ }, { "cell_type": "code", - "execution_count": 287, + "execution_count": 130, "metadata": {}, "outputs": [], "source": [ - "test_data_df= pd.read_pickle(\n", - " open(\"test_data.pkl\", \"rb\"\n", + "data_df= pd.read_pickle(\n", + " open(\"augmented_data.pkl\", \"rb\"\n", " )\n", ")" ] }, { "cell_type": "code", - "execution_count": 288, + "execution_count": 131, "metadata": {}, "outputs": [ { @@ -1054,7 +1139,7 @@ " 5\n", " 229518\n", " 2-Amino-3-methylsuccinic acid\n", - " O=C(C(C(C(O)=O)N)C)O\n", + " C(O)(C(C(N)C(O)=O)C)=O\n", " False\n", " False\n", " False\n", @@ -1078,7 +1163,7 @@ " 6\n", " 229518\n", " 2-Amino-3-methylsuccinic acid\n", - " C(C(C(N)C(=O)O)C)(=O)O\n", + " O=C(C(C(C(=O)O)N)C)O\n", " False\n", " False\n", " False\n", @@ -1102,7 +1187,7 @@ " 7\n", " 229518\n", " 2-Amino-3-methylsuccinic acid\n", - " NC(C(=O)O)C(C(O)=O)C\n", + " OC(=O)C(C(N)C(=O)O)C\n", " False\n", " False\n", " False\n", @@ -1126,7 +1211,7 @@ " 8\n", " 229518\n", " 2-Amino-3-methylsuccinic acid\n", - " C(C(C(O)=O)C)(C(O)=O)N\n", + " NC(C(C)C(=O)O)C(=O)O\n", " False\n", " False\n", " False\n", @@ -1150,7 +1235,7 @@ " 9\n", " 229518\n", " 2-Amino-3-methylsuccinic acid\n", - " C(C(C(O)=O)N)(C(O)=O)C\n", + " OC(=O)C(C)C(C(=O)O)N\n", " False\n", " False\n", " False\n", @@ -1174,7 +1259,7 @@ " 10\n", " 83380\n", " dinocap-4\n", - " O=C(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/C=C/C\n", + " C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(*)cc1[N+](=O)[O-]\n", " False\n", " False\n", " False\n", @@ -1198,7 +1283,7 @@ " 11\n", " 83380\n", " dinocap-4\n", - " C(=C/C(=O)Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O...\n", + " C/C=C/C(=O)Oc1c(cc(*)cc1[N+]([O-])=O)[N+]([O-])=O\n", " False\n", " False\n", " False\n", @@ -1222,7 +1307,7 @@ " 12\n", " 83380\n", " dinocap-4\n", - " O(c1c(cc(*)cc1[N+](=O)[O-])[N+]([O-])=O)C(/C=C...\n", + " C(=O)(/C=C/C)Oc1c(cc(*)cc1[N+](=O)[O-])[N+]([O...\n", " False\n", " False\n", " False\n", @@ -1246,7 +1331,7 @@ " 13\n", " 83380\n", " dinocap-4\n", - " C(\\C(=O)Oc1c([N+]([O-])=O)cc(*)cc1[N+](=O)[O-]...\n", + " C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(cc1[N+]([O-])=O)*\n", " False\n", " False\n", " False\n", @@ -1270,7 +1355,7 @@ " 14\n", " 83380\n", " dinocap-4\n", - " C(=O)(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/...\n", + " O(C(/C=C/C)=O)c1c([N+](=O)[O-])cc(cc1[N+](=O)[...\n", " False\n", " False\n", " False\n", @@ -1294,7 +1379,7 @@ " 15\n", " 140503\n", " kaolin\n", - " O.O.O([Al]=O)[Si](=O)O[Si](=O)O[Al]=O\n", + " O([Al]=O)[Si](O[Si](=O)O[Al]=O)=O.O.O\n", " False\n", " False\n", " False\n", @@ -1318,7 +1403,7 @@ " 16\n", " 140503\n", " kaolin\n", - " O.O.O([Si](=O)O[Al]=O)[Si](O[Al]=O)=O\n", + " [Si](=O)(O[Al]=O)O[Si](=O)O[Al]=O.O.O\n", " False\n", " False\n", " False\n", @@ -1342,7 +1427,7 @@ " 17\n", " 140503\n", " kaolin\n", - " O.O.[Si](=O)(O[Si](=O)O[Al]=O)O[Al]=O\n", + " O=[Si](O[Al]=O)O[Si](O[Al]=O)=O.O.O\n", " False\n", " False\n", " False\n", @@ -1366,7 +1451,7 @@ " 18\n", " 140503\n", " kaolin\n", - " O.O.[Al](=O)O[Si](=O)O[Si](O[Al]=O)=O\n", + " [Al](=O)O[Si](O[Si](O[Al]=O)=O)=O.O.O\n", " False\n", " False\n", " False\n", @@ -1390,7 +1475,7 @@ " 19\n", " 140503\n", " kaolin\n", - " O.O.[Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O\n", + " [Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O.O.O\n", " False\n", " False\n", " False\n", @@ -1414,7 +1499,7 @@ " 20\n", " 81948\n", " tralkoxydim\n", - " C(C)O/N=C(/C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)CC\n", + " CCO/N=C(/C1=C(CC(c2c(C)cc(C)cc2C)CC1=O)O)CC\n", " False\n", " False\n", " False\n", @@ -1438,7 +1523,7 @@ " 21\n", " 81948\n", " tralkoxydim\n", - " C(=N\\OCC)(\\CC)C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C\n", + " CCO/N=C(\\CC)C1=C(CC(c2c(cc(C)cc2C)C)CC1=O)O\n", " False\n", " False\n", " False\n", @@ -1462,7 +1547,7 @@ " 22\n", " 81948\n", " tralkoxydim\n", - " O(/N=C(\\CC)C1C(CC(CC=1O)c1c(C)cc(C)cc1C)=O)CC\n", + " CCO/N=C(\\CC)C1C(=O)CC(c2c(C)cc(C)cc2C)CC=1O\n", " False\n", " False\n", " False\n", @@ -1486,7 +1571,7 @@ " 23\n", " 81948\n", " tralkoxydim\n", - " N(\\OCC)=C(\\CC)C1C(CC(c2c(C)cc(cc2C)C)CC=1O)=O\n", + " C(C)/C(C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)=N\\OCC\n", " False\n", " False\n", " False\n", @@ -1510,7 +1595,7 @@ " 24\n", " 81948\n", " tralkoxydim\n", - " C(C)/C(C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C)=N\\OCC\n", + " O(/N=C(\\CC)C1=C(CC(CC1=O)c1c(C)cc(C)cc1C)O)CC\n", " False\n", " False\n", " False\n", @@ -1534,7 +1619,7 @@ " 25\n", " 140499\n", " kaolinite\n", - " O=[Si]([O-])O[Si](=O)[O-].[Al+3].[Al+3].[OH-]....\n", + " [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si]([O-]...\n", " False\n", " False\n", " False\n", @@ -1558,7 +1643,7 @@ " 26\n", " 140499\n", " kaolinite\n", - " [Al+3].[Al+3].[O-][Si](=O)O[Si]([O-])=O.[OH-]....\n", + " [OH-].[OH-].[OH-].[OH-].O=[Si](O[Si](=O)[O-])[...\n", " False\n", " False\n", " False\n", @@ -1582,7 +1667,7 @@ " 27\n", " 140499\n", " kaolinite\n", - " O=[Si](O[Si]([O-])=O)[O-].[Al+3].[Al+3].[OH-]....\n", + " [OH-].[OH-].[OH-].[OH-].[O-][Si](O[Si](=O)[O-]...\n", " False\n", " False\n", " False\n", @@ -1606,7 +1691,31 @@ " 28\n", " 140499\n", " kaolinite\n", - " [Al+3].[Al+3].[O-][Si](=O)O[Si](=O)[O-].[OH-]....\n", + " [OH-].[OH-].[OH-].[OH-].[Si](=O)(O[Si]([O-])=O...\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " ...\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " False\n", + " \n", + " \n", + " 29\n", + " 140499\n", + " kaolinite\n", + " [OH-].[OH-].[OH-].[OH-].O=[Si](O[Si]([O-])=O)[...\n", " False\n", " False\n", " False\n", @@ -1628,7 +1737,7 @@ " \n", " \n", "\n", - "

29 rows × 1514 columns

\n", + "

30 rows × 1514 columns

\n", "" ], "text/plain": [ @@ -1662,6 +1771,7 @@ "26 140499 kaolinite \n", "27 140499 kaolinite \n", "28 140499 kaolinite \n", + "29 140499 kaolinite \n", "\n", " SMILES 1722 2440 2468 \\\n", "0 OC(=O)C(C(N)C(O)=O)C False False False \n", @@ -1669,30 +1779,31 @@ "2 [Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O False False False \n", "3 CCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1C False False False \n", "4 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[... False False False \n", - "5 O=C(C(C(C(O)=O)N)C)O False False False \n", - "6 C(C(C(N)C(=O)O)C)(=O)O False False False \n", - "7 NC(C(=O)O)C(C(O)=O)C False False False \n", - "8 C(C(C(O)=O)C)(C(O)=O)N False False False \n", - "9 C(C(C(O)=O)N)(C(O)=O)C False False False \n", - "10 O=C(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/C=C/C False False False \n", - "11 C(=C/C(=O)Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O... False False False \n", - "12 O(c1c(cc(*)cc1[N+](=O)[O-])[N+]([O-])=O)C(/C=C... False False False \n", - "13 C(\\C(=O)Oc1c([N+]([O-])=O)cc(*)cc1[N+](=O)[O-]... False False False \n", - "14 C(=O)(Oc1c(cc(cc1[N+](=O)[O-])*)[N+](=O)[O-])/... False False False \n", - "15 O.O.O([Al]=O)[Si](=O)O[Si](=O)O[Al]=O False False False \n", - "16 O.O.O([Si](=O)O[Al]=O)[Si](O[Al]=O)=O False False False \n", - "17 O.O.[Si](=O)(O[Si](=O)O[Al]=O)O[Al]=O False False False \n", - "18 O.O.[Al](=O)O[Si](=O)O[Si](O[Al]=O)=O False False False \n", - "19 O.O.[Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O False False False \n", - "20 C(C)O/N=C(/C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)CC False False False \n", - "21 C(=N\\OCC)(\\CC)C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C False False False \n", - "22 O(/N=C(\\CC)C1C(CC(CC=1O)c1c(C)cc(C)cc1C)=O)CC False False False \n", - "23 N(\\OCC)=C(\\CC)C1C(CC(c2c(C)cc(cc2C)C)CC=1O)=O False False False \n", - "24 C(C)/C(C1=C(O)CC(CC1=O)c1c(C)cc(cc1C)C)=N\\OCC False False False \n", - "25 O=[Si]([O-])O[Si](=O)[O-].[Al+3].[Al+3].[OH-].... False False False \n", - "26 [Al+3].[Al+3].[O-][Si](=O)O[Si]([O-])=O.[OH-].... False False False \n", - "27 O=[Si](O[Si]([O-])=O)[O-].[Al+3].[Al+3].[OH-].... False False False \n", - "28 [Al+3].[Al+3].[O-][Si](=O)O[Si](=O)[O-].[OH-].... False False False \n", + "5 C(O)(C(C(N)C(O)=O)C)=O False False False \n", + "6 O=C(C(C(C(=O)O)N)C)O False False False \n", + "7 OC(=O)C(C(N)C(=O)O)C False False False \n", + "8 NC(C(C)C(=O)O)C(=O)O False False False \n", + "9 OC(=O)C(C)C(C(=O)O)N False False False \n", + "10 C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(*)cc1[N+](=O)[O-] False False False \n", + "11 C/C=C/C(=O)Oc1c(cc(*)cc1[N+]([O-])=O)[N+]([O-])=O False False False \n", + "12 C(=O)(/C=C/C)Oc1c(cc(*)cc1[N+](=O)[O-])[N+]([O... False False False \n", + "13 C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(cc1[N+]([O-])=O)* False False False \n", + "14 O(C(/C=C/C)=O)c1c([N+](=O)[O-])cc(cc1[N+](=O)[... False False False \n", + "15 O([Al]=O)[Si](O[Si](=O)O[Al]=O)=O.O.O False False False \n", + "16 [Si](=O)(O[Al]=O)O[Si](=O)O[Al]=O.O.O False False False \n", + "17 O=[Si](O[Al]=O)O[Si](O[Al]=O)=O.O.O False False False \n", + "18 [Al](=O)O[Si](O[Si](O[Al]=O)=O)=O.O.O False False False \n", + "19 [Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O.O.O False False False \n", + "20 CCO/N=C(/C1=C(CC(c2c(C)cc(C)cc2C)CC1=O)O)CC False False False \n", + "21 CCO/N=C(\\CC)C1=C(CC(c2c(cc(C)cc2C)C)CC1=O)O False False False \n", + "22 CCO/N=C(\\CC)C1C(=O)CC(c2c(C)cc(C)cc2C)CC=1O False False False \n", + "23 C(C)/C(C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)=N\\OCC False False False \n", + "24 O(/N=C(\\CC)C1=C(CC(CC1=O)c1c(C)cc(C)cc1C)O)CC False False False \n", + "25 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si]([O-]... False False False \n", + "26 [OH-].[OH-].[OH-].[OH-].O=[Si](O[Si](=O)[O-])[... False False False \n", + "27 [OH-].[OH-].[OH-].[OH-].[O-][Si](O[Si](=O)[O-]... False False False \n", + "28 [OH-].[OH-].[OH-].[OH-].[Si](=O)(O[Si]([O-])=O... False False False \n", + "29 [OH-].[OH-].[OH-].[OH-].O=[Si](O[Si]([O-])=O)[... False False False \n", "\n", " 2571 2580 2634 3098 ... 176910 177333 183508 183509 189832 189840 \\\n", "0 False False False False ... False False False False False False \n", @@ -1724,6 +1835,7 @@ "26 False False False False ... False False False False False False \n", "27 False False False False ... False False False False False False \n", "28 False False False False ... False False False False False False \n", + "29 False False False False ... False False False False False False \n", "\n", " 192499 194321 197504 229684 \n", "0 False False False False \n", @@ -1755,42 +1867,43 @@ "26 False False False False \n", "27 False False False False \n", "28 False False False False \n", + "29 False False False False \n", "\n", - "[29 rows x 1514 columns]" + "[30 rows x 1514 columns]" ] }, - "execution_count": 288, + "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "test_data_df" + "data_df" ] }, { "cell_type": "code", - "execution_count": 289, + "execution_count": 132, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(29, 1514)" + "(30, 1514)" ] }, - "execution_count": 289, + "execution_count": 132, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "test_data_df.shape" + "data_df.shape" ] }, { "cell_type": "code", - "execution_count": 290, + "execution_count": 133, "metadata": {}, "outputs": [ { @@ -1798,7 +1911,7 @@ "text/plain": [ "id 5\n", "name 5\n", - "SMILES 29\n", + "SMILES 30\n", "1722 1\n", "2440 1\n", " ..\n", @@ -1810,21 +1923,105 @@ "Length: 1514, dtype: int64" ] }, - "execution_count": 290, + "execution_count": 133, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "test_data_df.nunique()" + "data_df.nunique()" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [], + "source": [ + "def find_smiles_variations(smiles):\n", + " original_smiles = smiles\n", + " mol = Chem.MolFromSmiles(original_smiles)\n", + " smiles_variations=Chem.MolToSmiles(mol,doRandom=True,rootedAtAtom=2,canonical=False)\n", + " return smiles_variations\n" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:59:42] \n", + "\n", + "****\n", + "Pre-condition Violation\n", + "rootedAtomAtom must be less than the number of atoms\n", + "Violation occurred on line 534 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\SmilesParse\\SmilesWrite.cpp\n", + "Failed Expression: params.rootedAtAtom < 0 || static_cast(params.rootedAtAtom) < mol.getNumAtoms()\n", + "****\n", + "\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Pre-condition Violation\n\trootedAtomAtom must be less than the number of atoms\n\tViolation occurred on line 534 in file Code\\GraphMol\\SmilesParse\\SmilesWrite.cpp\n\tFailed Expression: params.rootedAtAtom < 0 || static_cast(params.rootedAtAtom) < mol.getNumAtoms()\n\tRDKIT: 2024.03.5\n\tBOOST: 1_85\n", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[135], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m smile_variations\u001b[38;5;241m=\u001b[39m \u001b[43mfind_smiles_variations\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m[Cl-].[H][N+]([H])([H])[H]\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[1;32mIn[134], line 4\u001b[0m, in \u001b[0;36mfind_smiles_variations\u001b[1;34m(smiles)\u001b[0m\n\u001b[0;32m 2\u001b[0m original_smiles \u001b[38;5;241m=\u001b[39m smiles\n\u001b[0;32m 3\u001b[0m mol \u001b[38;5;241m=\u001b[39m Chem\u001b[38;5;241m.\u001b[39mMolFromSmiles(original_smiles)\n\u001b[1;32m----> 4\u001b[0m smiles_variations\u001b[38;5;241m=\u001b[39m\u001b[43mChem\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMolToSmiles\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmol\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdoRandom\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mrootedAtAtom\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mcanonical\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m smiles_variations\n", + "\u001b[1;31mRuntimeError\u001b[0m: Pre-condition Violation\n\trootedAtomAtom must be less than the number of atoms\n\tViolation occurred on line 534 in file Code\\GraphMol\\SmilesParse\\SmilesWrite.cpp\n\tFailed Expression: params.rootedAtAtom < 0 || static_cast(params.rootedAtAtom) < mol.getNumAtoms()\n\tRDKIT: 2024.03.5\n\tBOOST: 1_85\n" + ] + } + ], + "source": [ + "smile_variations= find_smiles_variations(\"[Cl-].[H][N+]([H])([H])[H]\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'C(C(O)=O)(C(C)C(O)=O)N'" + ] + }, + "execution_count": 565, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "Each time we run it is generating different number of new variations" + "smile_variations" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 71baaf3f..16b65d26 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -24,10 +24,15 @@ MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit, ) -from rdkit import Chem +from rdkit import Chem, RDLogger +from tqdm import tqdm from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule +import random + +# Suppress RDKit warnings and errors +RDLogger.DisableLog('rdApp.*') # Disable all RDKit logging # exclude some entities from the dataset because the violate disjointness axioms CHEBI_BLACKLIST = [ @@ -815,7 +820,7 @@ def augment_data(self, path: str) -> None: # Process each row in the original DataFrame print("Generating New SMILES") - for _, row in data.iterrows(): + for _, row in tqdm(data.iterrows(), total=len(data), desc="Processing Rows", unit="row"): original_smiles = row['SMILES'] # Generate new SMILES variations variations = self.generate_smiles_variations(original_smiles) @@ -838,9 +843,40 @@ def augment_data(self, path: str) -> None: def save_file(self, dataset: pd.DataFrame, file_path: str): pd.to_pickle(dataset, open(file_path, "wb")) + # # Function to generate SMILES variations using different configurations + # def generate_smiles_variations1(self, original_smiles): + # num_variations=5 + # # print(type(original_smiles), original_smiles) + # if not isinstance(original_smiles, str): + # print(f"Non-string SMILES found: {original_smiles}") + # mol = Chem.MolFromSmiles(original_smiles) + # if mol is None: + # return [] # Return an empty list if conversion fails + # + # variations = set() + # + # # Loop through all combinations of doRandom and rootedAtAtom values + # for do_random in [True, False]: + # for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]: + # try: + # # Generate SMILES with the given configuration + # variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) + # if variant != original_smiles: # Avoid duplicates with the original SMILES + # variations.add(variant) + # + # # Check the number of variations after adding + # if len(variations) >= num_variations: + # return list(variations) # Return immediately when enough variations are found + # + # except Exception as e: + # # Skip invalid configurations + # continue + # + # return list(variations) + # Function to generate SMILES variations using different configurations def generate_smiles_variations(self, original_smiles): - num_variations=5 + num_variations = 5 print(type(original_smiles), original_smiles) if not isinstance(original_smiles, str): print(f"Non-string SMILES found: {original_smiles}") @@ -850,13 +886,37 @@ def generate_smiles_variations(self, original_smiles): variations = set() + # List of rootedAtAtom values to pick from randomly + rooted_at_atoms = [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5] + random.shuffle(rooted_at_atoms) # Randomize the order of rootedAtAtom values + + # Flag to track if we've already computed a SMILES with doRandom=False and a negative rootedAtAtom + already_computed_negative_rooted = False + # Loop through all combinations of doRandom and rootedAtAtom values for do_random in [True, False]: - for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]: + for rooted_at_atom in rooted_at_atoms: try: + # Skip redundant computations + if not do_random and rooted_at_atom < 0: + if already_computed_negative_rooted: + continue + already_computed_negative_rooted = True + # Generate SMILES with the given configuration - variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) - if variant != original_smiles: # Avoid duplicates with the original SMILES + variant = Chem.MolToSmiles( + mol, + doRandom=do_random, + rootedAtAtom=rooted_at_atom, + canonical=False + ) + + # # Print the configuration and the generated SMILES string + # print( + # f"Config: doRandom={do_random}, rootedAtAtom={rooted_at_atom}, canonical={False} -> SMILES: {variant}") + + # Avoid duplicates with the original SMILES + if variant != original_smiles: variations.add(variant) # Check the number of variations after adding From cba05620f92a4e486ed4e9744c397df8c3e3ad31 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 5 Sep 2024 16:48:45 +0200 Subject: [PATCH 05/35] Changed directory and added yml config --- chebai/preprocessing/datasets/chebi.py | 106 ++++++++++++++++--------- configs/data/chebi50.yml | 2 + 2 files changed, 70 insertions(+), 38 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 16b65d26..3fc0d4f2 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -142,6 +142,7 @@ def __init__( self, chebi_version_train: Optional[int] = None, single_class: Optional[int] = None, + aug_data: Optional[bool] = False, **kwargs, ): # predict only single class (given as id of one of the classes present in the raw data set) @@ -155,6 +156,7 @@ def __init__( self.dynamic_df_train = None self.dynamic_df_test = None self.dynamic_df_val = None + self.aug_data = aug_data if self.chebi_version_train is not None: # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given @@ -456,7 +458,10 @@ def setup_processed(self) -> None: os.path.join(self.processed_dir, processed_name), ) - if not os.path.isfile(os.path.join(self.processed_dir, "augmented_data.pt")): + augmented_dir = os.path.join("data", "augmented_dataset") + + # Define the augmented data file path + if not os.path.isfile(os.path.join(augmented_dir, "augmented_data.pt")): print( f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", "augmented_data.pt", @@ -464,13 +469,14 @@ def setup_processed(self) -> None: torch.save( self._load_data_from_file( os.path.join( - self.processed_dir_main, + augmented_dir, "augmented_data.pkl", ) ), - os.path.join(self.processed_dir, "augmented_data.pt"), + os.path.join(augmented_dir, "augmented_data.pt"), ) + # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist if self.chebi_version_train is not None and not os.path.isfile( os.path.join( @@ -622,6 +628,16 @@ def base_dir(self) -> str: """ return os.path.join("data", f"chebi_v{self.chebi_version}") + @property + def augmented_dir(self) -> str: + """ + Return the base directory path for data. + + Returns: + str: The base directory path for data. + """ + return os.path.join("data", "chebi_augmented") + @property def processed_file_names_dict(self) -> dict: """ @@ -803,42 +819,56 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: def augment_data(self, path: str) -> None: print(("inside_augment_data")) - if not os.path.isfile( - os.path.join( - path, "augmented_data.pkl")): + if self.aug_data: if os.path.isfile(os.path.join( - path, self.raw_file_names_dict["data"] - )): - data = self.read_file(os.path.join( - path, self.raw_file_names_dict["data"])) - print("Original Dataset size:",data.shape) - # Create a new empty DataFrame for storing new variations - new_df = pd.DataFrame(columns=data.columns) - - # Set to keep track of already seen SMILES to avoid duplicates - seen_smiles = set(data['SMILES']) - - # Process each row in the original DataFrame - print("Generating New SMILES") - for _, row in tqdm(data.iterrows(), total=len(data), desc="Processing Rows", unit="row"): - original_smiles = row['SMILES'] - # Generate new SMILES variations - variations = self.generate_smiles_variations(original_smiles) - - # Filter out variations that are already seen - variations = [var for var in variations if var not in seen_smiles] - - for var in variations: - # Create a new row with the new SMILES and the rest of the features and labels unchanged - new_row = row.copy() - new_row['SMILES'] = var - new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True) - - # Add the new SMILES to the seen set to avoid duplicates - seen_smiles.add(var) - - new_dataset = pd.concat([data, new_df], ignore_index=True) - self.save_file(new_dataset, os.path.join(path, "augmented_data.pkl")) + path, self.raw_file_names_dict["data"])): + + augmented_dir = os.path.join("data", "augmented_dataset") + + # Check if the augmented directory exists, if not, create it + os.makedirs(augmented_dir, exist_ok=True) + + # Define the augmented data file path + augmented_data_file = os.path.join(augmented_dir, "augmented_data.pkl") + + # If augmented_data.pkl does not already exist, proceed with the logic + if not os.path.isfile(augmented_data_file): + + data = self.read_file(os.path.join( + path, self.raw_file_names_dict["data"])) + print("Original Dataset size:", data.shape) + # Create a new empty DataFrame for storing new variations + # #testing + # data=data[:5] + # print("Test Dataset size:", data.shape) + new_df = pd.DataFrame(columns=data.columns) + + # Set to keep track of already seen SMILES to avoid duplicates + seen_smiles = set(data['SMILES']) + + # Process each row in the original DataFrame + print("Generating New SMILES") + for _, row in tqdm(data.iterrows(), total=len(data), desc="Processing Rows", unit="row"): + original_smiles = row['SMILES'] + # Generate new SMILES variations + variations = self.generate_smiles_variations(original_smiles) + + # Filter out variations that are already seen + variations = [var for var in variations if var not in seen_smiles] + + for var in variations: + # Create a new row with the new SMILES and the rest of the features and labels unchanged + new_row = row.copy() + new_row['SMILES'] = var + new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True) + + # Add the new SMILES to the seen set to avoid duplicates + seen_smiles.add(var) + + new_dataset = pd.concat([data, new_df], ignore_index=True) + self.save_file(new_dataset, os.path.join(augmented_dir, "augmented_data.pkl")) + else: + print("Data Augmentation config is False") def save_file(self, dataset: pd.DataFrame, file_path: str): pd.to_pickle(dataset, open(file_path, "wb")) diff --git a/configs/data/chebi50.yml b/configs/data/chebi50.yml index f89d5932..e0c075af 100644 --- a/configs/data/chebi50.yml +++ b/configs/data/chebi50.yml @@ -1 +1,3 @@ class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50 +init_args: + aug_data: True From 165606d9d1baeacf9b7504695de2e72e0b4de8dc Mon Sep 17 00:00:00 2001 From: vidvath Date: Fri, 6 Sep 2024 16:36:58 +0200 Subject: [PATCH 06/35] Code optimization & Batch processing changes --- chebai/preprocessing/datasets/chebi.py | 132 ++++++++++--------------- configs/data/chebi50.yml | 1 + 2 files changed, 55 insertions(+), 78 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 3fc0d4f2..4c466e40 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -139,11 +139,12 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC): """ def __init__( - self, - chebi_version_train: Optional[int] = None, - single_class: Optional[int] = None, - aug_data: Optional[bool] = False, - **kwargs, + self, + chebi_version_train: Optional[int] = None, + single_class: Optional[int] = None, + aug_data: Optional[bool] = False, + batch_size_:Optional[int]= 5000, + **kwargs, ): # predict only single class (given as id of one of the classes present in the raw data set) self.single_class = single_class @@ -157,6 +158,7 @@ def __init__( self.dynamic_df_test = None self.dynamic_df_val = None self.aug_data = aug_data + self.batch_size_=batch_size_ if self.chebi_version_train is not None: # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given @@ -801,14 +803,13 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: g = self.extract_class_hierarchy(chebi_path) df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"]) self.save_processed(df, filename=self.raw_file_names_dict["data"]) - print("Reached before augment data") - self.augment_data(self.processed_dir_main) + self.augment_data(self.processed_dir_main,self.batch_size_) if self.chebi_version_train is not None: if not os.path.isfile( - os.path.join( - self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.raw_file_names_dict["data"], - ) + os.path.join( + self._chebi_version_train_obj.processed_dir_main, + self._chebi_version_train_obj.raw_file_names_dict["data"], + ) ): print( f"Missing processed data related to train version: {self.chebi_version_train}" @@ -817,17 +818,14 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: # Generate the "chebi_version_train" data if it doesn't exist self._chebi_version_train_obj.prepare_data(*args, **kwargs) - def augment_data(self, path: str) -> None: + def augment_data(self, path: str, batch_size) -> None: print(("inside_augment_data")) if self.aug_data: if os.path.isfile(os.path.join( path, self.raw_file_names_dict["data"])): - augmented_dir = os.path.join("data", "augmented_dataset") - # Check if the augmented directory exists, if not, create it os.makedirs(augmented_dir, exist_ok=True) - # Define the augmented data file path augmented_data_file = os.path.join(augmented_dir, "augmented_data.pkl") @@ -837,73 +835,51 @@ def augment_data(self, path: str) -> None: data = self.read_file(os.path.join( path, self.raw_file_names_dict["data"])) print("Original Dataset size:", data.shape) - # Create a new empty DataFrame for storing new variations - # #testing - # data=data[:5] - # print("Test Dataset size:", data.shape) - new_df = pd.DataFrame(columns=data.columns) - - # Set to keep track of already seen SMILES to avoid duplicates - seen_smiles = set(data['SMILES']) - - # Process each row in the original DataFrame - print("Generating New SMILES") - for _, row in tqdm(data.iterrows(), total=len(data), desc="Processing Rows", unit="row"): - original_smiles = row['SMILES'] - # Generate new SMILES variations - variations = self.generate_smiles_variations(original_smiles) - - # Filter out variations that are already seen - variations = [var for var in variations if var not in seen_smiles] - - for var in variations: - # Create a new row with the new SMILES and the rest of the features and labels unchanged - new_row = row.copy() - new_row['SMILES'] = var - new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True) - - # Add the new SMILES to the seen set to avoid duplicates - seen_smiles.add(var) - - new_dataset = pd.concat([data, new_df], ignore_index=True) - self.save_file(new_dataset, os.path.join(augmented_dir, "augmented_data.pkl")) + total_rows = data.shape[0] + # Calculate the total number of batches + total_batches = (total_rows + batch_size - 1) // batch_size + + for batch_num, start in enumerate(range(0, total_rows, batch_size), start=1): + end = min(start + batch_size, total_rows) + batch = data[start:end] + print(f"Processing batch {batch_num}/{total_batches} ({start} to {end})") + + # Set to keep track of already seen SMILES + seen_smiles = set(batch['SMILES']) + + # Store new rows in a list instead of concatenating directly + new_rows = [] + + # Updated tqdm to show batch number and total batches + for _, row in tqdm(batch.iterrows(), total=len(batch), + desc=f"Batch {batch_num}/{total_batches}", unit="row"): + original_smiles = row['SMILES'] + variations = self.generate_smiles_variations(original_smiles) + variations = [var for var in variations if var not in seen_smiles] + + for var in variations: + new_row = row.copy() + new_row['SMILES'] = var + new_rows.append(new_row) + seen_smiles.add(var) + + # Create a DataFrame from the new rows + new_df = pd.DataFrame(new_rows) + + # Concatenate the batch and new variations, and save incrementally + batch_with_variations = pd.concat([batch, new_df], ignore_index=True) + + # Append batch to augmented file + self.save_file(batch_with_variations, augmented_data_file, append=True) else: print("Data Augmentation config is False") - def save_file(self, dataset: pd.DataFrame, file_path: str): + def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): + if append and os.path.isfile(file_path): + existing_data = pd.read_pickle(file_path) + dataset = pd.concat([existing_data, dataset], ignore_index=True) pd.to_pickle(dataset, open(file_path, "wb")) - # # Function to generate SMILES variations using different configurations - # def generate_smiles_variations1(self, original_smiles): - # num_variations=5 - # # print(type(original_smiles), original_smiles) - # if not isinstance(original_smiles, str): - # print(f"Non-string SMILES found: {original_smiles}") - # mol = Chem.MolFromSmiles(original_smiles) - # if mol is None: - # return [] # Return an empty list if conversion fails - # - # variations = set() - # - # # Loop through all combinations of doRandom and rootedAtAtom values - # for do_random in [True, False]: - # for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]: - # try: - # # Generate SMILES with the given configuration - # variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) - # if variant != original_smiles: # Avoid duplicates with the original SMILES - # variations.add(variant) - # - # # Check the number of variations after adding - # if len(variations) >= num_variations: - # return list(variations) # Return immediately when enough variations are found - # - # except Exception as e: - # # Skip invalid configurations - # continue - # - # return list(variations) - # Function to generate SMILES variations using different configurations def generate_smiles_variations(self, original_smiles): num_variations = 5 @@ -959,7 +935,7 @@ def generate_smiles_variations(self, original_smiles): return list(variations) - def read_file(self,file_path : str): + def read_file(self, file_path: str): df = pd.read_pickle( open(file_path, "rb" ) diff --git a/configs/data/chebi50.yml b/configs/data/chebi50.yml index e0c075af..d2ac29db 100644 --- a/configs/data/chebi50.yml +++ b/configs/data/chebi50.yml @@ -1,3 +1,4 @@ class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50 init_args: aug_data: True + batch_size_: 5000 From d42294818febe813f69ab3279f2e429de22ce030 Mon Sep 17 00:00:00 2001 From: vidvath Date: Fri, 13 Sep 2024 18:30:35 +0200 Subject: [PATCH 07/35] Reverted changes to original --- chebai/molecule.py | 89 - .../preprocessing/datasets/augmentation.ipynb | 2048 ----------------- 2 files changed, 2137 deletions(-) delete mode 100644 chebai/preprocessing/datasets/augmentation.ipynb diff --git a/chebai/molecule.py b/chebai/molecule.py index 33c05f17..5acd5546 100644 --- a/chebai/molecule.py +++ b/chebai/molecule.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division import logging -from itertools import product from typing import Any, List, Optional, Union try: @@ -468,95 +467,7 @@ def num_bond_features() -> int: return len(Molecule.bond_features(simple_mol.GetBonds()[0])) - @staticmethod - def find_smile(): - original_smiles = "OC(=O)C(C(N)C(O)=O)C" - mol = Chem.MolFromSmiles(original_smiles) - smile=Chem.MolToSmiles(mol,doRandom=True,rootedAtAtom=4) - return smile - - - - @staticmethod - def find_smile1(): - original_smiles = "OC(=O)C(C(N)C(O)=O)C" - mol = Chem.MolFromSmiles(original_smiles) - - # Test combinations of doRandom and rootedAtAtom - for do_random in [True, False]: - for rooted_at_atom in [4, 3, 2, 1, 0, -1 - 2, -3, -4, -5]: - try: - smiles = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) - print(f"Configuration: doRandom={do_random}, rootedAtAtom={rooted_at_atom}\n{smiles}\n") - except Exception as e: - print(f"Error with configuration: doRandom={do_random}, rootedAtAtom={rooted_at_atom}\nError: {e}\n") - - - @staticmethod - def find_smiles(): - original_smiles = "[Cl-].[H][N+]([H])([H])[H]" - mol = Chem.MolFromSmiles(original_smiles) - - # Dictionary to store SMILES and the corresponding configurations - smiles_dict = {} - - # Test combinations of doRandom and rootedAtAtom - for do_random in [True, False]: - for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]: - try: - smiles = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom) - config = f"doRandom={do_random}, rootedAtAtom={rooted_at_atom}" - - if smiles in smiles_dict: - smiles_dict[smiles].append(config) - else: - smiles_dict[smiles] = [config] - - print(f"Configuration: {config}\n{smiles}\n") - except Exception as e: - print(f"Error with configuration: doRandom={do_random}, rootedAtAtom={rooted_at_atom}\nError: {e}\n") - - # Print configurations that generated the same SMILES string - for smiles, configs in smiles_dict.items(): - if len(configs) > 1: - print(f"SMILES string '{smiles}' was generated by the following configurations:") - for config in configs: - print(f"- {config}") - print("\n") - - - - @staticmethod - def find_config(): - original_smiles = "Oc1ccc2nccc(O)c2c1" - mol = Chem.MolFromSmiles(original_smiles) - - # List of Boolean parameters to iterate over - boolean_params = ['isomericSmiles', 'kekuleSmiles', 'canonical', 'allBondsExplicit', 'allHsExplicit', - 'doRandom'] - - # Generate all combinations of True/False for these parameters - combinations = list(product([True, False], repeat=len(boolean_params))) - - # Dictionary to store the configuration and its generated SMILES - results = [] - - # Iterate through all combinations - for combination in combinations: - params = dict(zip(boolean_params, combination)) - smiles = Chem.MolToSmiles(mol, **params) - results.append({ - 'config': params, - 'generated_smiles': smiles, - 'matches_original': smiles == original_smiles - }) - - return results - - if __name__ == "__main__": - - # print(Molecule.find_smiles()) log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger(__name__) diff --git a/chebai/preprocessing/datasets/augmentation.ipynb b/chebai/preprocessing/datasets/augmentation.ipynb deleted file mode 100644 index 8a09e825..00000000 --- a/chebai/preprocessing/datasets/augmentation.ipynb +++ /dev/null @@ -1,2048 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 110, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [], - "source": [ - "# Path to the original data.pkl file\n", - "data_path = \"D:\\Knowledge\\Hiwi\\python-chebai\\data\\chebi_v231\\ChEBI50\\processed\\data.pkl\"\n", - "# data_path1=\"data\\chebi_v231\\ChEBI50\\processed\\data.pkl\"" - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
033429monoatomic monoanion[*-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
130151aluminide(1-)[Al-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
216042halide anion[*-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
317051fluoride[F-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
428741sodium fluoride[F-].[Na+]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", - "

5 rows × 1514 columns

\n", - "
" - ], - "text/plain": [ - " id name SMILES 1722 2440 2468 2571 2580 \\\n", - "0 33429 monoatomic monoanion [*-] False False False False False \n", - "1 30151 aluminide(1-) [Al-] False False False False False \n", - "2 16042 halide anion [*-] False False False False False \n", - "3 17051 fluoride [F-] False False False False False \n", - "4 28741 sodium fluoride [F-].[Na+] False False False False False \n", - "\n", - " 2634 3098 ... 176910 177333 183508 183509 189832 189840 192499 \\\n", - "0 False False ... False False False False False False False \n", - "1 False False ... False False False False False False False \n", - "2 False False ... False False False False False False False \n", - "3 False False ... False False False False False False False \n", - "4 False False ... False False False False False False False \n", - "\n", - " 194321 197504 229684 \n", - "0 False False False \n", - "1 False False False \n", - "2 False False False \n", - "3 False False False \n", - "4 False False False \n", - "\n", - "[5 rows x 1514 columns]" - ] - }, - "execution_count": 112, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pd.read_pickle(\n", - " open(data_path, \"rb\"\n", - " )\n", - ")\n", - "df[:5]" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(185007, 1514)" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
16992112763N2,N4-bis[[4-(dimethylamino)phenyl]methylidene...CC1=C(NC(=C1C(=O)NN=CC2=CC=C(C=C2)N(C)C)C)C(=O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", - "

1 rows × 1514 columns

\n", - "
" - ], - "text/plain": [ - " id name \\\n", - "16992 112763 N2,N4-bis[[4-(dimethylamino)phenyl]methylidene... \n", - "\n", - " SMILES 1722 2440 2468 \\\n", - "16992 CC1=C(NC(=C1C(=O)NN=CC2=CC=C(C=C2)N(C)C)C)C(=O... False False False \n", - "\n", - " 2571 2580 2634 3098 ... 176910 177333 183508 183509 \\\n", - "16992 False False False False ... False False False False \n", - "\n", - " 189832 189840 192499 194321 197504 229684 \n", - "16992 False False False False False False \n", - "\n", - "[1 rows x 1514 columns]" - ] - }, - "execution_count": 114, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "result = df[df[\"SMILES\"] == \"CC1=C(NC(=C1C(=O)NN=CC2=CC=C(C=C2)N(C)C)C)C(=O)NN=CC3=CC=C(C=C3)N(C)C\"]\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "metadata": {}, - "outputs": [], - "source": [ - "# Create a new empty DataFrame for storing new variations\n", - "new_df = pd.DataFrame(columns=df.columns)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 116, - "metadata": {}, - "outputs": [], - "source": [ - "# Function to generate SMILES variations using different configurations\n", - "def generate_smiles_variations1(smiles, num_variations=5):\n", - " mol = Chem.MolFromSmiles(smiles)\n", - " if mol is None:\n", - " return [] # Return an empty list if conversion fails\n", - "\n", - " variations = set()\n", - "\n", - " # Loop through all combinations of doRandom and rootedAtAtom values\n", - " for do_random in [True, False]:\n", - " for rooted_at_atom in [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]:\n", - " try:\n", - " # Generate SMILES with the given configuration\n", - " variant = Chem.MolToSmiles(mol, doRandom=do_random, rootedAtAtom=rooted_at_atom)\n", - " if variant != smiles: # Avoid duplicates with the original SMILES\n", - " variations.add(variant)\n", - " # print(\"len-variations:\", len(variations))\n", - "\n", - " # Check the number of variations after adding\n", - " if len(variations) >= num_variations:\n", - " return list(variations) # Return immediately when enough variations are found\n", - "\n", - " except Exception as e:\n", - " # Skip invalid configurations\n", - " continue\n", - "\n", - " return list(variations)" - ] - }, - { - "cell_type": "code", - "execution_count": 117, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "from rdkit import Chem\n", - "from tqdm import tqdm\n", - "\n", - "# Function to generate SMILES variations using different configurations\n", - "def generate_smiles_variations(smiles, num_variations=5):\n", - " \"\"\"\n", - " Generates a list of SMILES variations based on different configurations.\n", - "\n", - " Parameters:\n", - " smiles (str): The input SMILES string.\n", - " num_variations (int): The number of SMILES variations to generate.\n", - " canonical (bool): Whether to generate canonical SMILES.\n", - "\n", - " Returns:\n", - " list: A list of unique SMILES variations.\n", - " \"\"\"\n", - " mol = Chem.MolFromSmiles(smiles)\n", - " if mol is None:\n", - " return [] # Return an empty list if conversion fails\n", - "\n", - " variations = set()\n", - "\n", - " # List of rootedAtAtom values to pick from randomly\n", - " rooted_at_atoms = [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5]\n", - " random.shuffle(rooted_at_atoms) # Randomize the order of rootedAtAtom values\n", - "\n", - " # Flag to track if we've already computed a SMILES with doRandom=False and a negative rootedAtAtom\n", - " already_computed_negative_rooted = False\n", - " # Initialize tqdm progress bar for SMILES variation generation\n", - " with tqdm(total=num_variations, desc=\"Generating SMILES Variations\", unit=\"variant\", leave=False) as pbar:\n", - " # Loop through all combinations of doRandom and rootedAtAtom values\n", - " for do_random in [True, False]:\n", - " for rooted_at_atom in rooted_at_atoms:\n", - " try:\n", - " # Skip redundant computations\n", - " if not do_random and rooted_at_atom < 0:\n", - " if already_computed_negative_rooted:\n", - " continue\n", - " already_computed_negative_rooted = True\n", - "\n", - " # Generate SMILES with the given configuration\n", - " variant = Chem.MolToSmiles(\n", - " mol, \n", - " doRandom=do_random, \n", - " rootedAtAtom=rooted_at_atom, \n", - " canonical=False\n", - " )\n", - "\n", - " # Print the configuration and the generated SMILES string\n", - " # print(f\"Config: doRandom={do_random}, rootedAtAtom={rooted_at_atom}, canonical={False} -> SMILES: {variant}\")\n", - " \n", - " # Avoid duplicates with the original SMILES\n", - " if variant != smiles:\n", - " variations.add(variant)\n", - " pbar.update(1) # Update tqdm progress bar with each new variant\n", - "\n", - " # Check the number of variations after adding\n", - " if len(variations) >= num_variations:\n", - " pbar.close() # Close the progress bar when done\n", - " return list(variations) # Return immediately when enough variations are found\n", - "\n", - " except Exception as e:\n", - " # Skip invalid configurations\n", - " continue\n", - " pbar.close() # Close the progress bar if not already closed\n", - " return list(variations)\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 118, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['C(CC[N+]1(C)CCCC1)(C1CCCCC1)(c1ccccc1)O.[Cl-]', 'C(O)(C1CCCCC1)(c1ccccc1)CC[N+]1(CCCC1)C.[Cl-]', 'C(c1ccccc1)(C1CCCCC1)(CC[N+]1(C)CCCC1)O.[Cl-]', '[N+]1(C)(CCCC1)CCC(C1CCCCC1)(O)c1ccccc1.[Cl-]', 'C(O)(C1CCCCC1)(c1ccccc1)CC[N+]1(C)CCCC1.[Cl-]']\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r" - ] - } - ], - "source": [ - "# Example usage\n", - "smile1=\"OC(=O)C(C(N)C(O)=O)C\"\n", - "smile2=\"[Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O\"\n", - "smile3=\"[Cl-].[H][N+]([H])([H])[H]\"\n", - "smile4=\"[Ca++].OC[C@@H](O)[C@@H](O)[C@H](O)[C@@H](O)C(O)C([O-])=O.OC[C@@H](O)[C@@H](O)[C@H](O)[C@@H](O)C(O)C([O-])=O\"\n", - "smile5=\"C(CC[N+]1(C)CCCC1)(O)(C2CCCCC2)C3=CC=CC=C3.[Cl-]\"\n", - "variations = generate_smiles_variations(smile5, num_variations=5)\n", - "print(variations)" - ] - }, - { - "cell_type": "code", - "execution_count": 119, - "metadata": {}, - "outputs": [], - "source": [ - "# Set to keep track of already seen SMILES to avoid duplicates\n", - "seen_smiles = set(df['SMILES'])" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "metadata": {}, - "outputs": [], - "source": [ - "test_df=df[-5::]" - ] - }, - { - "cell_type": "code", - "execution_count": 121, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
1850112295182-Amino-3-methylsuccinic acidOC(=O)C(C(N)C(O)=O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
18501283380dinocap-4C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
185013140503kaolin[Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
18501481948tralkoxydimCCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
185015140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", - "

5 rows × 1514 columns

\n", - "
" - ], - "text/plain": [ - " id name \\\n", - "185011 229518 2-Amino-3-methylsuccinic acid \n", - "185012 83380 dinocap-4 \n", - "185013 140503 kaolin \n", - "185014 81948 tralkoxydim \n", - "185015 140499 kaolinite \n", - "\n", - " SMILES 1722 2440 \\\n", - "185011 OC(=O)C(C(N)C(O)=O)C False False \n", - "185012 C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O... False False \n", - "185013 [Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O False False \n", - "185014 CCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1C False False \n", - "185015 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[... False False \n", - "\n", - " 2468 2571 2580 2634 3098 ... 176910 177333 183508 \\\n", - "185011 False False False False False ... False False False \n", - "185012 False False False False False ... False False False \n", - "185013 False False False False False ... False False False \n", - "185014 False False False False False ... False False False \n", - "185015 False False False False False ... False False False \n", - "\n", - " 183509 189832 189840 192499 194321 197504 229684 \n", - "185011 False False False False False False False \n", - "185012 False False False False False False False \n", - "185013 False False False False False False False \n", - "185014 False False False False False False False \n", - "185015 False False False False False False False \n", - "\n", - "[5 rows x 1514 columns]" - ] - }, - "execution_count": 121, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_df" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id 5\n", - "name 5\n", - "SMILES 5\n", - "1722 1\n", - "2440 1\n", - " ..\n", - "189840 1\n", - "192499 1\n", - "194321 1\n", - "197504 1\n", - "229684 1\n", - "Length: 1514, dtype: int64" - ] - }, - "execution_count": 122, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_df.nunique()" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "metadata": {}, - "outputs": [], - "source": [ - "from tqdm import tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 124, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing Rows: 60%|██████ | 3/5 [00:00<00:00, 6.54row/s][17:59:42] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 4 < 1\n", - "****\n", - "\n", - "[17:59:42] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 5 < 1\n", - "****\n", - "\n", - "[17:59:42] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 1 < 1\n", - "****\n", - "\n", - "[17:59:42] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 3 < 1\n", - "****\n", - "\n", - "[17:59:42] \n", - "\n", - "****\n", - "Range Error\n", - "idx\n", - "Violation occurred on line 209 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\ROMol.cpp\n", - "Failed Expression: 2 < 1\n", - "****\n", - "\n", - "Processing Rows: 100%|██████████| 5/5 [00:00<00:00, 7.81row/s]\n" - ] - } - ], - "source": [ - "# Process each row in the original DataFrame\n", - "for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc=\"Processing Rows\", unit=\"row\"):\n", - " original_smiles = row['SMILES']\n", - " \n", - " # Generate new SMILES variations\n", - " variations = generate_smiles_variations(original_smiles)\n", - " \n", - " # Filter out variations that are already seen\n", - " variations = [var for var in variations if var not in seen_smiles]\n", - " \n", - " for var in variations:\n", - " # Create a new row with the new SMILES and the rest of the features and labels unchanged\n", - " new_row = row.copy()\n", - " new_row['SMILES'] = var\n", - " new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True)\n", - " \n", - " # Add the new SMILES to the seen set to avoid duplicates\n", - " seen_smiles.add(var)" - ] - }, - { - "cell_type": "code", - "execution_count": 125, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id 5\n", - "name 5\n", - "SMILES 25\n", - "1722 1\n", - "2440 1\n", - " ..\n", - "189840 1\n", - "192499 1\n", - "194321 1\n", - "197504 1\n", - "229684 1\n", - "Length: 1514, dtype: int64" - ] - }, - "execution_count": 125, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "new_df.nunique()" - ] - }, - { - "cell_type": "code", - "execution_count": 126, - "metadata": {}, - "outputs": [], - "source": [ - "# Append the new DataFrame (new_df) to the original DataFrame (df)\n", - "df_combined = pd.concat([test_df, new_df], ignore_index=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 127, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(30, 1514)" - ] - }, - "execution_count": 127, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_combined.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "metadata": {}, - "outputs": [], - "source": [ - "new_data_path=\"augmented_data.pkl\"" - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [], - "source": [ - "pd.to_pickle(df_combined, open(new_data_path, \"wb\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": {}, - "outputs": [], - "source": [ - "data_df= pd.read_pickle(\n", - " open(\"augmented_data.pkl\", \"rb\"\n", - " )\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 131, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idnameSMILES1722244024682571258026343098...176910177333183508183509189832189840192499194321197504229684
02295182-Amino-3-methylsuccinic acidOC(=O)C(C(N)C(O)=O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
183380dinocap-4C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2140503kaolin[Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
381948tralkoxydimCCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
4140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
52295182-Amino-3-methylsuccinic acidC(O)(C(C(N)C(O)=O)C)=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
62295182-Amino-3-methylsuccinic acidO=C(C(C(C(=O)O)N)C)OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
72295182-Amino-3-methylsuccinic acidOC(=O)C(C(N)C(=O)O)CFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
82295182-Amino-3-methylsuccinic acidNC(C(C)C(=O)O)C(=O)OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
92295182-Amino-3-methylsuccinic acidOC(=O)C(C)C(C(=O)O)NFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1083380dinocap-4C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(*)cc1[N+](=O)[O-]FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1183380dinocap-4C/C=C/C(=O)Oc1c(cc(*)cc1[N+]([O-])=O)[N+]([O-])=OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1283380dinocap-4C(=O)(/C=C/C)Oc1c(cc(*)cc1[N+](=O)[O-])[N+]([O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1383380dinocap-4C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(cc1[N+]([O-])=O)*FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
1483380dinocap-4O(C(/C=C/C)=O)c1c([N+](=O)[O-])cc(cc1[N+](=O)[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
15140503kaolinO([Al]=O)[Si](O[Si](=O)O[Al]=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
16140503kaolin[Si](=O)(O[Al]=O)O[Si](=O)O[Al]=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
17140503kaolinO=[Si](O[Al]=O)O[Si](O[Al]=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
18140503kaolin[Al](=O)O[Si](O[Si](O[Al]=O)=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
19140503kaolin[Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O.O.OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2081948tralkoxydimCCO/N=C(/C1=C(CC(c2c(C)cc(C)cc2C)CC1=O)O)CCFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2181948tralkoxydimCCO/N=C(\\CC)C1=C(CC(c2c(cc(C)cc2C)C)CC1=O)OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2281948tralkoxydimCCO/N=C(\\CC)C1C(=O)CC(c2c(C)cc(C)cc2C)CC=1OFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2381948tralkoxydimC(C)/C(C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)=N\\OCCFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
2481948tralkoxydimO(/N=C(\\CC)C1=C(CC(CC1=O)c1c(C)cc(C)cc1C)O)CCFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
25140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si]([O-]...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
26140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si](O[Si](=O)[O-])[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
27140499kaolinite[OH-].[OH-].[OH-].[OH-].[O-][Si](O[Si](=O)[O-]...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
28140499kaolinite[OH-].[OH-].[OH-].[OH-].[Si](=O)(O[Si]([O-])=O...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
29140499kaolinite[OH-].[OH-].[OH-].[OH-].O=[Si](O[Si]([O-])=O)[...FalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", - "

30 rows × 1514 columns

\n", - "
" - ], - "text/plain": [ - " id name \\\n", - "0 229518 2-Amino-3-methylsuccinic acid \n", - "1 83380 dinocap-4 \n", - "2 140503 kaolin \n", - "3 81948 tralkoxydim \n", - "4 140499 kaolinite \n", - "5 229518 2-Amino-3-methylsuccinic acid \n", - "6 229518 2-Amino-3-methylsuccinic acid \n", - "7 229518 2-Amino-3-methylsuccinic acid \n", - "8 229518 2-Amino-3-methylsuccinic acid \n", - "9 229518 2-Amino-3-methylsuccinic acid \n", - "10 83380 dinocap-4 \n", - "11 83380 dinocap-4 \n", - "12 83380 dinocap-4 \n", - "13 83380 dinocap-4 \n", - "14 83380 dinocap-4 \n", - "15 140503 kaolin \n", - "16 140503 kaolin \n", - "17 140503 kaolin \n", - "18 140503 kaolin \n", - "19 140503 kaolin \n", - "20 81948 tralkoxydim \n", - "21 81948 tralkoxydim \n", - "22 81948 tralkoxydim \n", - "23 81948 tralkoxydim \n", - "24 81948 tralkoxydim \n", - "25 140499 kaolinite \n", - "26 140499 kaolinite \n", - "27 140499 kaolinite \n", - "28 140499 kaolinite \n", - "29 140499 kaolinite \n", - "\n", - " SMILES 1722 2440 2468 \\\n", - "0 OC(=O)C(C(N)C(O)=O)C False False False \n", - "1 C\\C=C\\C(=O)Oc1c(cc([*])cc1[N+]([O-])=O)[N+]([O... False False False \n", - "2 [Al](O[Si](O[Si](O[Al]=O)=O)=O)=O.O.O False False False \n", - "3 CCO\\N=C(CC)\\C1=C(O)CC(CC1=O)c1c(C)cc(C)cc1C False False False \n", - "4 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si](=O)[... False False False \n", - "5 C(O)(C(C(N)C(O)=O)C)=O False False False \n", - "6 O=C(C(C(C(=O)O)N)C)O False False False \n", - "7 OC(=O)C(C(N)C(=O)O)C False False False \n", - "8 NC(C(C)C(=O)O)C(=O)O False False False \n", - "9 OC(=O)C(C)C(C(=O)O)N False False False \n", - "10 C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(*)cc1[N+](=O)[O-] False False False \n", - "11 C/C=C/C(=O)Oc1c(cc(*)cc1[N+]([O-])=O)[N+]([O-])=O False False False \n", - "12 C(=O)(/C=C/C)Oc1c(cc(*)cc1[N+](=O)[O-])[N+]([O... False False False \n", - "13 C/C=C/C(=O)Oc1c([N+](=O)[O-])cc(cc1[N+]([O-])=O)* False False False \n", - "14 O(C(/C=C/C)=O)c1c([N+](=O)[O-])cc(cc1[N+](=O)[... False False False \n", - "15 O([Al]=O)[Si](O[Si](=O)O[Al]=O)=O.O.O False False False \n", - "16 [Si](=O)(O[Al]=O)O[Si](=O)O[Al]=O.O.O False False False \n", - "17 O=[Si](O[Al]=O)O[Si](O[Al]=O)=O.O.O False False False \n", - "18 [Al](=O)O[Si](O[Si](O[Al]=O)=O)=O.O.O False False False \n", - "19 [Si](O[Al]=O)(O[Si](=O)O[Al]=O)=O.O.O False False False \n", - "20 CCO/N=C(/C1=C(CC(c2c(C)cc(C)cc2C)CC1=O)O)CC False False False \n", - "21 CCO/N=C(\\CC)C1=C(CC(c2c(cc(C)cc2C)C)CC1=O)O False False False \n", - "22 CCO/N=C(\\CC)C1C(=O)CC(c2c(C)cc(C)cc2C)CC=1O False False False \n", - "23 C(C)/C(C1=C(O)CC(CC1=O)c1c(cc(C)cc1C)C)=N\\OCC False False False \n", - "24 O(/N=C(\\CC)C1=C(CC(CC1=O)c1c(C)cc(C)cc1C)O)CC False False False \n", - "25 [OH-].[OH-].[OH-].[OH-].O=[Si]([O-])O[Si]([O-]... False False False \n", - "26 [OH-].[OH-].[OH-].[OH-].O=[Si](O[Si](=O)[O-])[... False False False \n", - "27 [OH-].[OH-].[OH-].[OH-].[O-][Si](O[Si](=O)[O-]... False False False \n", - "28 [OH-].[OH-].[OH-].[OH-].[Si](=O)(O[Si]([O-])=O... False False False \n", - "29 [OH-].[OH-].[OH-].[OH-].O=[Si](O[Si]([O-])=O)[... False False False \n", - "\n", - " 2571 2580 2634 3098 ... 176910 177333 183508 183509 189832 189840 \\\n", - "0 False False False False ... False False False False False False \n", - "1 False False False False ... False False False False False False \n", - "2 False False False False ... False False False False False False \n", - "3 False False False False ... False False False False False False \n", - "4 False False False False ... False False False False False False \n", - "5 False False False False ... False False False False False False \n", - "6 False False False False ... False False False False False False \n", - "7 False False False False ... False False False False False False \n", - "8 False False False False ... False False False False False False \n", - "9 False False False False ... False False False False False False \n", - "10 False False False False ... False False False False False False \n", - "11 False False False False ... False False False False False False \n", - "12 False False False False ... False False False False False False \n", - "13 False False False False ... False False False False False False \n", - "14 False False False False ... False False False False False False \n", - "15 False False False False ... False False False False False False \n", - "16 False False False False ... False False False False False False \n", - "17 False False False False ... False False False False False False \n", - "18 False False False False ... False False False False False False \n", - "19 False False False False ... False False False False False False \n", - "20 False False False False ... False False False False False False \n", - "21 False False False False ... False False False False False False \n", - "22 False False False False ... False False False False False False \n", - "23 False False False False ... False False False False False False \n", - "24 False False False False ... False False False False False False \n", - "25 False False False False ... False False False False False False \n", - "26 False False False False ... False False False False False False \n", - "27 False False False False ... False False False False False False \n", - "28 False False False False ... False False False False False False \n", - "29 False False False False ... False False False False False False \n", - "\n", - " 192499 194321 197504 229684 \n", - "0 False False False False \n", - "1 False False False False \n", - "2 False False False False \n", - "3 False False False False \n", - "4 False False False False \n", - "5 False False False False \n", - "6 False False False False \n", - "7 False False False False \n", - "8 False False False False \n", - "9 False False False False \n", - "10 False False False False \n", - "11 False False False False \n", - "12 False False False False \n", - "13 False False False False \n", - "14 False False False False \n", - "15 False False False False \n", - "16 False False False False \n", - "17 False False False False \n", - "18 False False False False \n", - "19 False False False False \n", - "20 False False False False \n", - "21 False False False False \n", - "22 False False False False \n", - "23 False False False False \n", - "24 False False False False \n", - "25 False False False False \n", - "26 False False False False \n", - "27 False False False False \n", - "28 False False False False \n", - "29 False False False False \n", - "\n", - "[30 rows x 1514 columns]" - ] - }, - "execution_count": 131, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_df" - ] - }, - { - "cell_type": "code", - "execution_count": 132, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(30, 1514)" - ] - }, - "execution_count": 132, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_df.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 133, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id 5\n", - "name 5\n", - "SMILES 30\n", - "1722 1\n", - "2440 1\n", - " ..\n", - "189840 1\n", - "192499 1\n", - "194321 1\n", - "197504 1\n", - "229684 1\n", - "Length: 1514, dtype: int64" - ] - }, - "execution_count": 133, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_df.nunique()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 134, - "metadata": {}, - "outputs": [], - "source": [ - "def find_smiles_variations(smiles):\n", - " original_smiles = smiles\n", - " mol = Chem.MolFromSmiles(original_smiles)\n", - " smiles_variations=Chem.MolToSmiles(mol,doRandom=True,rootedAtAtom=2,canonical=False)\n", - " return smiles_variations\n" - ] - }, - { - "cell_type": "code", - "execution_count": 135, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[17:59:42] \n", - "\n", - "****\n", - "Pre-condition Violation\n", - "rootedAtomAtom must be less than the number of atoms\n", - "Violation occurred on line 534 in file C:\\rdkit\\build\\temp.win-amd64-cpython-311\\Release\\rdkit\\Code\\GraphMol\\SmilesParse\\SmilesWrite.cpp\n", - "Failed Expression: params.rootedAtAtom < 0 || static_cast(params.rootedAtAtom) < mol.getNumAtoms()\n", - "****\n", - "\n" - ] - }, - { - "ename": "RuntimeError", - "evalue": "Pre-condition Violation\n\trootedAtomAtom must be less than the number of atoms\n\tViolation occurred on line 534 in file Code\\GraphMol\\SmilesParse\\SmilesWrite.cpp\n\tFailed Expression: params.rootedAtAtom < 0 || static_cast(params.rootedAtAtom) < mol.getNumAtoms()\n\tRDKIT: 2024.03.5\n\tBOOST: 1_85\n", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[135], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m smile_variations\u001b[38;5;241m=\u001b[39m \u001b[43mfind_smiles_variations\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m[Cl-].[H][N+]([H])([H])[H]\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[1;32mIn[134], line 4\u001b[0m, in \u001b[0;36mfind_smiles_variations\u001b[1;34m(smiles)\u001b[0m\n\u001b[0;32m 2\u001b[0m original_smiles \u001b[38;5;241m=\u001b[39m smiles\n\u001b[0;32m 3\u001b[0m mol \u001b[38;5;241m=\u001b[39m Chem\u001b[38;5;241m.\u001b[39mMolFromSmiles(original_smiles)\n\u001b[1;32m----> 4\u001b[0m smiles_variations\u001b[38;5;241m=\u001b[39m\u001b[43mChem\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMolToSmiles\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmol\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdoRandom\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mrootedAtAtom\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mcanonical\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m smiles_variations\n", - "\u001b[1;31mRuntimeError\u001b[0m: Pre-condition Violation\n\trootedAtomAtom must be less than the number of atoms\n\tViolation occurred on line 534 in file Code\\GraphMol\\SmilesParse\\SmilesWrite.cpp\n\tFailed Expression: params.rootedAtAtom < 0 || static_cast(params.rootedAtAtom) < mol.getNumAtoms()\n\tRDKIT: 2024.03.5\n\tBOOST: 1_85\n" - ] - } - ], - "source": [ - "smile_variations= find_smiles_variations(\"[Cl-].[H][N+]([H])([H])[H]\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'C(C(O)=O)(C(C)C(O)=O)N'" - ] - }, - "execution_count": 565, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "smile_variations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From db4d16e35da4c74418ad378b9bd3099e36ddcf07 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 26 Sep 2024 09:35:33 +0200 Subject: [PATCH 08/35] Changes in rootedAtAtom,added smiles variation config --- chebai/preprocessing/datasets/chebi.py | 41 ++++++++++++++++---------- configs/data/chebi50.yml | 3 +- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 4c466e40..6ff45987 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -143,7 +143,8 @@ def __init__( chebi_version_train: Optional[int] = None, single_class: Optional[int] = None, aug_data: Optional[bool] = False, - batch_size_:Optional[int]= 5000, + augment_data_batch_size:Optional[int]= 10000, + num_smiles_variations:Optional[int]=7, **kwargs, ): # predict only single class (given as id of one of the classes present in the raw data set) @@ -158,7 +159,8 @@ def __init__( self.dynamic_df_test = None self.dynamic_df_val = None self.aug_data = aug_data - self.batch_size_=batch_size_ + self.augment_data_batch_size=augment_data_batch_size + self.num_smiles_variations=num_smiles_variations if self.chebi_version_train is not None: # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given @@ -803,7 +805,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: g = self.extract_class_hierarchy(chebi_path) df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"]) self.save_processed(df, filename=self.raw_file_names_dict["data"]) - self.augment_data(self.processed_dir_main,self.batch_size_) + self.augment_data(self.processed_dir_main, self.augment_data_batch_size) if self.chebi_version_train is not None: if not os.path.isfile( os.path.join( @@ -820,6 +822,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: def augment_data(self, path: str, batch_size) -> None: print(("inside_augment_data")) + print("batch_size",batch_size) if self.aug_data: if os.path.isfile(os.path.join( path, self.raw_file_names_dict["data"])): @@ -882,7 +885,8 @@ def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): # Function to generate SMILES variations using different configurations def generate_smiles_variations(self, original_smiles): - num_variations = 5 + num_variations = self.num_smiles_variations + print("num_variations",num_variations) print(type(original_smiles), original_smiles) if not isinstance(original_smiles, str): print(f"Non-string SMILES found: {original_smiles}") @@ -890,25 +894,30 @@ def generate_smiles_variations(self, original_smiles): if mol is None: return [] # Return an empty list if conversion fails - variations = set() + # Get the number of atoms in the molecule + num_atoms = mol.GetNumAtoms() + + print("num_atoms", num_atoms) + print("num_variations", num_variations) + + # Generate the rooted_at_atoms list based on the number of atoms + if num_atoms < num_variations: + rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) + else: + rooted_at_atoms = list(range(0, num_variations)) # [0, num_variations) - # List of rootedAtAtom values to pick from randomly - rooted_at_atoms = [5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5] - random.shuffle(rooted_at_atoms) # Randomize the order of rootedAtAtom values + print("rooted_at_atoms", rooted_at_atoms) - # Flag to track if we've already computed a SMILES with doRandom=False and a negative rootedAtAtom - already_computed_negative_rooted = False + # Shuffle the rooted_at_atoms list to randomize the order + random.shuffle(rooted_at_atoms) + print("shuffled rooted_at_atoms", rooted_at_atoms) + + variations = set() # Loop through all combinations of doRandom and rootedAtAtom values for do_random in [True, False]: for rooted_at_atom in rooted_at_atoms: try: - # Skip redundant computations - if not do_random and rooted_at_atom < 0: - if already_computed_negative_rooted: - continue - already_computed_negative_rooted = True - # Generate SMILES with the given configuration variant = Chem.MolToSmiles( mol, diff --git a/configs/data/chebi50.yml b/configs/data/chebi50.yml index d2ac29db..a88d4019 100644 --- a/configs/data/chebi50.yml +++ b/configs/data/chebi50.yml @@ -1,4 +1,5 @@ class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50 init_args: aug_data: True - batch_size_: 5000 + augment_data_batch_size: 5000 + num_smiles_variations: 5 From 5895d2395844bf0f6e8ef1155c08272587a57e41 Mon Sep 17 00:00:00 2001 From: vidvath Date: Sat, 28 Sep 2024 23:20:01 +0200 Subject: [PATCH 09/35] Changed the directory for augmented data files --- chebai/preprocessing/datasets/chebi.py | 68 +++++++++++++------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 6ff45987..c9df0ede 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -144,7 +144,7 @@ def __init__( single_class: Optional[int] = None, aug_data: Optional[bool] = False, augment_data_batch_size:Optional[int]= 10000, - num_smiles_variations:Optional[int]=7, + num_smiles_variations:Optional[int]=5, **kwargs, ): # predict only single class (given as id of one of the classes present in the raw data set) @@ -461,24 +461,24 @@ def setup_processed(self) -> None: ), os.path.join(self.processed_dir, processed_name), ) + if self.aug_data: + augmented_dir = self.augmented_dir_main - augmented_dir = os.path.join("data", "augmented_dataset") - - # Define the augmented data file path - if not os.path.isfile(os.path.join(augmented_dir, "augmented_data.pt")): - print( - f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", - "augmented_data.pt", - ) - torch.save( - self._load_data_from_file( - os.path.join( - augmented_dir, - "augmented_data.pkl", + # Define the augmented data file path + if not os.path.isfile(os.path.join(augmented_dir, "augmented_data.pt")): + print( + f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", + "augmented_data.pt", + ) + torch.save( + self._load_data_from_file( + os.path.join( + augmented_dir, + "augmented_data.pkl", + ) + ), + os.path.join(augmented_dir, "augmented_data.pt"), ) - ), - os.path.join(augmented_dir, "augmented_data.pt"), - ) # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist @@ -605,6 +605,20 @@ def processed_dir_main(self) -> str: "processed", ) + @property + def augmented_dir_main(self) -> str: + """ + Return the main directory path for processed data. + + Returns: + str: The path to the main processed data directory. + """ + return os.path.join( + self.base_dir, + self._name, + "augmented", + ) + @property def processed_dir(self) -> str: """ @@ -632,16 +646,6 @@ def base_dir(self) -> str: """ return os.path.join("data", f"chebi_v{self.chebi_version}") - @property - def augmented_dir(self) -> str: - """ - Return the base directory path for data. - - Returns: - str: The base directory path for data. - """ - return os.path.join("data", "chebi_augmented") - @property def processed_file_names_dict(self) -> dict: """ @@ -826,7 +830,7 @@ def augment_data(self, path: str, batch_size) -> None: if self.aug_data: if os.path.isfile(os.path.join( path, self.raw_file_names_dict["data"])): - augmented_dir = os.path.join("data", "augmented_dataset") + augmented_dir = self.augmented_dir_main # Check if the augmented directory exists, if not, create it os.makedirs(augmented_dir, exist_ok=True) # Define the augmented data file path @@ -837,7 +841,7 @@ def augment_data(self, path: str, batch_size) -> None: data = self.read_file(os.path.join( path, self.raw_file_names_dict["data"])) - print("Original Dataset size:", data.shape) + total_rows = data.shape[0] # Calculate the total number of batches total_batches = (total_rows + batch_size - 1) // batch_size @@ -886,7 +890,6 @@ def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): # Function to generate SMILES variations using different configurations def generate_smiles_variations(self, original_smiles): num_variations = self.num_smiles_variations - print("num_variations",num_variations) print(type(original_smiles), original_smiles) if not isinstance(original_smiles, str): print(f"Non-string SMILES found: {original_smiles}") @@ -897,20 +900,15 @@ def generate_smiles_variations(self, original_smiles): # Get the number of atoms in the molecule num_atoms = mol.GetNumAtoms() - print("num_atoms", num_atoms) - print("num_variations", num_variations) - # Generate the rooted_at_atoms list based on the number of atoms if num_atoms < num_variations: rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) else: rooted_at_atoms = list(range(0, num_variations)) # [0, num_variations) - print("rooted_at_atoms", rooted_at_atoms) # Shuffle the rooted_at_atoms list to randomize the order random.shuffle(rooted_at_atoms) - print("shuffled rooted_at_atoms", rooted_at_atoms) variations = set() From f303dc7be1995a405ab4f63373181e2bb9f10ec5 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 3 Oct 2024 22:49:03 +0200 Subject: [PATCH 10/35] Changed the file names for augmented data --- chebai/preprocessing/datasets/chebi.py | 71 +++++++++++++------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index c9df0ede..7fe7c5f8 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -28,6 +28,7 @@ from tqdm import tqdm from chebai.preprocessing import reader as dr +from chebai.preprocessing.datasets.aug import augment_data from chebai.preprocessing.datasets.base import XYBaseDataModule import random @@ -446,39 +447,40 @@ def setup_processed(self) -> None: # ) # Transform the processed data into encoded data - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print( - f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", - processed_name, - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.raw_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) - if self.aug_data: - augmented_dir = self.augmented_dir_main + if not self.aug_data: + processed_name = self.processed_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): + print( + f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", + processed_name, + ) + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.raw_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, processed_name), + ) + else: + augmented_dir = self.augmented_dir_main - # Define the augmented data file path - if not os.path.isfile(os.path.join(augmented_dir, "augmented_data.pt")): - print( - f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", - "augmented_data.pt", - ) - torch.save( - self._load_data_from_file( - os.path.join( - augmented_dir, - "augmented_data.pkl", - ) - ), - os.path.join(augmented_dir, "augmented_data.pt"), - ) + # Define the augmented data file path + if not os.path.isfile(os.path.join(augmented_dir, "augmented_data.pt")): + print( + f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", + "augmented_data.pt", + ) + torch.save( + self._load_data_from_file( + os.path.join( + augmented_dir, + "data.pkl", + ) + ), + os.path.join(augmented_dir, "data.pt"), + ) # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist @@ -826,7 +828,6 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: def augment_data(self, path: str, batch_size) -> None: print(("inside_augment_data")) - print("batch_size",batch_size) if self.aug_data: if os.path.isfile(os.path.join( path, self.raw_file_names_dict["data"])): @@ -834,9 +835,9 @@ def augment_data(self, path: str, batch_size) -> None: # Check if the augmented directory exists, if not, create it os.makedirs(augmented_dir, exist_ok=True) # Define the augmented data file path - augmented_data_file = os.path.join(augmented_dir, "augmented_data.pkl") + augmented_data_file = os.path.join(augmented_dir, "data.pkl") - # If augmented_data.pkl does not already exist, proceed with the logic + # If data.pkl(augmented) does not already exist in augmented dir, proceed with the logic if not os.path.isfile(augmented_data_file): data = self.read_file(os.path.join( From 32966e10d3450dd2e9d5b42780d62e9428501e44 Mon Sep 17 00:00:00 2001 From: vidvath Date: Fri, 18 Oct 2024 09:28:58 +0200 Subject: [PATCH 11/35] Added new config file for data augmentation --- configs/data/chebi_augmentation.yml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 configs/data/chebi_augmentation.yml diff --git a/configs/data/chebi_augmentation.yml b/configs/data/chebi_augmentation.yml new file mode 100644 index 00000000..a88d4019 --- /dev/null +++ b/configs/data/chebi_augmentation.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50 +init_args: + aug_data: True + augment_data_batch_size: 5000 + num_smiles_variations: 5 From 72bb33fd40892e9329b817f1af7933cbf9499c5e Mon Sep 17 00:00:00 2001 From: vidvath Date: Fri, 18 Oct 2024 09:29:45 +0200 Subject: [PATCH 12/35] Created new class for data augmentation --- chebai/preprocessing/datasets/chebi.py | 447 ++++++++++++++----------- 1 file changed, 257 insertions(+), 190 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 7fe7c5f8..3fec3b21 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -31,6 +31,8 @@ from chebai.preprocessing.datasets.aug import augment_data from chebai.preprocessing.datasets.base import XYBaseDataModule import random +import time + # Suppress RDKit warnings and errors RDLogger.DisableLog('rdApp.*') # Disable all RDKit logging @@ -145,7 +147,7 @@ def __init__( single_class: Optional[int] = None, aug_data: Optional[bool] = False, augment_data_batch_size:Optional[int]= 10000, - num_smiles_variations:Optional[int]=5, + num_smiles_variations:Optional[int]=10, **kwargs, ): # predict only single class (given as id of one of the classes present in the raw data set) @@ -447,42 +449,21 @@ def setup_processed(self) -> None: # ) # Transform the processed data into encoded data - if not self.aug_data: - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print( - f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", - processed_name, - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.raw_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) - else: - augmented_dir = self.augmented_dir_main - - # Define the augmented data file path - if not os.path.isfile(os.path.join(augmented_dir, "augmented_data.pt")): - print( - f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", - "augmented_data.pt", - ) - torch.save( - self._load_data_from_file( - os.path.join( - augmented_dir, - "data.pkl", - ) - ), - os.path.join(augmented_dir, "data.pt"), - ) - - + processed_name = self.processed_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): + print( + f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", + processed_name, + ) + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.raw_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, processed_name), + ) # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist if self.chebi_version_train is not None and not os.path.isfile( os.path.join( @@ -496,6 +477,8 @@ def setup_processed(self) -> None: print("Call the setup method related to it") self._chebi_version_train_obj.setup() + + def get_test_split( self, df: pd.DataFrame, seed: Optional[int] = None ) -> Tuple[pd.DataFrame, pd.DataFrame]: @@ -770,48 +753,17 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: for f in self.raw_file_names ): os.makedirs(self.processed_dir_main, exist_ok=True) + print("Created directory :",self.processed_dir_main) print("Missing raw data. Go fetch...") - # -------- Commented the code for Data Handling Restructure for Issue No.10 - # -------- https://github.com/ChEB-AI/python-chebai/issues/10 - # missing test set -> create - # if not os.path.isfile( - # os.path.join(self.raw_dir, self.raw_file_names_dict["test"]) - # ): - # chebi_path = self._load_chebi(self.chebi_version) - # g = self.extract_class_hierarchy(chebi_path) - # df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["test"]) - # _, test_df = self.get_test_split(df) - # self.save_raw(test_df, self.raw_file_names_dict["test"]) - # # load test_split from file - # else: - # with open( - # os.path.join(self.raw_dir, self.raw_file_names_dict["test"]), "rb" - # ) as input_file: - # test_df = pickle.load(input_file) - # # create train/val split based on test set - # chebi_path = self._load_chebi( - # self.chebi_version_train - # if self.chebi_version_train is not None - # else self.chebi_version - # ) - # g = self.extract_class_hierarchy(chebi_path) - # if self.use_inner_cross_validation: - # df = self.graph_to_raw_dataset( - # g, self.raw_file_names_dict[f"fold_0_train"] - # ) - # else: - # df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train"]) - # train_val_dict = self.get_train_val_splits_given_test(df, test_df) - # for name, df in train_val_dict.items(): - # self.save_raw(df, name) - # Data from chebi_version chebi_path = self._load_chebi(self.chebi_version) g = self.extract_class_hierarchy(chebi_path) df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"]) self.save_processed(df, filename=self.raw_file_names_dict["data"]) - self.augment_data(self.processed_dir_main, self.augment_data_batch_size) + + # **CALL AUGMENTATION CLASS HERE** + if self.chebi_version_train is not None: if not os.path.isfile( os.path.join( @@ -826,129 +778,21 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: # Generate the "chebi_version_train" data if it doesn't exist self._chebi_version_train_obj.prepare_data(*args, **kwargs) - def augment_data(self, path: str, batch_size) -> None: - print(("inside_augment_data")) + # Data augmentation if self.aug_data: - if os.path.isfile(os.path.join( - path, self.raw_file_names_dict["data"])): - augmented_dir = self.augmented_dir_main - # Check if the augmented directory exists, if not, create it - os.makedirs(augmented_dir, exist_ok=True) - # Define the augmented data file path - augmented_data_file = os.path.join(augmented_dir, "data.pkl") - - # If data.pkl(augmented) does not already exist in augmented dir, proceed with the logic - if not os.path.isfile(augmented_data_file): - - data = self.read_file(os.path.join( - path, self.raw_file_names_dict["data"])) - - total_rows = data.shape[0] - # Calculate the total number of batches - total_batches = (total_rows + batch_size - 1) // batch_size - - for batch_num, start in enumerate(range(0, total_rows, batch_size), start=1): - end = min(start + batch_size, total_rows) - batch = data[start:end] - print(f"Processing batch {batch_num}/{total_batches} ({start} to {end})") - - # Set to keep track of already seen SMILES - seen_smiles = set(batch['SMILES']) - - # Store new rows in a list instead of concatenating directly - new_rows = [] - - # Updated tqdm to show batch number and total batches - for _, row in tqdm(batch.iterrows(), total=len(batch), - desc=f"Batch {batch_num}/{total_batches}", unit="row"): - original_smiles = row['SMILES'] - variations = self.generate_smiles_variations(original_smiles) - variations = [var for var in variations if var not in seen_smiles] - - for var in variations: - new_row = row.copy() - new_row['SMILES'] = var - new_rows.append(new_row) - seen_smiles.add(var) - - # Create a DataFrame from the new rows - new_df = pd.DataFrame(new_rows) - - # Concatenate the batch and new variations, and save incrementally - batch_with_variations = pd.concat([batch, new_df], ignore_index=True) - - # Append batch to augmented file - self.save_file(batch_with_variations, augmented_data_file, append=True) - else: - print("Data Augmentation config is False") - - def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): - if append and os.path.isfile(file_path): - existing_data = pd.read_pickle(file_path) - dataset = pd.concat([existing_data, dataset], ignore_index=True) - pd.to_pickle(dataset, open(file_path, "wb")) - - # Function to generate SMILES variations using different configurations - def generate_smiles_variations(self, original_smiles): - num_variations = self.num_smiles_variations - print(type(original_smiles), original_smiles) - if not isinstance(original_smiles, str): - print(f"Non-string SMILES found: {original_smiles}") - mol = Chem.MolFromSmiles(original_smiles) - if mol is None: - return [] # Return an empty list if conversion fails - - # Get the number of atoms in the molecule - num_atoms = mol.GetNumAtoms() - - # Generate the rooted_at_atoms list based on the number of atoms - if num_atoms < num_variations: - rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) - else: - rooted_at_atoms = list(range(0, num_variations)) # [0, num_variations) - - - # Shuffle the rooted_at_atoms list to randomize the order - random.shuffle(rooted_at_atoms) - - variations = set() - - # Loop through all combinations of doRandom and rootedAtAtom values - for do_random in [True, False]: - for rooted_at_atom in rooted_at_atoms: - try: - # Generate SMILES with the given configuration - variant = Chem.MolToSmiles( - mol, - doRandom=do_random, - rootedAtAtom=rooted_at_atom, - canonical=False - ) - - # # Print the configuration and the generated SMILES string - # print( - # f"Config: doRandom={do_random}, rootedAtAtom={rooted_at_atom}, canonical={False} -> SMILES: {variant}") - - # Avoid duplicates with the original SMILES - if variant != original_smiles: - variations.add(variant) + if os.path.isfile(os.path.join(self.processed_dir_main, self.raw_file_names_dict["data"])): + augmenter = AugmentedDataExtractor(self.chebi_version, self.chebi_version_train, + self.use_inner_cross_validation, self.single_class, self.aug_data, + self.augment_data_batch_size, self.num_smiles_variations, + self.reader, **kwargs) + augmenter.augment_data(self.processed_dir_main, self.augment_data_batch_size) + augmenter.setup_processed() + else: + print("data.pkl(original) file is not found") - # Check the number of variations after adding - if len(variations) >= num_variations: - return list(variations) # Return immediately when enough variations are found - except Exception as e: - # Skip invalid configurations - continue - return list(variations) - def read_file(self, file_path: str): - df = pd.read_pickle( - open(file_path, "rb" - ) - ) - return df def _generate_dynamic_splits(self) -> None: """ @@ -1472,6 +1316,229 @@ def term_callback(doc) -> dict: "smiles": smiles, } +# New Class for Augmentation +class AugmentedDataExtractor(_ChEBIDataExtractor): + """ + A class for data augmentation that inherits from the main ChEBI data extractor class. + """ + + def __init__(self, chebi_version, chebi_version_train, use_inner_cross_validation, single_class, aug_data, + batch_size, num_smiles_variations,reader, **kwargs): + self.single_class = single_class + super(AugmentedDataExtractor, self).__init__(**kwargs) + self.chebi_version = chebi_version + self.aug_data = aug_data + self.chebi_version_train = chebi_version_train + self.use_inner_cross_validation = use_inner_cross_validation + self.num_smiles_variations = num_smiles_variations + self.aug_data = aug_data + self.reader=reader + print(f"Initializing AugmentedDataExtractor with chebi_version: {self.chebi_version}") + # other initializations + + # READER = dr.DataReader + READER: dr.ChemDataReader = dr.ChemDataReader + + @property + def _name(self): + return "ChEBI50" + + + + def augment_data(self, path: str, batch_size) -> None: + print(("Inside - AugmentedDataExtractor - augment_data()")) + if self.aug_data: + if os.path.isfile(os.path.join( + path, self.raw_file_names_dict["data"])): + augmented_dir = self.augmented_dir_main + # Check if the augmented directory exists, if not, create it + os.makedirs(augmented_dir, exist_ok=True) + print("Created augmented directory: ", augmented_dir) + # Define the augmented data file path + augmented_data_file = os.path.join(augmented_dir, "data.pkl") + print("Created augmented file directory: ", augmented_data_file) + # If data.pkl(augmented) does not already exist in augmented dir, proceed with the logic + if not os.path.isfile(augmented_data_file): + + # Start timing the augmentation process + start_time = time.time() + + data = self.read_file(os.path.join( + path, self.raw_file_names_dict["data"])) + + total_rows = data.shape[0] + #For testing + subset1 = data.iloc[:10000] # First 10,000 of the datapoints for testing only + data=subset1 + total_rows=subset1.shape[0] + # Calculate the total number of batches + total_batches = (total_rows + batch_size - 1) // batch_size + + for batch_num, start in enumerate(range(0, total_rows, batch_size), start=1): + end = min(start + batch_size, total_rows) + batch = data[start:end] + print(f"Processing batch {batch_num}/{total_batches} ({start} to {end})") + + # Set to keep track of already seen SMILES + seen_smiles = set(batch['SMILES']) + + # Store new rows in a list instead of concatenating directly + new_rows = [] + + # Updated tqdm to show batch number and total batches + for _, row in tqdm(batch.iterrows(), total=len(batch), + desc=f"Batch {batch_num}/{total_batches}", unit="row"): + original_smiles = row['SMILES'] + variations = self.generate_smiles_variations(original_smiles) + variations = [var for var in variations if var not in seen_smiles] + + for var in variations: + new_row = row.copy() + new_row['SMILES'] = var + new_rows.append(new_row) + seen_smiles.add(var) + + # Create a DataFrame from the new rows + new_df = pd.DataFrame(new_rows) + + # Concatenate the batch and new variations, and save incrementally + batch_with_variations = pd.concat([batch, new_df], ignore_index=True) + + # Append batch to augmented file + print("File going to be saved at :", augmented_data_file) + self.save_file(batch_with_variations, augmented_data_file, append=True) + + # End timing the augmentation process + end_time = time.time() + time_taken = end_time - start_time + time_taken_minutes=time_taken/60 + + # Load the augmented data to count the number of SMILES strings generated + augmented_data = self.read_file(augmented_data_file) + num_augmented_smiles = augmented_data.shape[0] + + print(f"Number of SMILES strings in original dataset: {total_rows}") + print(f"Number of SMILES strings in augmented dataset: {num_augmented_smiles}") + print(f"Time taken for augmentation process: {time_taken_minutes:.2f} minutes") + + else: + print("Original data.pkl file doesn't exist") + else: + print("Data Augmentation config is False") + + def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): + if append and os.path.isfile(file_path): + existing_data = pd.read_pickle(file_path) + dataset = pd.concat([existing_data, dataset], ignore_index=True) + pd.to_pickle(dataset, open(file_path, "wb")) + + # Function to generate SMILES variations using different configurations + def generate_smiles_variations(self, original_smiles): + num_variations = self.num_smiles_variations + print(type(original_smiles), original_smiles) + if not isinstance(original_smiles, str): + print(f"Non-string SMILES found: {original_smiles}") + mol = Chem.MolFromSmiles(original_smiles) + if mol is None: + return [] # Return an empty list if conversion fails + + # Get the number of atoms in the molecule + num_atoms = mol.GetNumAtoms() + + # Generate the rooted_at_atoms list based on the number of atoms + # if num_atoms < num_variations: + # rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) + # else: + # rooted_at_atoms = list(range(0, num_variations)) # [0, num_variations) + + rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) + + # Shuffle the rooted_at_atoms list to randomize the order + random.shuffle(rooted_at_atoms) + + variations = set() + + # Loop through all combinations of doRandom and rootedAtAtom values + for do_random in [True, False]: + for rooted_at_atom in rooted_at_atoms: + try: + # Generate SMILES with the given configuration + variant = Chem.MolToSmiles( + mol, + doRandom=do_random, + rootedAtAtom=rooted_at_atom, + canonical=False + ) + + # Avoid duplicates with the original SMILES + if variant != original_smiles: + variations.add(variant) + + # Check the number of variations after adding + if len(variations) >= num_atoms: + return list(variations) # Return immediately when enough variations are found + + except Exception as e: + # Skip invalid configurations + continue + + return list(variations) + + def read_file(self, file_path: str): + df = pd.read_pickle( + open(file_path, "rb" + ) + ) + return df + + @property + def augmented_dir_main(self) -> str: + """ + Return the main directory path for processed data. + + Returns: + str: The path to the main processed data directory. + """ + return os.path.join( + self.base_dir, + self._name, + "augmented", + ) + + def setup_processed(self) -> None: + """ + Transform and prepare processed data for the augmented ChEBI dataset. + + This method sets up the processed data directories and files for augmented data. + It ensures that the required processed data files exist by loading raw data, transforming it into + augmented format, and saving it. + + If augmented data already exists, it skips the transformation step. + """ + + print("Inside AugmentedDataExtractor - setup_processed()") + augmented_dir = self.augmented_dir_main + # augmented_dir = r'data\chebi_v231\ChEBI50\augmented' + + if os.path.isfile(os.path.join(augmented_dir, "data.pkl")): + # Define the augmented data file path + if not os.path.isfile(os.path.join(augmented_dir, "data.pt")): + print( + f"Missing encoded data related to version {self.chebi_version}, transform augmented data into encoded data:", + "data.pt(augmented)", + ) + torch.save( + self._load_data_from_file( + os.path.join( + augmented_dir, + "data.pkl", + ) + ), + os.path.join(augmented_dir, "data.pt"), + ) + else: + print("data.pkl (augmented) file doesn't exist") + atom_index = ( "\*", From 5916420f65dba68b64bbedb1d2e823645f3242e8 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 24 Oct 2024 15:46:50 +0200 Subject: [PATCH 13/35] Add new folder --- logs/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 logs/.gitkeep diff --git a/logs/.gitkeep b/logs/.gitkeep new file mode 100644 index 00000000..e69de29b From 5c48e4df13923f8c1958de3e7fa25e47d7456b59 Mon Sep 17 00:00:00 2001 From: vidvath Date: Sun, 3 Nov 2024 17:32:14 +0100 Subject: [PATCH 14/35] Added directory for augmented directory for splitting --- chebai/preprocessing/datasets/chebi.py | 61 ++++++++++++++------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 3fec3b21..ffc2d6a4 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -449,33 +449,34 @@ def setup_processed(self) -> None: # ) # Transform the processed data into encoded data - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print( - f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", - processed_name, - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.raw_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) - # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist - if self.chebi_version_train is not None and not os.path.isfile( - os.path.join( - self._chebi_version_train_obj.processed_dir, - self._chebi_version_train_obj.raw_file_names_dict["data"], - ) - ): - print( - f"Missing encoded data related to train version: {self.chebi_version_train}" - ) - print("Call the setup method related to it") - self._chebi_version_train_obj.setup() + if not self.aug_data: + processed_name = self.processed_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): + print( + f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", + processed_name, + ) + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.raw_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, processed_name), + ) + # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist + if self.chebi_version_train is not None and not os.path.isfile( + os.path.join( + self._chebi_version_train_obj.processed_dir, + self._chebi_version_train_obj.raw_file_names_dict["data"], + ) + ): + print( + f"Missing encoded data related to train version: {self.chebi_version_train}" + ) + print("Call the setup method related to it") + self._chebi_version_train_obj.setup() @@ -810,9 +811,13 @@ def _generate_dynamic_splits(self) -> None: """ print("Generate dynamic splits...") # Load encoded data derived from "chebi_version" + # Determine the directory for loading encoded data based on the aug_data flag + data_dir = self.augmented_dir_main if self.aug_data else self.processed_dir + try: filename = self.processed_file_names_dict["data"] - data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) + print("Directory:",os.path.join(data_dir, filename)) + data_chebi_version = torch.load(os.path.join(data_dir, filename)) except FileNotFoundError: raise FileNotFoundError( f"File data.pt doesn't exists. " From 99d36965fa88af7488eab9287056f344e1838617 Mon Sep 17 00:00:00 2001 From: vidvath Date: Sun, 3 Nov 2024 23:11:36 +0100 Subject: [PATCH 15/35] Removed changes made for testing with subset --- chebai/preprocessing/datasets/chebi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index ffc2d6a4..0a4f8f77 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -1372,10 +1372,10 @@ def augment_data(self, path: str, batch_size) -> None: path, self.raw_file_names_dict["data"])) total_rows = data.shape[0] - #For testing - subset1 = data.iloc[:10000] # First 10,000 of the datapoints for testing only - data=subset1 - total_rows=subset1.shape[0] + # #For testing + # subset1 = data.iloc[:10000] # First 10,000 of the datapoints for testing only + # data=subset1 + # total_rows=subset1.shape[0] # Calculate the total number of batches total_batches = (total_rows + batch_size - 1) // batch_size From 4cf83d19c0c68ebc26d6257bce4783fa87193ffc Mon Sep 17 00:00:00 2001 From: vidvath Date: Mon, 4 Nov 2024 18:19:08 +0100 Subject: [PATCH 16/35] removed import --- chebai/preprocessing/datasets/chebi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 0a4f8f77..728c0fed 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -28,7 +28,6 @@ from tqdm import tqdm from chebai.preprocessing import reader as dr -from chebai.preprocessing.datasets.aug import augment_data from chebai.preprocessing.datasets.base import XYBaseDataModule import random import time From 43360bc700d51fd26747b6143c890ba127bda67b Mon Sep 17 00:00:00 2001 From: vidvath Date: Tue, 12 Nov 2024 18:09:52 +0100 Subject: [PATCH 17/35] changes for lightning error --- chebai/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 362731df..dd375c7f 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -46,7 +46,7 @@ def __init__( super().__init__() self.criterion = criterion self.save_hyperparameters( - ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"] + ignore=["criterion", "train_metrics", "val_metrics", "test_metrics","_class_path"] ) self.out_dim = out_dim if optimizer_kwargs: From cef2ea2a59587ee42aef09dd7b15c45a9d3e10df Mon Sep 17 00:00:00 2001 From: vidvath Date: Sat, 16 Nov 2024 00:49:19 +0100 Subject: [PATCH 18/35] Changes in yml to ChEBI100 --- configs/data/chebi_augmentation.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/data/chebi_augmentation.yml b/configs/data/chebi_augmentation.yml index a88d4019..77efb664 100644 --- a/configs/data/chebi_augmentation.yml +++ b/configs/data/chebi_augmentation.yml @@ -1,5 +1,6 @@ -class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50 +class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100 init_args: aug_data: True augment_data_batch_size: 5000 num_smiles_variations: 5 + From dff800bcceb2ad4cf621435597c1d69db03455eb Mon Sep 17 00:00:00 2001 From: vidvath Date: Sat, 16 Nov 2024 01:07:31 +0100 Subject: [PATCH 19/35] Removed whitespaces --- configs/data/chebi_augmentation.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/data/chebi_augmentation.yml b/configs/data/chebi_augmentation.yml index 77efb664..5fee5ff7 100644 --- a/configs/data/chebi_augmentation.yml +++ b/configs/data/chebi_augmentation.yml @@ -3,4 +3,3 @@ init_args: aug_data: True augment_data_batch_size: 5000 num_smiles_variations: 5 - From 131ea90df24471c52b8349167ca58f9563c03e40 Mon Sep 17 00:00:00 2001 From: vidvath Date: Sat, 16 Nov 2024 01:15:47 +0100 Subject: [PATCH 20/35] Changed augmented path to ChEBI100 --- chebai/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 728c0fed..3a938d03 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -1345,7 +1345,7 @@ def __init__(self, chebi_version, chebi_version_train, use_inner_cross_validatio @property def _name(self): - return "ChEBI50" + return "ChEBI100" From 7056187395601c137a6ca3e9c557d8c5b0459148 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 28 Nov 2024 19:02:53 +0100 Subject: [PATCH 21/35] Changes for creating splits.csv,added lines for debugging --- chebai/preprocessing/datasets/chebi.py | 35 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 3a938d03..4e627813 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -513,6 +513,9 @@ def get_test_split( df_train = df.iloc[train_indices] df_test = df.iloc[test_indices] + print("Inside get_test_split") + print("Train Split : ", df_train.shape) + print("Test Split : ", df_test.shape) return df_train, df_test def get_train_val_splits_given_test( @@ -534,14 +537,15 @@ def get_train_val_splits_given_test( are the corresponding DataFrames. """ print(f"Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - # ---- list comprehension degrades performance, dataframe operations are faster - # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] - # df_trainval = df_trainval[mask] - df_trainval = df[~df["ident"].isin(test_ids)] + df_trainval = df + if self.aug_data==False: + test_ids = test_df["ident"].tolist() + # ---- list comprehension degrades performance, dataframe operations are faster + # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] + # df_trainval = df_trainval[mask] + df_trainval = df[~df["ident"].isin(test_ids)] labels_list_trainval = df_trainval["labels"].tolist() - + print("df_trainval.shape after removing overlapping points:",df_trainval.shape) if self.use_inner_cross_validation: folds = {} kfold = MultilabelStratifiedKFold( @@ -571,9 +575,13 @@ def get_train_val_splits_given_test( train_indices, validation_indices = next( msss.split(labels_list_trainval, labels_list_trainval) ) + print("train_indices.shape : ", train_indices.shape) + print("validation_indices.shape : ", validation_indices.shape) df_validation = df_trainval.iloc[validation_indices] df_train = df_trainval.iloc[train_indices] + print("df_train :",df_train.shape) + print("df_validation :",df_validation.shape) return df_train, df_validation @property @@ -815,7 +823,9 @@ def _generate_dynamic_splits(self) -> None: try: filename = self.processed_file_names_dict["data"] - print("Directory:",os.path.join(data_dir, filename)) + print("Directory of data.pt:",os.path.join(data_dir, filename)) + #loading of data.pt + print("Loading : ", filename ) data_chebi_version = torch.load(os.path.join(data_dir, filename)) except FileNotFoundError: raise FileNotFoundError( @@ -824,10 +834,14 @@ def _generate_dynamic_splits(self) -> None: ) df_chebi_version = pd.DataFrame(data_chebi_version) + print("Created dataframe for data.pt :",df_chebi_version) + print("Created dataframe size:",df_chebi_version.shape) train_df_chebi_ver, df_test_chebi_ver = self.get_test_split( df_chebi_version, seed=self.dynamic_data_split_seed ) - + print("get_test_split done, splits size train: ", train_df_chebi_ver.shape) + print("get_test_split done, splits size test: ", df_test_chebi_ver.shape) + print("chebi_version_train : ",self.chebi_version_train) if self.chebi_version_train is not None: # Load encoded data derived from "chebi_version_train" try: @@ -872,9 +886,11 @@ def _generate_dynamic_splits(self) -> None: pd.DataFrame({"id": df_test["ident"], "split": "test"}), ] combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + #Saving csv combined_split_assignment.to_csv( os.path.join(self.processed_dir_main, "splits.csv") ) + print("Saving splits.csv") # Store the splits in class variables self.dynamic_df_train = df_train @@ -934,6 +950,7 @@ def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: ): if self.splits_file_path is None: # Generate splits based on given seed, create csv file to records the splits + print("no splits_file_path provided by the user") self._generate_dynamic_splits() else: # If user has provided splits file path, use it to get the splits from the data From 419d603d43c88c9df95a2fe6c8e68f35ebebf3a0 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 5 Dec 2024 13:25:02 +0100 Subject: [PATCH 22/35] Added new file for Evaluation --- chebai/result/utils.py | 4 +++ eval.py | 73 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 eval.py diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 31063747..b6d56a5e 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -94,6 +94,7 @@ def evaluate_model( Returns: Tensors with predictions and labels. """ + print("Start of evaluate_model") model.eval() collate = data_module.reader.COLLATOR() @@ -157,6 +158,7 @@ def evaluate_model( torch.cat(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) + print("End of evaluate_model") def load_results_from_buffer( @@ -172,6 +174,7 @@ def load_results_from_buffer( Returns: Tensors with predictions and labels. """ + print("Start of load_results_from_buffer") preds_list = [] labels_list = [] @@ -208,6 +211,7 @@ def load_results_from_buffer( else: test_labels = None + print("End of load_results_from_buffer") return test_preds, test_labels diff --git a/eval.py b/eval.py new file mode 100644 index 00000000..28157903 --- /dev/null +++ b/eval.py @@ -0,0 +1,73 @@ +import pandas as pd + +from chebai.result.utils import ( + evaluate_model, + load_results_from_buffer, +) +from chebai.result.classification import print_metrics +from chebai.models.electra import Electra +from chebai.preprocessing.datasets.chebi import ChEBIOver50, ChEBIOver100 +import os +import tqdm +import torch +import pickle + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print(DEVICE) + + +# Specify paths and parameters +checkpoint_name = "best_epoch=14_val_loss=0.0017_val_macro-f1=0.9226_val_micro-f1=0.9847.ckpt" +print("checkpoint_name",checkpoint_name) +checkpoint_path = os.path.join("logs/wandb/run-20241128_214007-ukcabied/files/checkpoints", f"{checkpoint_name}.ckpt") +print("checkpoint_path",checkpoint_path) +kind = "test" # Change to "train" or "validation" as needed +buffer_dir = os.path.join("results_buffer", checkpoint_name, kind) +print("buffer_dir",buffer_dir) +batch_size = 10 # Set batch size + +# Load data module +data_module = ChEBIOver100(chebi_version=231) + +model_class = Electra + +# evaluates model, stores results in buffer_dir +model = model_class.load_from_checkpoint(checkpoint_path) +if buffer_dir is None: + preds, labels = evaluate_model( + model, + data_module, + buffer_dir=buffer_dir, + # No need to provide this parameter for Chebi dataset, "kind" parameter should be provided + # filename=data_module.processed_file_names_dict[kind], + batch_size=10, + kind=kind, + ) +else: + evaluate_model( + model, + data_module, + buffer_dir=buffer_dir, + # No need to provide this parameter for Chebi dataset, "kind" parameter should be provided + # filename=data_module.processed_file_names_dict[kind], + batch_size=10, + kind=kind, + ) + # load data from buffer_dir + preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE) + + +# Load classes from the classes.txt +with open(os.path.join(data_module.processed_dir_main, "classes.txt"), "r") as f: + classes = [line.strip() for line in f.readlines()] + + +# output relevant metrics +print_metrics( + preds, + labels.to(torch.int), + DEVICE, + classes=classes, + markdown_output=False, + top_k=10, +) From c04a5957de4d9bff3c991327f6771a026dcbf154 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 5 Dec 2024 13:50:06 +0100 Subject: [PATCH 23/35] Corrected directory --- eval.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/eval.py b/eval.py index 28157903..fbc73c6d 100644 --- a/eval.py +++ b/eval.py @@ -17,13 +17,13 @@ # Specify paths and parameters -checkpoint_name = "best_epoch=14_val_loss=0.0017_val_macro-f1=0.9226_val_micro-f1=0.9847.ckpt" -print("checkpoint_name",checkpoint_name) +checkpoint_name = "best_epoch=14_val_loss=0.0017_val_macro-f1=0.9226_val_micro-f1=0.9847" +print("checkpoint_name : ",checkpoint_name) checkpoint_path = os.path.join("logs/wandb/run-20241128_214007-ukcabied/files/checkpoints", f"{checkpoint_name}.ckpt") -print("checkpoint_path",checkpoint_path) +print("checkpoint_path : ",checkpoint_path) kind = "test" # Change to "train" or "validation" as needed buffer_dir = os.path.join("results_buffer", checkpoint_name, kind) -print("buffer_dir",buffer_dir) +print("buffer_dir : ",buffer_dir) batch_size = 10 # Set batch size # Load data module From a52c6356af7b42d92be5a2d7f6cd3042c46b5a27 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 5 Dec 2024 16:19:43 +0100 Subject: [PATCH 24/35] Changes for evaluation-set splits_file_path --- chebai/preprocessing/datasets/chebi.py | 2 ++ eval.py | 1 + 2 files changed, 3 insertions(+) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 4e627813..0b84619a 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -940,6 +940,8 @@ def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: dict: A dictionary containing the dynamic train, validation, and test DataFrames. Keys are 'train', 'validation', and 'test'. """ + print("Inside dynamic_split_dfs") + print("splits_file_path : ",self.splits_file_path) if any( split is None for split in [ diff --git a/eval.py b/eval.py index fbc73c6d..b2538bd9 100644 --- a/eval.py +++ b/eval.py @@ -29,6 +29,7 @@ # Load data module data_module = ChEBIOver100(chebi_version=231) +data_module.splits_file_path="python-chebai/data/chebi_v231/ChEBI100/processed/splits.csv" model_class = Electra # evaluates model, stores results in buffer_dir From 6ba39212eed9557f7d081545ddb1b2cf1391a8af Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 5 Dec 2024 16:22:38 +0100 Subject: [PATCH 25/35] Corrected splits_file_path --- eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval.py b/eval.py index b2538bd9..072405b7 100644 --- a/eval.py +++ b/eval.py @@ -29,7 +29,7 @@ # Load data module data_module = ChEBIOver100(chebi_version=231) -data_module.splits_file_path="python-chebai/data/chebi_v231/ChEBI100/processed/splits.csv" +data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/splits.csv" model_class = Electra # evaluates model, stores results in buffer_dir From 6e9e098218644187d6356ec0582228d323ab4c71 Mon Sep 17 00:00:00 2001 From: vidvath <73276923+vidvath7@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:20:21 +0100 Subject: [PATCH 26/35] Update eval.py --- eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval.py b/eval.py index 072405b7..2acf2f58 100644 --- a/eval.py +++ b/eval.py @@ -12,7 +12,7 @@ import torch import pickle -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") print(DEVICE) From b38cce7f904456c9fefdded627e3a6dc8712f421 Mon Sep 17 00:00:00 2001 From: vidvath Date: Fri, 6 Dec 2024 09:49:55 +0100 Subject: [PATCH 27/35] printing batch size --- chebai/result/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index b6d56a5e..a323724f 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -95,6 +95,7 @@ def evaluate_model( Tensors with predictions and labels. """ print("Start of evaluate_model") + print("batch_size: ", batch_size) model.eval() collate = data_module.reader.COLLATOR() From 58ef43ecfc2e0e5f09ae36748052bdaeaab381bb Mon Sep 17 00:00:00 2001 From: vidvath Date: Fri, 6 Dec 2024 10:07:04 +0100 Subject: [PATCH 28/35] changing batch size --- chebai/result/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index a323724f..4ed8c7cd 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -95,6 +95,7 @@ def evaluate_model( Tensors with predictions and labels. """ print("Start of evaluate_model") + batch_size=5 print("batch_size: ", batch_size) model.eval() collate = data_module.reader.COLLATOR() From ae43441c1b8777769eaa2bff8ed340fdf21a6d16 Mon Sep 17 00:00:00 2001 From: vidvath Date: Sat, 7 Dec 2024 20:12:39 +0100 Subject: [PATCH 29/35] Changes for Augmentation after splitting --- chebai/preprocessing/datasets/chebi.py | 129 +++++++++++++++---------- 1 file changed, 79 insertions(+), 50 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 0b84619a..1711b4d7 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -448,34 +448,34 @@ def setup_processed(self) -> None: # ) # Transform the processed data into encoded data - if not self.aug_data: - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print( - f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", - processed_name, - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.raw_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) - # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist - if self.chebi_version_train is not None and not os.path.isfile( - os.path.join( - self._chebi_version_train_obj.processed_dir, - self._chebi_version_train_obj.raw_file_names_dict["data"], - ) - ): - print( - f"Missing encoded data related to train version: {self.chebi_version_train}" - ) - print("Call the setup method related to it") - self._chebi_version_train_obj.setup() + # if not self.aug_data: + processed_name = self.processed_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): + print( + f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", + processed_name, + ) + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.raw_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, processed_name), + ) + # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist + if self.chebi_version_train is not None and not os.path.isfile( + os.path.join( + self._chebi_version_train_obj.processed_dir, + self._chebi_version_train_obj.raw_file_names_dict["data"], + ) + ): + print( + f"Missing encoded data related to train version: {self.chebi_version_train}" + ) + print("Call the setup method related to it") + self._chebi_version_train_obj.setup() @@ -537,13 +537,14 @@ def get_train_val_splits_given_test( are the corresponding DataFrames. """ print(f"Split dataset into train / val with given test set") - df_trainval = df - if self.aug_data==False: - test_ids = test_df["ident"].tolist() + # df_trainval = df + # if self.aug_data==False: + test_ids = test_df["ident"].tolist() + print("test_ids size : ",len(test_ids)) # ---- list comprehension degrades performance, dataframe operations are faster # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] # df_trainval = df_trainval[mask] - df_trainval = df[~df["ident"].isin(test_ids)] + df_trainval = df[~df["ident"].isin(test_ids)] labels_list_trainval = df_trainval["labels"].tolist() print("df_trainval.shape after removing overlapping points:",df_trainval.shape) if self.use_inner_cross_validation: @@ -786,23 +787,9 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: # Generate the "chebi_version_train" data if it doesn't exist self._chebi_version_train_obj.prepare_data(*args, **kwargs) - # Data augmentation - if self.aug_data: - if os.path.isfile(os.path.join(self.processed_dir_main, self.raw_file_names_dict["data"])): - augmenter = AugmentedDataExtractor(self.chebi_version, self.chebi_version_train, - self.use_inner_cross_validation, self.single_class, self.aug_data, - self.augment_data_batch_size, self.num_smiles_variations, - self.reader, **kwargs) - augmenter.augment_data(self.processed_dir_main, self.augment_data_batch_size) - augmenter.setup_processed() - else: - print("data.pkl(original) file is not found") - - - - def _generate_dynamic_splits(self) -> None: + def _generate_dynamic_splits(self,**kwargs: Any) -> None: """ Generate data splits during runtime and save them in class variables. @@ -819,7 +806,8 @@ def _generate_dynamic_splits(self) -> None: print("Generate dynamic splits...") # Load encoded data derived from "chebi_version" # Determine the directory for loading encoded data based on the aug_data flag - data_dir = self.augmented_dir_main if self.aug_data else self.processed_dir + # data_dir = self.augmented_dir_main if self.aug_data else self.processed_dir + data_dir=self.processed_dir try: filename = self.processed_file_names_dict["data"] @@ -892,6 +880,40 @@ def _generate_dynamic_splits(self) -> None: ) print("Saving splits.csv") + # Data augmentation + if self.aug_data: + if os.path.isfile(os.path.join(self.processed_dir_main, self.raw_file_names_dict["data"])): + augmenter = AugmentedDataExtractor(self.chebi_version, self.chebi_version_train, + self.use_inner_cross_validation, self.single_class, self.aug_data, + self.augment_data_batch_size, self.num_smiles_variations, + self.reader, **kwargs) + augmenter.augment_data(self.processed_dir_main, self.augment_data_batch_size,df_train) + augmenter.setup_processed() + + print("augmented file directory:",os.path.join(self.augmented_dir_main,self.processed_file_names_dict["data"])) + augmented_data = torch.load(os.path.join(self.augmented_dir_main, self.processed_file_names_dict["data"])) + df_train = pd.DataFrame(augmented_data) + else: + print("data.pkl(original) file is not found") + + print("df_train(augmented)",df_train.shape) + print("df_val",df_val.shape) + print("df_test",df_test.shape) + + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": df_train["ident"], "split": "train"}), + pd.DataFrame({"id": df_val["ident"], "split": "validation"}), + pd.DataFrame({"id": df_test["ident"], "split": "test"}), + ] + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + # Saving csv + combined_split_assignment.to_csv( + os.path.join(self.processed_dir_main, "augmented_splits.csv") + ) + print("Saving augmented_splits.csv") + + + # Store the splits in class variables self.dynamic_df_train = df_train self.dynamic_df_val = df_val @@ -994,6 +1016,8 @@ def load_processed_data( ] else: data_df = self.dynamic_split_dfs[kind] + # Print the dataset size for the given kind + print(f"Dataset size for kind '{kind}': {data_df.shape}") return data_df.to_dict(orient="records") except KeyError: kind = f"{kind}" @@ -1368,7 +1392,7 @@ def _name(self): - def augment_data(self, path: str, batch_size) -> None: + def augment_data(self, path: str, batch_size,df_train) -> None: print(("Inside - AugmentedDataExtractor - augment_data()")) if self.aug_data: if os.path.isfile(os.path.join( @@ -1386,10 +1410,15 @@ def augment_data(self, path: str, batch_size) -> None: # Start timing the augmentation process start_time = time.time() - data = self.read_file(os.path.join( + #Loading the original dataset + full_data = self.read_file(os.path.join( path, self.raw_file_names_dict["data"])) + # Filter the `data` DataFrame using the 'ident' feature values from `df_train` + data = full_data[full_data['id'].isin(df_train['ident'])] + total_rows = data.shape[0] + print("Shape of the Data to be augmented : ",data.shape) # #For testing # subset1 = data.iloc[:10000] # First 10,000 of the datapoints for testing only # data=subset1 From ecaeef795168e678f6b8afde91437ddbff7bc09f Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 12 Dec 2024 00:59:10 +0100 Subject: [PATCH 30/35] Changes for evaluation --- eval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval.py b/eval.py index 2acf2f58..5af9ab56 100644 --- a/eval.py +++ b/eval.py @@ -17,9 +17,9 @@ # Specify paths and parameters -checkpoint_name = "best_epoch=14_val_loss=0.0017_val_macro-f1=0.9226_val_micro-f1=0.9847" +checkpoint_name = "best_epoch=09_val_loss=0.0217_val_macro-f1=0.7101_val_micro-f1=0.9091" print("checkpoint_name : ",checkpoint_name) -checkpoint_path = os.path.join("logs/wandb/run-20241128_214007-ukcabied/files/checkpoints", f"{checkpoint_name}.ckpt") +checkpoint_path = os.path.join("logs/wandb/run-20241207_192102-nug9ndqi/files/checkpoints", f"{checkpoint_name}.ckpt") print("checkpoint_path : ",checkpoint_path) kind = "test" # Change to "train" or "validation" as needed buffer_dir = os.path.join("results_buffer", checkpoint_name, kind) @@ -29,7 +29,7 @@ # Load data module data_module = ChEBIOver100(chebi_version=231) -data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/splits.csv" +data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/augmented_splits.csv" model_class = Electra # evaluates model, stores results in buffer_dir From f4a34d25094dac8dedf892e8f362467c32d02346 Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 12 Dec 2024 01:06:25 +0100 Subject: [PATCH 31/35] Changes for evaluation- splits file changed --- eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval.py b/eval.py index 5af9ab56..ef39bdd9 100644 --- a/eval.py +++ b/eval.py @@ -29,7 +29,7 @@ # Load data module data_module = ChEBIOver100(chebi_version=231) -data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/augmented_splits.csv" +data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/splits.csv" model_class = Electra # evaluates model, stores results in buffer_dir From 8850f3a3af5ce1274b4bc150dc96bde0f16a3fbb Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 12 Dec 2024 01:34:21 +0100 Subject: [PATCH 32/35] Generating SMILES based on no. of variations --- chebai/preprocessing/datasets/chebi.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 1711b4d7..e59846fd 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -1487,6 +1487,7 @@ def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): # Function to generate SMILES variations using different configurations def generate_smiles_variations(self, original_smiles): num_variations = self.num_smiles_variations + print("no. of variations: ",num_variations) print(type(original_smiles), original_smiles) if not isinstance(original_smiles, str): print(f"Non-string SMILES found: {original_smiles}") @@ -1498,12 +1499,12 @@ def generate_smiles_variations(self, original_smiles): num_atoms = mol.GetNumAtoms() # Generate the rooted_at_atoms list based on the number of atoms - # if num_atoms < num_variations: - # rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) - # else: - # rooted_at_atoms = list(range(0, num_variations)) # [0, num_variations) + if num_atoms < num_variations: + rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) + else: + rooted_at_atoms = list(range(0, num_variations)) # [0, num_variations) - rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) + # rooted_at_atoms = list(range(0, num_atoms)) # [0, num_atoms) # Shuffle the rooted_at_atoms list to randomize the order random.shuffle(rooted_at_atoms) From 83d3db891be065b5db76b834a257e2bf1f7caa0a Mon Sep 17 00:00:00 2001 From: vidvath Date: Thu, 12 Dec 2024 01:42:42 +0100 Subject: [PATCH 33/35] Removed print statement --- chebai/preprocessing/datasets/chebi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index e59846fd..1381888e 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -1487,7 +1487,6 @@ def save_file(self, dataset: pd.DataFrame, file_path: str, append=False): # Function to generate SMILES variations using different configurations def generate_smiles_variations(self, original_smiles): num_variations = self.num_smiles_variations - print("no. of variations: ",num_variations) print(type(original_smiles), original_smiles) if not isinstance(original_smiles, str): print(f"Non-string SMILES found: {original_smiles}") From c38157ea1785eafbb3280219aaafbdf7abfd5ee5 Mon Sep 17 00:00:00 2001 From: vidvath Date: Wed, 18 Dec 2024 22:17:07 +0100 Subject: [PATCH 34/35] Changes for testing- changed checkpoint --- eval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval.py b/eval.py index ef39bdd9..4c8345ed 100644 --- a/eval.py +++ b/eval.py @@ -17,9 +17,9 @@ # Specify paths and parameters -checkpoint_name = "best_epoch=09_val_loss=0.0217_val_macro-f1=0.7101_val_micro-f1=0.9091" +checkpoint_name = "best_epoch=31_val_loss=0.0204_val_macro-f1=0.7655_val_micro-f1=0.9246.ckpt" print("checkpoint_name : ",checkpoint_name) -checkpoint_path = os.path.join("logs/wandb/run-20241207_192102-nug9ndqi/files/checkpoints", f"{checkpoint_name}.ckpt") +checkpoint_path = os.path.join("logs/wandb/run-20241212_003611-8yohluv6/files/checkpoints", f"{checkpoint_name}.ckpt") print("checkpoint_path : ",checkpoint_path) kind = "test" # Change to "train" or "validation" as needed buffer_dir = os.path.join("results_buffer", checkpoint_name, kind) @@ -29,7 +29,7 @@ # Load data module data_module = ChEBIOver100(chebi_version=231) -data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/splits.csv" +data_module.splits_file_path="data/chebi_v231/ChEBI100/processed/augmented_splits.csv" model_class = Electra # evaluates model, stores results in buffer_dir From 5c4f7b3c3a8a48efce47ac27c53dc19b3beb4d5f Mon Sep 17 00:00:00 2001 From: vidvath Date: Wed, 18 Dec 2024 22:22:18 +0100 Subject: [PATCH 35/35] Checkpoint name correction --- eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval.py b/eval.py index 4c8345ed..aeccdd5c 100644 --- a/eval.py +++ b/eval.py @@ -17,7 +17,7 @@ # Specify paths and parameters -checkpoint_name = "best_epoch=31_val_loss=0.0204_val_macro-f1=0.7655_val_micro-f1=0.9246.ckpt" +checkpoint_name = "best_epoch=31_val_loss=0.0204_val_macro-f1=0.7655_val_micro-f1=0.9246" print("checkpoint_name : ",checkpoint_name) checkpoint_path = os.path.join("logs/wandb/run-20241212_003611-8yohluv6/files/checkpoints", f"{checkpoint_name}.ckpt") print("checkpoint_path : ",checkpoint_path)