diff --git a/moleculeresolver/moleculeresolver.py b/moleculeresolver/moleculeresolver.py index a795674..3255d39 100644 --- a/moleculeresolver/moleculeresolver.py +++ b/moleculeresolver/moleculeresolver.py @@ -6491,7 +6491,7 @@ def get_prop_value( } else: if len(props_found) != 1: - raise + raise ValueError(f"Expected exactly 1 property, found {len(props_found)}") prop_vals = list(props_found[0].values()) return conversion_funtion(prop_vals[0]) diff --git a/tests/responses.json b/tests/responses.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/tests/responses.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/tests/test_integration.py b/tests/test_integration.py index dc6ef58..90efc8f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,9 +1,11 @@ +from datetime import datetime import pytest from moleculeresolver import MoleculeResolver import json import os from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union +from tqdm import tqdm # IUPAC names @@ -11,26 +13,82 @@ with open(dir_path / "benchmark_component_molecules_iupac.json", "r") as f: benchmark = json.load(f) +RESPONSES_PATH = dir_path / "responses.json" SMILES = "SMILES" -# PATCH_STATE = "SAVE" +class MoleculeResolverPatched(MoleculeResolver): + def __init__( + self, + json_path: Optional[str] = None, + available_service_API_keys: Optional[dict[str, Optional[str]]] = None, + molecule_cache_db_path: Optional[str] = None, + molecule_cache_expiration: Optional[datetime] = None, + differentiate_isotopes: bool = False, + ): + super().__init__( + available_service_API_keys, + molecule_cache_db_path, + molecule_cache_expiration, + differentiate_isotopes, + ) + self.json_path = json_path + if self.json_path is not None: + self.json_path = Path(self.json_path) + + if self.json_path is not None and self.json_path.exists(): + with open(self.json_path, "r") as f: + self.json_data = json.load(f) + else: + self.json_data = {} + + def _resilient_request( + self, + url: str, + kwargs: Optional[dict[str, Any]] = None, + request_type: Optional[str] = "get", + accepted_status_codes: list[int] = [200], + rejected_status_codes: list[int] = [404], + offline_status_codes: list[int] = [], + max_retries: Optional[int] = 10, + sleep_time: Union[int, float] = 2, + allow_redirects: Optional[bool] = False, + json: Optional[str] = None, + return_response: Optional[bool] = False, + ) -> Optional[str]: + if self.json_path is not None and url in self.json_data: + res = self.json_data[url] + else: + res = super()._resilient_request( + url, + kwargs, + request_type, + accepted_status_codes, + rejected_status_codes, + offline_status_codes, + max_retries, + sleep_time, + allow_redirects, + json, + return_response, + ) + if self.json_path is not None and url not in self.json_data: + self.json_data[url] = res + return res -# class PatchResilientRequest: -# def __init__(self, json_data, patch_state): -# self.json_data = json_data -# self.patch_state = patch_state -# def __call__(self, url: str, **kwargs) -> str: -# if self.patch_state == "SAVE": -# self.json_data[url] = kwargs["json"] -# elif self.patch_state == "LOAD": -# return self.json_data[url] + def __exit__(self, exc_type, exc_value, exc_traceback): + if self.json_path is not None: + with open(self.json_path, "w") as f: + json.dump(self.json_data, f) + return super().__exit__(exc_type, exc_value, exc_traceback) -@pytest.mark.parametrize("data", benchmark.values()) + +@pytest.mark.parametrize("data", benchmark.values()) class TestServices: + json_path: Optional[str] = RESPONSES_PATH @staticmethod def _test_service( @@ -47,7 +105,7 @@ def _test_service( ---------- call_method : Callable The method to call - input_identifier : str + input_identifier : str The input identifier output_identifier_type : str The type of the output identifier @@ -55,8 +113,8 @@ def _test_service( The expected output identifier kwargs : Optional[Dict], optional Additional keyword arguments to pass to the call method, by default None - - + + """ if kwargs is None: kwargs = {} @@ -71,7 +129,7 @@ def _test_service( raise ValueError(f"Expected {output_identifier} but got {res_txt}") def test_opsin(self, data): - with MoleculeResolver() as mr: + with MoleculeResolverPatched(json_path=self.json_path) as mr: iupac_name = data["iupac_name"] self._test_service( mr.get_molecule_from_OPSIN, @@ -80,24 +138,134 @@ def test_opsin(self, data): data["SMILES"], ) + def test_pubchem(self, data): + with MoleculeResolverPatched(json_path=self.json_path) as mr: + iupac_name = data["iupac_name"] + self._test_service( + mr.get_molecule_from_pubchem, + iupac_name, + SMILES, + data["SMILES"], + {"mode": "name"}, + ) + + def test_comptox(self, data): + with MoleculeResolverPatched(json_path=self.json_path) as mr: + iupac_name = data["iupac_name"] + self._test_service( + mr.get_molecule_from_CompTox, + iupac_name, + SMILES, + data["SMILES"], + {"mode": "name"}, + ) + + def test_cts(self, data): + with MoleculeResolverPatched(json_path=self.json_path) as mr: + iupac_name = data["iupac_name"] + self._test_service( + mr.get_molecule_from_CTS, + iupac_name, + SMILES, + data["SMILES"], + {"mode": "name"}, + ) + + # Need API key + # def test_chemeo(self, data): + # with MoleculeResolverPatched(json_path=self.json_path) as mr: + # iupac_name = data["iupac_name"] + # self._test_service( + # mr.get_molecule_from_Chemeo, + # iupac_name, + # SMILES, + # data["SMILES"], + # {"mode": "name"}, + # ) + + def test_cas(self, data): + with MoleculeResolverPatched(json_path=self.json_path) as mr: + iupac_name = data["iupac_name"] + self._test_service( + mr.get_molecule_from_CAS_registry, + iupac_name, + SMILES, + data["SMILES"], + {"mode": "name"}, + ) + + def test_cir(self, data): + with MoleculeResolverPatched(json_path=self.json_path) as mr: + iupac_name = data["iupac_name"] + self._test_service( + mr.get_molecule_from_CIR, + iupac_name, + SMILES, + data["SMILES"], + {"mode": "name"}, + ) + + def test_nist(self, data): + with MoleculeResolverPatched(json_path=self.json_path) as mr: + iupac_name = data["iupac_name"] + self._test_service( + mr.get_molecule_from_NIST, + iupac_name, + SMILES, + data["SMILES"], + {"mode": "name"}, + ) + + # ChEBI test disabled - benchmark data doesn't have chebi_id field + # def test_chebi(self, data): + # with MoleculeResolverPatched(json_path=self.json_path) as mr: + # iupac_name = data["iupac_name"] + # self._test_service( + # mr.get_molecule_from_ChEBI, iupac_name, "chebi_id", data["chebi_id"] + # ) + + +def test_opsin_batchmode(): + names = [d["iupac_name"] for d in benchmark.values()] + smiles = [d["SMILES"] for d in benchmark.values()] + with MoleculeResolver() as mr: + res = mr.get_molecule_from_OPSIN_batchmode(names) + for i, r in enumerate(res): + if r[0].SMILES == smiles[i]: + continue + else: + raise ValueError("Expected " + smiles[i] + " but got " + r[0].SMILES) + + +def generate_data(json_path, overwrite_json: bool = False): + """ Generate data for unit tests + Parameters + ---------- + json_path : str + Path to JSON file to save data to + overwrite_json : bool, optional + Whether to overwrite the JSON file, by default False -# def test_opsin_batchmode(): -# names = [d["iupac_name"] for d in benchmark.values()] -# smiles = [d["SMILES"] for d in benchmark.values()] -# with MoleculeResolver() as mr: -# res = mr.get_molecule_from_OPSIN_batchmode(names) -# for i, r in enumerate(res): -# if r[0].SMILES == smiles[i]: -# continue -# else: -# raise ValueError("Expected " + smiles[i] + " but got " + r.SMILES) + """ + # Remove JSON path + if overwrite_json: + if os.path.exists(json_path): + os.remove(json_path) -def generate_data(): - # Run each test with a patch of resilient request that saves response - pass + # Get all test methods + test_services = TestServices() + methods = dir(TestServices) + test_methods = [m for m in methods if m.startswith("test_")] + # Run tests + bar = tqdm(test_methods, desc="Running tests") + for m in bar: + method = getattr(test_services, m) + bar.set_description(f"Running {m}") + for data in benchmark.values(): + method(data) if __name__ == "__main__": - generate_data() + generate_data(RESPONSES_PATH, overwrite_json=True)