diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fb0723..0de2d37 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,11 +61,11 @@ jobs: with: name: dist - - name: Install package and pytest + - name: Install package and test dependencies shell: bash run: | WHL_NAME=$(ls molecule_resolver-*.whl) - pip install ${WHL_NAME} pytest + pip install ${WHL_NAME} pytest pytest-mock - name: Run tests shell: bash diff --git a/docs/context_lifecycle.md b/docs/context_lifecycle.md new file mode 100644 index 0000000..1222778 --- /dev/null +++ b/docs/context_lifecycle.md @@ -0,0 +1,21 @@ +# Runtime Context Lifecycle + +`MoleculeResolver` context management is intentionally limited to runtime resources. + +## Enter + +The `with MoleculeResolver() as mr:` enter step performs only: + +1. RDKit log suppression context setup. +2. Molecule cache context setup. +3. OPSIN temp folder creation (when OPSIN batch support is enabled). + +## Exit + +The exit step performs only: + +1. Molecule cache context teardown. +2. RDKit log suppression context teardown. +3. OPSIN temp folder cleanup (skipped when an exception is bubbling up). + +This separation keeps search and resolution logic outside lifecycle handling, so runtime setup/teardown stays predictable and easier to maintain. diff --git a/docs/evidence.md b/docs/evidence.md new file mode 100644 index 0000000..3bcef47 --- /dev/null +++ b/docs/evidence.md @@ -0,0 +1,20 @@ +# Evidence Output (`include_evidence`) + +`find_single_molecule_crosschecked(..., include_evidence=True)` returns a +`ResolutionResult` instead of only a `Molecule`. + +## Response fields + +- `best_molecule`: Selected molecule (or `None` for unresolved ties). +- `ranked_candidates`: List of `CandidateEvidence` entries sorted by score. +- `grouped_by_structure`: Raw grouped molecules by SMILES. +- `selected_smiles`: SMILES selected as best candidate. +- `selection_reason`: Why this result form was returned. + +## CandidateEvidence fields + +- `service_agreement_count`: Number of unique supporting services. +- `identifier_concordance_count`: Number of unique supporting identifiers. +- `synonym_overlap_count`: Repeated synonym matches across supporting molecules. +- `score_breakdown`: Explicit points per evidence component. +- `total_score`: Sum of `score_breakdown`. diff --git a/docs/scoring.md b/docs/scoring.md new file mode 100644 index 0000000..0eaab87 --- /dev/null +++ b/docs/scoring.md @@ -0,0 +1,27 @@ +# Consensus Scoring + +`find_single_molecule_crosschecked(..., resolution_mode="consensus")` uses an +explicit score model for grouped structure candidates. + +## Score Components + +- `crosscheck_points`: `crosscheck_count * 100` +- `opsin_bonus`: `30` if at least one supporting OPSIN name match exists +- `stereo_specificity_bonus`: RDKit-derived stereochemistry signal: + - defined chiral centers: `20` points each + - stereochemical bonds: `15` points each + - unresolved chiral centers: `5` points each + +## Deterministic Tie-Breaking + +When total score ties, the selector orders by: + +1. Higher crosscheck count +2. Higher OPSIN bonus +3. Higher stereo specificity bonus +4. Higher defined chirality bonus +5. Higher bond stereo bonus +6. Lexicographical SMILES fallback + +This removes raw SMILES length from score weighting while preserving a stable +deterministic result. diff --git a/docs/search_strategy.md b/docs/search_strategy.md new file mode 100644 index 0000000..47a891c --- /dev/null +++ b/docs/search_strategy.md @@ -0,0 +1,25 @@ +# Search Strategy + +`find_single_molecule`, `find_single_molecule_crosschecked`, and +`find_multiple_molecules_parallelized` now support `search_strategy`. + +## Options + +- `first_hit` (default): preserves the previous behavior and returns as soon as + the first matching candidate is found in service order. +- `exhaustive`: evaluates all provided identifier/mode combinations across the + selected services, deduplicates candidates, and chooses the best supported + structure group. + +## Example + +```python +from moleculeresolver import MoleculeResolver + +with MoleculeResolver() as mr: + molecule = mr.find_single_molecule( + identifiers=["ethanol", "ethyl alcohol"], + modes=["name", "name"], + search_strategy="exhaustive", + ) +``` diff --git a/moleculeresolver/moleculeresolver.py b/moleculeresolver/moleculeresolver.py index 50f7647..32b6af5 100644 --- a/moleculeresolver/moleculeresolver.py +++ b/moleculeresolver/moleculeresolver.py @@ -15,7 +15,7 @@ import tempfile import time from types import SimpleNamespace -from typing import Any, Generator, Optional, Union +from typing import Any, Generator, Literal, Optional, Union import traceback import unicodedata import urllib @@ -39,7 +39,31 @@ import urllib3 import xmltodict from moleculeresolver.molecule import Molecule +from moleculeresolver.resolution import ( + CandidateEvidence, + ResolutionResult, + build_structure_group_candidates, + score_structure_groups, + select_best_scored_structure, +) from moleculeresolver.SqliteMoleculeCache import SqliteMoleculeCache +from moleculeresolver.services import ( + CASRegistryServiceAdapter, + CIRServiceAdapter, + CTSServiceAdapter, + ChEBIServiceAdapter, + ChemeoServiceAdapter, + CompToxServiceAdapter, + NISTServiceAdapter, + OPSINServiceAdapter, + PubChemServiceAdapter, + SRSServiceAdapter, + ServiceAdapterRegistry, + ServiceSearchResult, +) + +SearchStrategy = Literal["first_hit", "exhaustive"] +ResolutionMode = Literal["legacy", "consensus", "strict_isomer"] class EmptyResonanceMolSupplierCallback(ResonanceMolSupplierCallback): @@ -341,6 +365,8 @@ def __init__( for k in sorted(self.supported_services_by_mode) } self.supported_modes = sorted(list(set(self.supported_modes))) + self._service_adapters = ServiceAdapterRegistry() + self._register_default_service_adapters() self.CAS_regex_with_groups = regex.compile(r"^(\d{2,7})-(\d{2})-(\d)$") self.CAS_regex = r"(\d{2,7}-\d{2}-\d)" @@ -405,18 +431,41 @@ def __enter__(self) -> "MoleculeResolver": 2. Initializes the molecule cache. 3. Sets up a temporary folder for OPSIN if it's available. """ + self._enter_rdkit_log_context() + self._enter_molecule_cache_context() + self._create_opsin_tempfolder() + return self + + def _enter_rdkit_log_context(self) -> None: + """Start suppressing RDKit logs for the resolver runtime.""" self._disabling_rdkit_logger = rdBase.BlockLogs() self._disabling_rdkit_logger.__enter__() + + def _enter_molecule_cache_context(self) -> None: + """Create and enter the molecule cache context when needed.""" if not self.molecule_cache: self.molecule_cache = SqliteMoleculeCache( self.molecule_cache_db_path, self.molecule_cache_expiration ) self.molecule_cache.__enter__() + + def _create_opsin_tempfolder(self) -> None: + """Create the OPSIN temp folder when OPSIN batch mode is enabled.""" if "opsin" in self._available_services_with_batch_capabilities: self._OPSIN_tempfolder = tempfile.TemporaryDirectory( prefix="OPSIN_tempfolder_", ignore_cleanup_errors=True ) - return self + + def _cleanup_opsin_tempfolder(self, *, error_ocurred: bool) -> None: + """Cleanup OPSIN temp folder unless an exception is currently bubbling up.""" + if self._OPSIN_tempfolder and not error_ocurred: + self._OPSIN_tempfolder.cleanup() + + def _teardown_runtime_contexts(self, *, error_ocurred: bool) -> None: + """Teardown all runtime contexts in a single lifecycle helper.""" + self.molecule_cache.__exit__(None, None, None) + self._disabling_rdkit_logger.__exit__(None, None, None) + self._cleanup_opsin_tempfolder(error_ocurred=error_ocurred) def __exit__(self, exception_type, exception_value, exception_traceback) -> None: """ @@ -448,10 +497,84 @@ def __exit__(self, exception_type, exception_value, exception_traceback) -> None or exception_traceback is not None ) - self.molecule_cache.__exit__(None, None, None) - self._disabling_rdkit_logger.__exit__(None, None, None) - if self._OPSIN_tempfolder and not error_ocurred: - self._OPSIN_tempfolder.cleanup() + self._teardown_runtime_contexts(error_ocurred=error_ocurred) + + def _register_default_service_adapters(self) -> None: + """Register built-in adapters used by find_single_molecule.""" + adapters = [ + CASRegistryServiceAdapter(), + ChEBIServiceAdapter(), + ChemeoServiceAdapter(), + CIRServiceAdapter(), + CompToxServiceAdapter(), + CTSServiceAdapter(), + NISTServiceAdapter(), + OPSINServiceAdapter(), + PubChemServiceAdapter(), + SRSServiceAdapter(), + ] + for adapter in adapters: + self._service_adapters.register(adapter) + + def _resolve_service_with_adapter( + self, + service: str, + flattened_identifiers: list[str], + flattened_modes: list[str], + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + """Resolve one service by delegating to its configured adapter.""" + adapter = self._service_adapters.get(service) + if adapter is None: + return None + return adapter.resolve( + self, + flattened_identifiers, + flattened_modes, + required_formula, + required_charge, + required_structure_type, + ) + + def _resolve_identifier_with_adapter( + self, + service: str, + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + """Resolve one identifier/mode pair via the configured service adapter.""" + adapter = self._service_adapters.get(service) + if adapter is None: + return None + return adapter.resolve_one( + self, + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + + @staticmethod + def _service_result_to_exhaustive_candidate( + service: str, result: ServiceSearchResult + ) -> dict[str, Any]: + """Normalize a service adapter result into the exhaustive candidate payload.""" + return { + "SMILES": result.molecule.SMILES, + "synonyms": list(result.synonyms), + "CAS": set(result.cas), + "additional_information": result.additional_information, + "mode_used": result.mode_used, + "identifier_used": result.identifier_used, + "service": service, + "cas_is_authoritative": service == "cas_registry", + } @contextmanager def query_molecule_cache( @@ -1213,6 +1336,41 @@ def _check_and_flatten_identifiers_and_modes( return flattened_identifiers, flattened_modes, synonyms, CAS, given_SMILES + def _expand_identifier_mode_pairs( + self, + flattened_identifiers: list[str], + flattened_modes: list[str], + search_strategy: SearchStrategy, + ) -> list[tuple[str, str]]: + """Expand identifier/mode pairs for single-molecule search strategies.""" + if search_strategy not in {"first_hit", "exhaustive"}: + raise ValueError( + "search_strategy can only be one of: 'first_hit', 'exhaustive'." + ) + return list(zip(flattened_identifiers, flattened_modes, strict=True)) + + def _resolve_single_service_candidate( + self, + service: str, + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[dict[str, Any]]: + """Backwards-compatible wrapper for single-pair adapter resolution.""" + result = self._resolve_identifier_with_adapter( + service, + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + if result is None: + return None + return self._service_result_to_exhaustive_candidate(service, result) + def _is_list_of_list_of_str(self, value: list[list[str]]) -> bool: """ Check if the input is a valid list of lists of strings. @@ -7974,6 +8132,38 @@ def find_salt_molecules_and_stoichometric_coefficients( return all_molecules, stoichometric_coefficients + @staticmethod + def _validate_resolution_mode(resolution_mode: ResolutionMode) -> None: + """Validate the chosen resolution mode.""" + valid_modes = {"legacy", "consensus", "strict_isomer"} + if resolution_mode not in valid_modes: + raise ValueError( + "resolution_mode can only be one of: 'legacy', 'consensus', 'strict_isomer'." + ) + + def _collect_opsin_isomer_matches( + self, + grouped_molecules: dict[str, list[Molecule]], + candidate_smiles: list[str], + ) -> dict[str, bool]: + """Check whether candidate groups have at least one OPSIN-confirmed isomeric match.""" + matches = {smiles: False for smiles in candidate_smiles} + for smiles in candidate_smiles: + target_smiles = self.standardize_SMILES(smiles) + names = set() + for molecule in grouped_molecules[smiles]: + names.update(molecule.synonyms) + + for name in names: + opsin_candidate = self.get_molecule_from_OPSIN(name) + if opsin_candidate is None: + continue + opsin_smiles = self.standardize_SMILES(opsin_candidate.SMILES) + if opsin_smiles == target_smiles: + matches[smiles] = True + break + return matches + def find_single_molecule( self, identifiers: list[str], @@ -7985,6 +8175,8 @@ def find_single_molecule( search_iupac_name: Optional[bool] = False, interactive: Optional[bool] = False, ignore_exceptions: Optional[bool] = False, + search_strategy: SearchStrategy = "first_hit", + resolution_mode: ResolutionMode = "legacy", ) -> Optional[Molecule]: """Searches for a single molecule across multiple chemical databases and services. @@ -8012,6 +8204,11 @@ def find_single_molecule( ignore_exceptions (Optional[bool]): Whether to ignore exceptions during the search. Defaults to False. + search_strategy (str): Search strategy. "first_hit" keeps legacy behavior; + "exhaustive" evaluates all identifier/service combinations. + resolution_mode (str): Resolution mode. Included for API consistency and + future expansion. Accepted values are "legacy", "consensus", "strict_isomer". + Returns: Optional[Molecule]: A Molecule object if found, None otherwise. @@ -8025,6 +8222,7 @@ def find_single_molecule( """ if services_to_use is None: services_to_use = self._available_services + self._validate_resolution_mode(resolution_mode) ( flattened_identifiers, @@ -8033,6 +8231,9 @@ def find_single_molecule( CAS, given_SMILES, ) = self._check_and_flatten_identifiers_and_modes(identifiers, modes) + flattened_identifier_mode_pairs = self._expand_identifier_mode_pairs( + flattened_identifiers, flattened_modes, search_strategy + ) self._check_parameters( services=services_to_use, required_formulas=required_formula, @@ -8049,224 +8250,97 @@ def find_single_molecule( mode_used = None identifier_used = None current_service = None + exhaustive_candidates = [] try: for service in services_to_use: current_service = service - if service == "cas_registry": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_CAS_registry( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS = set( - cmp.CAS - ) # overwrite CAS with data from the CAS registry - additional_information = cmp.service - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "pubchem": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_pubchem( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = ( - f"{cmp.service} id: {cmp.additional_information}" - ) - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "cir": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_CIR( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - additional_information = cmp.service - mode_used = mode - identifier_used = cmp.identifier - break - elif service == "opsin": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_OPSIN( - identifier, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - additional_information = cmp.additional_information - mode_used = mode - identifier_used = cmp.identifier - break - elif service == "chebi": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_ChEBI( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = ( - f"{cmp.service} id: {cmp.additional_information}" - ) - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "srs": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_SRS( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = ( - f"{cmp.service} id: {cmp.additional_information}" - ) - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "comptox": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_CompTox( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = ( - f"{cmp.service} id: {cmp.additional_information}" - ) - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "chemeo": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_Chemeo( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = ( - f"{cmp.service} id: {cmp.additional_information}" - ) - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "cts": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_CTS( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = "cts" - mode_used = cmp.mode - identifier_used = cmp.identifier - break - elif service == "nist": - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): - if mode in self.supported_modes_by_services[service]: - cmp = self.get_molecule_from_NIST( - identifier, - mode, - required_formula, - required_charge, - required_structure_type, - ) - if cmp is not None: - SMILES = cmp.SMILES - synonyms.extend(cmp.synonyms) - CAS.update(cmp.CAS) - additional_information = ( - f"{cmp.service} id: {cmp.additional_information}" - ) - mode_used = cmp.mode - identifier_used = cmp.identifier - break + if search_strategy == "exhaustive": + for identifier, mode in flattened_identifier_mode_pairs: + adapter_result = self._resolve_identifier_with_adapter( + service, + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + if adapter_result is None: + continue + resolved_candidate = self._service_result_to_exhaustive_candidate( + service, adapter_result + ) + exhaustive_candidates.append(resolved_candidate) + else: + adapter_result = self._resolve_service_with_adapter( + service, + flattened_identifiers, + flattened_modes, + required_formula, + required_charge, + required_structure_type, + ) + if adapter_result is not None: + SMILES = adapter_result.molecule.SMILES + synonyms.extend(adapter_result.synonyms) + additional_information = adapter_result.additional_information + mode_used = adapter_result.mode_used + identifier_used = adapter_result.identifier_used + if service == "cas_registry": + CAS = adapter_result.cas + else: + CAS.update(adapter_result.cas) - if SMILES is not None: + if search_strategy == "first_hit" and SMILES is not None: break + if search_strategy == "exhaustive" and exhaustive_candidates: + grouped_candidates = collections.defaultdict(list) + first_index_by_group = {} + for i, candidate in enumerate(exhaustive_candidates): + standardized = self.standardize_SMILES(candidate["SMILES"]) + dedupe_key = ( + standardized, + candidate["service"], + candidate["identifier_used"], + ) + if any(dedupe_key == c["dedupe_key"] for c in grouped_candidates[standardized]): + continue + candidate["dedupe_key"] = dedupe_key + grouped_candidates[standardized].append(candidate) + if standardized not in first_index_by_group: + first_index_by_group[standardized] = i + + best_group_smiles = sorted( + grouped_candidates, + key=lambda smi: ( + len(grouped_candidates[smi]), + -first_index_by_group[smi], + len(smi), + smi, + ), + reverse=True, + )[0] + best_group = grouped_candidates[best_group_smiles] + chosen_candidate = best_group[0] + + SMILES = chosen_candidate["SMILES"] + additional_information = chosen_candidate["additional_information"] + mode_used = chosen_candidate["mode_used"] + identifier_used = chosen_candidate["identifier_used"] + current_service = chosen_candidate["service"] + + synonyms = [] + CAS = set() + authoritative_cas = None + for candidate in best_group: + synonyms.extend(candidate["synonyms"]) + if candidate["cas_is_authoritative"]: + authoritative_cas = set(candidate["CAS"]) + else: + CAS.update(candidate["CAS"]) + if authoritative_cas is not None: + CAS = authoritative_cas + if SMILES is None: if given_SMILES is not None: if self.check_SMILES( @@ -8286,9 +8360,7 @@ def find_single_molecule( # if searching for an ion # search for salts in pubchem and extract the single ions, take the one found most often if required_charge != "zero" and required_charge != 0: - for identifier, mode in zip( - flattened_identifiers, flattened_modes, strict=True - ): + for identifier, mode in flattened_identifier_mode_pairs: molecules = ( self.get_molecule_for_ion_from_partial_pubchem_search( identifier, required_formula, required_charge @@ -8584,6 +8656,79 @@ def find_single_molecule_interactively( 1, ) + @staticmethod + def _rank_candidate_evidence( + candidate_evidence: list[CandidateEvidence], + ) -> list[CandidateEvidence]: + """Rank candidate evidence by agreement and concordance.""" + return sorted( + candidate_evidence, + key=lambda evidence: ( + -evidence.total_score, + -evidence.service_agreement_count, + -evidence.identifier_concordance_count, + -evidence.synonym_overlap_count, + evidence.smiles, + ), + ) + + def _build_candidate_evidence( + self, + grouped_molecules: dict[str, list[Molecule]], + ) -> list[CandidateEvidence]: + """Create base evidence objects with service and identifier concordance.""" + evidence = [] + for smiles, molecules in grouped_molecules.items(): + service_names = sorted( + {molecule.service for molecule in molecules if molecule.service} + ) + identifiers = sorted( + {molecule.identifier for molecule in molecules if molecule.identifier} + ) + normalized_synonyms = [ + synonym.strip().casefold() + for molecule in molecules + for synonym in molecule.synonyms + if synonym + ] + synonym_overlap_count = len(normalized_synonyms) - len( + set(normalized_synonyms) + ) + score_breakdown = { + "service_agreement": len(service_names) * 100, + "identifier_concordance": len(identifiers) * 20, + "synonym_overlap": synonym_overlap_count * 5, + } + evidence.append( + CandidateEvidence( + smiles=smiles, + service_agreement_count=len(service_names), + service_names=service_names, + identifiers=identifiers, + identifier_concordance_count=len(identifiers), + synonym_overlap_count=synonym_overlap_count, + score_breakdown=score_breakdown, + total_score=sum(score_breakdown.values()), + ) + ) + return self._rank_candidate_evidence(evidence) + + def _build_resolution_result( + self, + best_molecule: Optional[Molecule], + grouped_molecules: dict[str, list[Molecule]], + selected_smiles: Optional[str], + selection_reason: str, + ) -> ResolutionResult: + """Build full include_evidence payload from grouped candidate molecules.""" + return ResolutionResult( + best_molecule=best_molecule, + ranked_candidates=self._build_candidate_evidence(grouped_molecules), + grouped_by_structure=grouped_molecules, + selected_smiles=selected_smiles, + selection_reason=selection_reason, + ) + def find_single_molecule_crosschecked( self, identifiers: list[str], @@ -8596,7 +8741,10 @@ def find_single_molecule_crosschecked( minimum_number_of_crosschecks: Optional[int] = 1, try_to_choose_best_structure: Optional[bool] = True, ignore_exceptions: Optional[bool] = False, - ) -> Union[Optional[Molecule], list[Optional[Molecule]]]: + search_strategy: SearchStrategy = "first_hit", + resolution_mode: ResolutionMode = "legacy", + include_evidence: bool = False, + ) -> Union[Optional[Molecule], list[Optional[Molecule]], ResolutionResult]: """Finds a single molecule with cross-checking across multiple services. This method searches for a molecule using the provided identifiers and modes, @@ -8613,6 +8761,11 @@ def find_single_molecule_crosschecked( minimum_number_of_crosschecks (Optional[int]): Minimum number of services that must agree. Defaults to 1. try_to_choose_best_structure (Optional[bool]): Whether to attempt to select the best structure. Defaults to True. ignore_exceptions (Optional[bool]): Whether to ignore exceptions during search. Defaults to False. + search_strategy (str): Search strategy. "first_hit" keeps legacy behavior; + "exhaustive" evaluates all identifier/service combinations. + resolution_mode (str): Resolution mode. Accepted values are "legacy", + "consensus", "strict_isomer". + include_evidence (bool): If True, return ResolutionResult with evidence payload. Returns: Union[Optional[Molecule], list[Optional[Molecule]]]: A single Molecule object if a best structure is chosen, @@ -8632,6 +8785,7 @@ def find_single_molecule_crosschecked( """ if services_to_use is None: services_to_use = self._available_services + self._validate_resolution_mode(resolution_mode) if minimum_number_of_crosschecks is None: minimum_number_of_crosschecks = 1 @@ -8652,6 +8806,8 @@ def find_single_molecule_crosschecked( services_to_use=[service], search_iupac_name=search_iupac_name, ignore_exceptions=ignore_exceptions, + search_strategy=search_strategy, + resolution_mode=resolution_mode, ) molecules.append(molecule) @@ -8678,90 +8834,54 @@ def find_single_molecule_crosschecked( if len(group_molecules) == maximum_number_of_crosschecks_found: SMILES_with_highest_number_of_crosschecks.append(group_SMILES) - if try_to_choose_best_structure: - - SMILES_preferred = sorted(SMILES_with_highest_number_of_crosschecks)[0] - if len(SMILES_with_highest_number_of_crosschecks) > 1: - # if SMILES are the same ignoring isomeric info, use the more specific one: - unique_non_isomeric_SMILES = set( - [ - self.to_SMILES(self.get_from_SMILES(smi), isomeric=False) - for smi in SMILES_with_highest_number_of_crosschecks - ] - ) - if len(unique_non_isomeric_SMILES) == 1: - SMILES_preferred = sorted( - SMILES_with_highest_number_of_crosschecks, key=len - )[-1] - else: - # trust opsin algorithm: if not sure and opsin available - SMILES_preferred_by_opsin = None - for SMILES in SMILES_with_highest_number_of_crosschecks: - for molecule in grouped_molecules[SMILES]: - if molecule.mode == "name" and molecule.service == "opsin": - SMILES_preferred_by_opsin = SMILES - - # if opsin result not available, or searched by another mode - # try getting all structures from the names and see if they agree - # with the SMILES found - if not SMILES_preferred_by_opsin: - SMILES_map = [] - names_map = [] - for SMILES in SMILES_with_highest_number_of_crosschecks: - for molecule in grouped_molecules[SMILES]: - for name in molecule.synonyms: - SMILES_map.append(SMILES) - names_map.append(name) - - SMILES_found_by_opsin_from_synonyms = [ - self.get_molecule_from_OPSIN(name) for name in names_map - ] - - SMILES_preferred_by_opsin = [] - for ( - original_SMILES_found, - SMILES_found_by_opsin_from_synonym, - ) in zip( - SMILES_map, SMILES_found_by_opsin_from_synonyms, strict=True - ): - if SMILES_found_by_opsin_from_synonym: - if ( - original_SMILES_found - == SMILES_found_by_opsin_from_synonym - ): - SMILES_preferred_by_opsin.append( - original_SMILES_found - ) - - SMILES_preferred_by_opsin = set(SMILES_preferred_by_opsin) - if len(SMILES_preferred_by_opsin) == 1: - SMILES_preferred_by_opsin = SMILES_preferred_by_opsin.pop() - else: - SMILES_preferred_by_opsin = None + opsin_isomer_matches = {} + if resolution_mode == "strict_isomer": + opsin_isomer_matches = self._collect_opsin_isomer_matches( + grouped_molecules, SMILES_with_highest_number_of_crosschecks + ) + SMILES_with_highest_number_of_crosschecks = [ + smiles + for smiles in SMILES_with_highest_number_of_crosschecks + if opsin_isomer_matches.get(smiles, False) + ] + if not SMILES_with_highest_number_of_crosschecks: + return None - if SMILES_preferred_by_opsin: - SMILES_preferred = SMILES_preferred_by_opsin - else: - c = [] - if len(c) == 1 or (len(c) > 1 and c[0][1] > c[1][1]): - SMILES_preferred = c[0][0] - else: - if self._show_warning_if_non_unique_structure_was_found: - temp = len(SMILES_with_highest_number_of_crosschecks) - warnings.warn( - f"\n\n{temp} molecules were found equally as often. First one sorted by SMILES was taken: \n{grouped_molecules}\n" - ) + if try_to_choose_best_structure: + candidate_groups = { + smiles: grouped_molecules[smiles] + for smiles in SMILES_with_highest_number_of_crosschecks + } + candidates = build_structure_group_candidates(self, candidate_groups) + scored_candidates = score_structure_groups(candidates) + best_scored_candidate = select_best_scored_structure(scored_candidates) + SMILES_preferred = best_scored_candidate.smiles molec = self.combine_molecules( SMILES_preferred, grouped_molecules[SMILES_preferred] ) molec.found_molecules.append(grouped_molecules) + if include_evidence: + return self._build_resolution_result( + best_molecule=molec, + grouped_molecules=grouped_molecules, + selected_smiles=SMILES_preferred, + selection_reason="best_structure_selected", + ) return molec else: - return [ + unresolved_molecules = [ self.combine_molecules(SMILES, grouped_molecules[SMILES]) for SMILES in SMILES_with_highest_number_of_crosschecks ] + if include_evidence: + return self._build_resolution_result( + best_molecule=None, + grouped_molecules=grouped_molecules, + selected_smiles=None, + selection_reason="multiple_structures_tied", + ) + return unresolved_molecules def find_multiple_molecules_parallelized( self, @@ -8777,6 +8897,8 @@ def find_multiple_molecules_parallelized( progressbar: Optional[bool] = True, max_workers: Optional[int] = 5, ignore_exceptions: bool = True, + search_strategy: SearchStrategy = "first_hit", + resolution_mode: ResolutionMode = "legacy", ) -> list[Optional[Molecule]]: """Finds multiple molecules in parallel based on provided identifiers and criteria. @@ -8822,6 +8944,11 @@ def find_multiple_molecules_parallelized( ignore_exceptions (Optional[bool]): If True, ignores exceptions that may occur during the search process. Defaults to True. + search_strategy (str): Search strategy. "first_hit" keeps legacy behavior; + "exhaustive" evaluates all identifier/service combinations. + resolution_mode (str): Resolution mode. Accepted values are "legacy", + "consensus", "strict_isomer". + Returns: list[Optional[Molecule]]: A list of found molecules, where each molecule is represented as an instance of the Molecule class, or None if not found. @@ -8840,6 +8967,7 @@ def find_multiple_molecules_parallelized( # reinitialize session self._session = None self._init_session(pool_maxsize=max_workers * 2) + self._validate_resolution_mode(resolution_mode) if isinstance(modes, str): temp_modes = [] @@ -8942,6 +9070,8 @@ def _find(generator): search_iupac_name, False, ignore_exceptions, + search_strategy, + resolution_mode, ) ) else: @@ -8957,6 +9087,8 @@ def _find(generator): minimum_number_of_crosschecks, try_to_choose_best_structure, ignore_exceptions, + search_strategy, + resolution_mode, ) ) diff --git a/moleculeresolver/resolution/__init__.py b/moleculeresolver/resolution/__init__.py new file mode 100644 index 0000000..c273f88 --- /dev/null +++ b/moleculeresolver/resolution/__init__.py @@ -0,0 +1,18 @@ +from .evidence import CandidateEvidence, ResolutionResult +from .models import ( + StructureGroupCandidate, + WeightedStructureScore, + build_structure_group_candidates, +) +from .scorer import score_structure_groups +from .selector import select_best_scored_structure + +__all__ = [ + "CandidateEvidence", + "ResolutionResult", + "StructureGroupCandidate", + "WeightedStructureScore", + "build_structure_group_candidates", + "score_structure_groups", + "select_best_scored_structure", +] diff --git a/moleculeresolver/resolution/evidence.py b/moleculeresolver/resolution/evidence.py new file mode 100644 index 0000000..eb9145f --- /dev/null +++ b/moleculeresolver/resolution/evidence.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from moleculeresolver.molecule import Molecule + + +@dataclass +class CandidateEvidence: + """Evidence and scoring payload for one candidate structure group.""" + + smiles: str + service_agreement_count: int + service_names: list[str] + identifiers: list[str] + identifier_concordance_count: int + synonym_overlap_count: int = 0 + score_breakdown: dict[str, int] = field(default_factory=dict) + total_score: int = 0 + + +@dataclass +class ResolutionResult: + """Extended result payload for include_evidence=True requests.""" + + best_molecule: "Molecule | None" + ranked_candidates: list[CandidateEvidence] + grouped_by_structure: dict[str, list["Molecule"]] + selected_smiles: str | None + selection_reason: str diff --git a/moleculeresolver/resolution/models.py b/moleculeresolver/resolution/models.py new file mode 100644 index 0000000..e249344 --- /dev/null +++ b/moleculeresolver/resolution/models.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING +from rdkit import Chem +from rdkit.Chem.rdchem import BondStereo + +if TYPE_CHECKING: + from moleculeresolver.molecule import Molecule + from moleculeresolver.moleculeresolver import MoleculeResolver + + +@dataclass +class StructureGroupCandidate: + """Single grouped-structure candidate used in consensus scoring.""" + + smiles: str + molecules: list["Molecule"] + crosscheck_count: int + non_isomeric_smiles: str + supports_opsin_name_match: bool + chiral_center_count: int + defined_chiral_center_count: int + bond_stereo_count: int + stereo_signal_count: int + + +@dataclass +class WeightedStructureScore: + """Explicit weighted score and breakdown for one structure group.""" + + smiles: str + total_score: int + crosscheck_count: int + opsin_bonus: int + stereo_specificity_bonus: int + defined_chirality_bonus: int + bond_stereo_bonus: int + + # Backward-compatible aliases for earlier score fields. + @property + def isomer_specificity_bonus(self) -> int: + return self.stereo_specificity_bonus + + @property + def smiles_length_bonus(self) -> int: + return 0 + + +def build_structure_group_candidates( + resolver: "MoleculeResolver", + grouped_molecules: dict[str, list["Molecule"]], +) -> list[StructureGroupCandidate]: + """Build normalized candidates from grouped molecules for scoring.""" + candidates = [] + for smiles, molecules in grouped_molecules.items(): + mol = resolver.get_from_SMILES(smiles) + non_isomeric = ( + resolver.to_SMILES(mol, isomeric=False) if mol is not None else smiles + ) + chiral_center_count = 0 + defined_chiral_center_count = 0 + bond_stereo_count = 0 + if mol is not None: + chiral_centers = Chem.FindMolChiralCenters( + mol, includeUnassigned=True, useLegacyImplementation=False + ) + chiral_center_count = len(chiral_centers) + defined_chiral_center_count = sum( + 1 for _, label in chiral_centers if label != "?" + ) + bond_stereo_count = sum( + 1 + for bond in mol.GetBonds() + if bond.GetStereo() not in {BondStereo.STEREONONE, BondStereo.STEREOANY} + ) + stereo_signal_count = defined_chiral_center_count + bond_stereo_count + candidates.append( + StructureGroupCandidate( + smiles=smiles, + molecules=molecules, + crosscheck_count=len(molecules), + non_isomeric_smiles=non_isomeric, + supports_opsin_name_match=any( + molecule.mode == "name" and molecule.service == "opsin" + for molecule in molecules + ), + chiral_center_count=chiral_center_count, + defined_chiral_center_count=defined_chiral_center_count, + bond_stereo_count=bond_stereo_count, + stereo_signal_count=stereo_signal_count, + ) + ) + return candidates diff --git a/moleculeresolver/resolution/scorer.py b/moleculeresolver/resolution/scorer.py new file mode 100644 index 0000000..5714134 --- /dev/null +++ b/moleculeresolver/resolution/scorer.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from moleculeresolver.resolution.models import ( + StructureGroupCandidate, + WeightedStructureScore, +) + + +def score_structure_groups( + candidates: list[StructureGroupCandidate], +) -> list[WeightedStructureScore]: + """Score grouped structures with explicit weighted components.""" + scored = [] + for candidate in candidates: + crosscheck_points = candidate.crosscheck_count * 100 + opsin_bonus = 30 if candidate.supports_opsin_name_match else 0 + unresolved_chiral_centers = max( + 0, candidate.chiral_center_count - candidate.defined_chiral_center_count + ) + defined_chirality_bonus = candidate.defined_chiral_center_count * 20 + bond_stereo_bonus = candidate.bond_stereo_count * 15 + stereo_specificity_bonus = ( + defined_chirality_bonus + bond_stereo_bonus + unresolved_chiral_centers * 5 + ) + + scored.append( + WeightedStructureScore( + smiles=candidate.smiles, + total_score=( + crosscheck_points + + opsin_bonus + + stereo_specificity_bonus + ), + crosscheck_count=candidate.crosscheck_count, + opsin_bonus=opsin_bonus, + stereo_specificity_bonus=stereo_specificity_bonus, + defined_chirality_bonus=defined_chirality_bonus, + bond_stereo_bonus=bond_stereo_bonus, + ) + ) + + return scored diff --git a/moleculeresolver/resolution/selector.py b/moleculeresolver/resolution/selector.py new file mode 100644 index 0000000..1b41b7f --- /dev/null +++ b/moleculeresolver/resolution/selector.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from moleculeresolver.resolution.models import WeightedStructureScore + + +def select_best_scored_structure( + scored_structures: list[WeightedStructureScore], +) -> WeightedStructureScore: + """Deterministically select the strongest structure candidate.""" + if not scored_structures: + raise ValueError("select_best_scored_structure requires a non-empty list.") + + return sorted( + scored_structures, + key=lambda score: ( + -score.total_score, + -score.crosscheck_count, + -score.opsin_bonus, + -score.stereo_specificity_bonus, + -score.defined_chirality_bonus, + -score.bond_stereo_bonus, + score.smiles, + ), + )[0] diff --git a/moleculeresolver/services/__init__.py b/moleculeresolver/services/__init__.py new file mode 100644 index 0000000..f2125b2 --- /dev/null +++ b/moleculeresolver/services/__init__.py @@ -0,0 +1,30 @@ +from .base import ServiceAdapter, ServiceSearchResult +from .cir_adapter import CIRServiceAdapter +from .opsin_adapter import OPSINServiceAdapter +from .other_adapters import ( + CASRegistryServiceAdapter, + ChEBIServiceAdapter, + ChemeoServiceAdapter, + CTSServiceAdapter, + CompToxServiceAdapter, + NISTServiceAdapter, + SRSServiceAdapter, +) +from .pubchem_adapter import PubChemServiceAdapter +from .registry import ServiceAdapterRegistry + +__all__ = [ + "CASRegistryServiceAdapter", + "CIRServiceAdapter", + "ChEBIServiceAdapter", + "ChemeoServiceAdapter", + "CTSServiceAdapter", + "CompToxServiceAdapter", + "NISTServiceAdapter", + "OPSINServiceAdapter", + "PubChemServiceAdapter", + "SRSServiceAdapter", + "ServiceAdapter", + "ServiceSearchResult", + "ServiceAdapterRegistry", +] diff --git a/moleculeresolver/services/base.py b/moleculeresolver/services/base.py new file mode 100644 index 0000000..e4fe0ad --- /dev/null +++ b/moleculeresolver/services/base.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from moleculeresolver.molecule import Molecule + from moleculeresolver.moleculeresolver import MoleculeResolver + + +@dataclass +class ServiceSearchResult: + """Normalized single-service search result consumed by MoleculeResolver.""" + + molecule: "Molecule" + mode_used: str + identifier_used: str + additional_information: Optional[str] + current_service: str + synonyms: list[str] + cas: set[str] + + +class ServiceAdapter(ABC): + """Contract for a resolver service adapter.""" + + name: str + + def resolve( + self, + resolver: "MoleculeResolver", + flattened_identifiers: list[str], + flattened_modes: list[str], + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + """Resolve by trying each identifier/mode pair in order.""" + for identifier, mode in zip(flattened_identifiers, flattened_modes, strict=True): + result = self.resolve_one( + resolver, + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + if result is not None: + return result + return None + + @abstractmethod + def resolve_one( + self, + resolver: "MoleculeResolver", + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + """Resolve one identifier/mode pair for this adapter or return None.""" + raise NotImplementedError diff --git a/moleculeresolver/services/cir_adapter.py b/moleculeresolver/services/cir_adapter.py new file mode 100644 index 0000000..0e12f2f --- /dev/null +++ b/moleculeresolver/services/cir_adapter.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Optional + +from moleculeresolver.services.base import ServiceAdapter, ServiceSearchResult + + +class CIRServiceAdapter(ServiceAdapter): + name = "cir" + + def resolve_one( + self, + resolver, + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + if mode not in resolver.supported_modes_by_services[self.name]: + return None + cmp = resolver.get_molecule_from_CIR( + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + if cmp is None: + return None + return ServiceSearchResult( + molecule=cmp, + mode_used=mode, + identifier_used=cmp.identifier, + additional_information=cmp.service, + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(), + ) diff --git a/moleculeresolver/services/opsin_adapter.py b/moleculeresolver/services/opsin_adapter.py new file mode 100644 index 0000000..5650370 --- /dev/null +++ b/moleculeresolver/services/opsin_adapter.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Optional + +from moleculeresolver.services.base import ServiceAdapter, ServiceSearchResult + + +class OPSINServiceAdapter(ServiceAdapter): + name = "opsin" + + def resolve_one( + self, + resolver, + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + if mode not in resolver.supported_modes_by_services[self.name]: + return None + cmp = resolver.get_molecule_from_OPSIN( + identifier, + required_formula, + required_charge, + required_structure_type, + ) + if cmp is None: + return None + return ServiceSearchResult( + molecule=cmp, + mode_used=mode, + identifier_used=cmp.identifier, + additional_information=cmp.additional_information, + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(), + ) diff --git a/moleculeresolver/services/other_adapters.py b/moleculeresolver/services/other_adapters.py new file mode 100644 index 0000000..27a72e2 --- /dev/null +++ b/moleculeresolver/services/other_adapters.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from typing import Optional + +from moleculeresolver.services.base import ServiceAdapter, ServiceSearchResult + + +class _ModeDrivenAdapter(ServiceAdapter): + fetch_method_name: str + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + raise NotImplementedError + + def resolve_one( + self, + resolver, + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + if mode not in resolver.supported_modes_by_services[self.name]: + return None + fetch_method = getattr(resolver, self.fetch_method_name) + cmp = fetch_method( + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + if cmp is not None: + return self._build_result(cmp, mode) + return None + + +class CASRegistryServiceAdapter(_ModeDrivenAdapter): + name = "cas_registry" + fetch_method_name = "get_molecule_from_CAS_registry" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=cmp.service, + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) + + +class ChEBIServiceAdapter(_ModeDrivenAdapter): + name = "chebi" + fetch_method_name = "get_molecule_from_ChEBI" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=f"{cmp.service} id: {cmp.additional_information}", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) + + +class SRSServiceAdapter(_ModeDrivenAdapter): + name = "srs" + fetch_method_name = "get_molecule_from_SRS" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=f"{cmp.service} id: {cmp.additional_information}", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) + + +class CompToxServiceAdapter(_ModeDrivenAdapter): + name = "comptox" + fetch_method_name = "get_molecule_from_CompTox" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=f"{cmp.service} id: {cmp.additional_information}", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) + + +class ChemeoServiceAdapter(_ModeDrivenAdapter): + name = "chemeo" + fetch_method_name = "get_molecule_from_Chemeo" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=f"{cmp.service} id: {cmp.additional_information}", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) + + +class CTSServiceAdapter(_ModeDrivenAdapter): + name = "cts" + fetch_method_name = "get_molecule_from_CTS" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information="cts", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) + + +class NISTServiceAdapter(_ModeDrivenAdapter): + name = "nist" + fetch_method_name = "get_molecule_from_NIST" + + def _build_result(self, cmp, mode: str) -> ServiceSearchResult: + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=f"{cmp.service} id: {cmp.additional_information}", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) diff --git a/moleculeresolver/services/pubchem_adapter.py b/moleculeresolver/services/pubchem_adapter.py new file mode 100644 index 0000000..0a673f4 --- /dev/null +++ b/moleculeresolver/services/pubchem_adapter.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Optional + +from moleculeresolver.services.base import ServiceAdapter, ServiceSearchResult + + +class PubChemServiceAdapter(ServiceAdapter): + name = "pubchem" + + def resolve_one( + self, + resolver, + identifier: str, + mode: str, + required_formula: Optional[str], + required_charge: Optional[int], + required_structure_type: Optional[str], + ) -> Optional[ServiceSearchResult]: + if mode not in resolver.supported_modes_by_services[self.name]: + return None + cmp = resolver.get_molecule_from_pubchem( + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ) + if cmp is None: + return None + return ServiceSearchResult( + molecule=cmp, + mode_used=cmp.mode, + identifier_used=cmp.identifier, + additional_information=f"{cmp.service} id: {cmp.additional_information}", + current_service=self.name, + synonyms=list(cmp.synonyms), + cas=set(cmp.CAS), + ) diff --git a/moleculeresolver/services/registry.py b/moleculeresolver/services/registry.py new file mode 100644 index 0000000..9b6611e --- /dev/null +++ b/moleculeresolver/services/registry.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Optional + +from moleculeresolver.services.base import ServiceAdapter + + +class ServiceAdapterRegistry: + """Simple runtime registry for all service adapters.""" + + def __init__(self): + self._adapters: dict[str, ServiceAdapter] = {} + + def register(self, adapter: ServiceAdapter) -> None: + self._adapters[adapter.name] = adapter + + def get(self, service_name: str) -> Optional[ServiceAdapter]: + return self._adapters.get(service_name) + + def names(self) -> list[str]: + return sorted(self._adapters) diff --git a/tests/test_context_lifecycle.py b/tests/test_context_lifecycle.py new file mode 100644 index 0000000..73e714e --- /dev/null +++ b/tests/test_context_lifecycle.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from moleculeresolver import MoleculeResolver + + +@dataclass +class _DummyContext: + exited: bool = False + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.exited = True + + +@dataclass +class _DummyTempFolder: + cleaned: bool = False + + def cleanup(self): + self.cleaned = True + + +def test_enter_delegates_to_lifecycle_helpers(mocker): + mr = MoleculeResolver() + + called = [] + mocker.patch.object( + mr, "_enter_rdkit_log_context", side_effect=lambda: called.append("rdkit") + ) + mocker.patch.object( + mr, "_enter_molecule_cache_context", side_effect=lambda: called.append("cache") + ) + mocker.patch.object( + mr, "_create_opsin_tempfolder", side_effect=lambda: called.append("opsin") + ) + + result = mr.__enter__() + + assert result is mr + assert called == ["rdkit", "cache", "opsin"] + + +def test_exit_delegates_to_teardown_helper(mocker): + mr = MoleculeResolver() + + teardown_mock = mocker.patch.object(mr, "_teardown_runtime_contexts") + + mr.__exit__(None, None, None) + teardown_mock.assert_called_once_with(error_ocurred=False) + + teardown_mock.reset_mock() + mr.__exit__(RuntimeError, RuntimeError("boom"), object()) + teardown_mock.assert_called_once_with(error_ocurred=True) + + +def test_teardown_runtime_contexts_cleans_tempfolder_on_success(): + mr = MoleculeResolver() + cache_ctx = _DummyContext() + rdkit_ctx = _DummyContext() + tempfolder = _DummyTempFolder() + + mr.molecule_cache = cache_ctx + mr._disabling_rdkit_logger = rdkit_ctx + mr._OPSIN_tempfolder = tempfolder + + mr._teardown_runtime_contexts(error_ocurred=False) + + assert cache_ctx.exited + assert rdkit_ctx.exited + assert tempfolder.cleaned + + +def test_teardown_runtime_contexts_keeps_tempfolder_on_error(): + mr = MoleculeResolver() + cache_ctx = _DummyContext() + rdkit_ctx = _DummyContext() + tempfolder = _DummyTempFolder() + + mr.molecule_cache = cache_ctx + mr._disabling_rdkit_logger = rdkit_ctx + mr._OPSIN_tempfolder = tempfolder + + mr._teardown_runtime_contexts(error_ocurred=True) + + assert cache_ctx.exited + assert rdkit_ctx.exited + assert not tempfolder.cleaned diff --git a/tests/test_evidence_model.py b/tests/test_evidence_model.py new file mode 100644 index 0000000..b11cca2 --- /dev/null +++ b/tests/test_evidence_model.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from moleculeresolver import MoleculeResolver +from moleculeresolver.molecule import Molecule +from moleculeresolver.resolution import ResolutionResult + + +def test_include_evidence_returns_resolution_result(monkeypatch): + mr = MoleculeResolver() + mr._available_services = ["svc1", "svc2"] + + mol_a = Molecule("CCO", ["Ethanol", "EtOH"], ["64-17-5"], "svc1", "name", "svc1", 1, "a") + mol_b = Molecule("CCO", ["ethanol"], [], "svc2", "name", "svc2", 1, "b") + + def fake_find_single_molecule(*args, **kwargs): + service = kwargs["services_to_use"][0] + return mol_a if service == "svc1" else mol_b + + monkeypatch.setattr(mr, "find_single_molecule", fake_find_single_molecule) + monkeypatch.setattr( + mr, + "filter_molecules", + lambda molecules, *_: [molecule for molecule in molecules if molecule is not None], + ) + monkeypatch.setattr( + mr, "group_molecules_by_structure", lambda molecules, *_: {"CCO": [mol_a, mol_b]} + ) + + result = mr.find_single_molecule_crosschecked( + identifiers=["ethanol"], + modes=["name"], + services_to_use=["svc1", "svc2"], + include_evidence=True, + ) + + assert isinstance(result, ResolutionResult) + assert result.best_molecule is not None + assert result.selected_smiles == "CCO" + assert len(result.ranked_candidates) == 1 + evidence = result.ranked_candidates[0] + assert evidence.service_agreement_count == 2 + assert evidence.identifier_concordance_count == 2 + assert evidence.synonym_overlap_count >= 1 + assert "service_agreement" in evidence.score_breakdown + assert evidence.total_score == sum(evidence.score_breakdown.values()) + + +def test_include_evidence_with_tied_structures_has_no_best_molecule(monkeypatch): + mr = MoleculeResolver() + mr._available_services = ["svc1", "svc2"] + + mol_a = Molecule("CCO", ["ethanol"], [], "svc1", "name", "svc1", 1, "a") + mol_b = Molecule("CCC", ["propane"], [], "svc2", "name", "svc2", 1, "b") + + def fake_find_single_molecule(*args, **kwargs): + service = kwargs["services_to_use"][0] + return mol_a if service == "svc1" else mol_b + + monkeypatch.setattr(mr, "find_single_molecule", fake_find_single_molecule) + monkeypatch.setattr( + mr, + "filter_molecules", + lambda molecules, *_: [molecule for molecule in molecules if molecule is not None], + ) + monkeypatch.setattr( + mr, + "group_molecules_by_structure", + lambda molecules, *_: {"CCO": [mol_a], "CCC": [mol_b]}, + ) + + result = mr.find_single_molecule_crosschecked( + identifiers=["ethanol", "propane"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + include_evidence=True, + try_to_choose_best_structure=False, + ) + + assert isinstance(result, ResolutionResult) + assert result.best_molecule is None + assert result.selection_reason == "multiple_structures_tied" diff --git a/tests/test_integration.py b/tests/test_integration.py index fc3664f..6051615 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,12 +1,13 @@ from datetime import datetime import pytest -from moleculeresolver import MoleculeResolver +from moleculeresolver import Molecule, MoleculeResolver import json import os from pathlib import Path from typing import Any, Callable, Dict, Optional, Union from tqdm import tqdm from dataclasses import dataclass +from types import SimpleNamespace # IUPAC names @@ -265,5 +266,145 @@ def generate_data(json_path, overwrite_json: bool = False): for data in benchmark.values(): method(data) +def test_multi_name_exhaustive_benchmark_case(monkeypatch): + mr = MoleculeResolver() + mr.supported_modes_by_services = {"svc1": ["name"], "svc2": ["name"]} + mr._available_services = ["svc1", "svc2"] + + monkeypatch.setattr(mr, "_check_parameters", lambda **_: None) + monkeypatch.setattr( + mr, + "_check_and_flatten_identifiers_and_modes", + lambda identifiers, modes: ( + ["isopropanol", "2-propanol"], + ["name", "name"], + ["isopropanol", "2-propanol"], + set(), + None, + ), + ) + monkeypatch.setattr(mr, "standardize_SMILES", lambda smiles: smiles) + + def _adapter_result(smiles, synonym, identifier): + return SimpleNamespace( + molecule=SimpleNamespace(SMILES=smiles), + synonyms=[synonym], + cas={"67-63-0"}, + additional_information="mock", + mode_used="name", + identifier_used=identifier, + ) + + result_map = { + ("svc1", "isopropanol"): _adapter_result("CC(O)C", "isopropanol", "isopropanol"), + ("svc1", "2-propanol"): _adapter_result("CC(C)O", "2-propanol", "2-propanol"), + ("svc2", "isopropanol"): _adapter_result("CC(C)O", "isopropanol", "isopropanol"), + } + + monkeypatch.setattr( + mr, + "_resolve_identifier_with_adapter", + lambda service, identifier, mode, *_: result_map.get((service, identifier)), + ) + + result = mr.find_single_molecule( + identifiers=["isopropanol", "2-propanol"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + search_strategy="exhaustive", + ) + + assert result is not None + assert result.SMILES == "CC(C)O" + + +def test_consensus_vs_legacy_mode_comparison(monkeypatch): + mr = MoleculeResolver() + mr._available_services = ["svc1", "svc2"] + + legacy_winner = Molecule( + "CCO", ["ethanol"], [], "svc1", "name", "svc1", 1, "ethanol" + ) + consensus_winner = Molecule( + "CCO", ["ethyl alcohol"], [], "svc2", "name", "svc2", 1, "ethyl alcohol" + ) + + def fake_find_single_molecule(*args, **kwargs): + service = kwargs["services_to_use"][0] + if service == "svc1": + return legacy_winner + return consensus_winner + + monkeypatch.setattr(mr, "find_single_molecule", fake_find_single_molecule) + monkeypatch.setattr( + mr, + "filter_molecules", + lambda molecules, *_: [molecule for molecule in molecules if molecule is not None], + ) + monkeypatch.setattr( + mr, + "group_molecules_by_structure", + lambda molecules, *_: {"CCO": molecules}, + ) + + legacy_result = mr.find_single_molecule_crosschecked( + identifiers=["ethanol"], + modes=["name"], + services_to_use=["svc1", "svc2"], + resolution_mode="legacy", + ) + consensus_result = mr.find_single_molecule_crosschecked( + identifiers=["ethanol"], + modes=["name"], + services_to_use=["svc1", "svc2"], + resolution_mode="consensus", + ) + + assert legacy_result is not None + assert consensus_result is not None + assert legacy_result.SMILES == consensus_result.SMILES + + +def test_strict_isomer_acceptance_case(monkeypatch): + mr = MoleculeResolver() + mr._available_services = ["svc1", "svc2"] + + mol_a = Molecule("CCO", ["ethanol"], [], "svc1", "name", "svc1", 1, "ethanol") + mol_b = Molecule( + "C[C@H](O)C", ["(S)-2-butanol"], [], "svc2", "name", "svc2", 1, "(S)-2-butanol" + ) + + def fake_find_single_molecule(*args, **kwargs): + service = kwargs["services_to_use"][0] + return mol_a if service == "svc1" else mol_b + + monkeypatch.setattr(mr, "find_single_molecule", fake_find_single_molecule) + monkeypatch.setattr( + mr, + "filter_molecules", + lambda molecules, *_: [molecule for molecule in molecules if molecule is not None], + ) + monkeypatch.setattr( + mr, + "group_molecules_by_structure", + lambda molecules, *_: {"CCO": [mol_a], "C[C@H](O)C": [mol_b]}, + ) + monkeypatch.setattr( + mr, + "_collect_opsin_isomer_matches", + lambda grouped, smiles: {smiles[0]: False, smiles[1]: True}, + ) + + result = mr.find_single_molecule_crosschecked( + identifiers=["ethanol", "(S)-2-butanol"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + resolution_mode="strict_isomer", + ) + + assert result is not None + assert result.SMILES == "C[C@H](O)C" + + if __name__ == "__main__": generate_data(RESPONSES_PATH, overwrite_json=True) diff --git a/tests/test_scoring_engine.py b/tests/test_scoring_engine.py new file mode 100644 index 0000000..aab487f --- /dev/null +++ b/tests/test_scoring_engine.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass +from rdkit import Chem + +from moleculeresolver.resolution import ( + build_structure_group_candidates, + score_structure_groups, + select_best_scored_structure, +) + + +@dataclass +class _FakeMolecule: + mode: str + service: str + + +class _FakeResolver: + @staticmethod + def get_from_SMILES(smiles): + return Chem.MolFromSmiles(smiles) + + @staticmethod + def to_SMILES(smiles, isomeric=False): + return Chem.MolToSmiles(smiles, isomericSmiles=isomeric) + + +def test_weighted_scoring_prefers_higher_crosscheck_count(): + resolver = _FakeResolver() + grouped = { + "CCO": [_FakeMolecule(mode="name", service="opsin"), _FakeMolecule(mode="name", service="pubchem")], + "CCC": [_FakeMolecule(mode="name", service="pubchem")], + } + + candidates = build_structure_group_candidates(resolver, grouped) + scored = score_structure_groups(candidates) + best = select_best_scored_structure(scored) + + assert best.smiles == "CCO" + assert best.crosscheck_count == 2 + + +def test_tie_breaker_is_deterministic_for_equal_scores(): + # Same score components. Lexicographical SMILES decides. + scored = [ + type( + "Score", + (), + { + "smiles": "CCN", + "total_score": 200, + "crosscheck_count": 2, + "opsin_bonus": 0, + "stereo_specificity_bonus": 10, + "defined_chirality_bonus": 10, + "bond_stereo_bonus": 0, + }, + )(), + type( + "Score", + (), + { + "smiles": "CCO", + "total_score": 200, + "crosscheck_count": 2, + "opsin_bonus": 0, + "stereo_specificity_bonus": 10, + "defined_chirality_bonus": 10, + "bond_stereo_bonus": 0, + }, + )(), + ] + + best = select_best_scored_structure(scored) + + assert best.smiles == "CCN" + + +def test_stereo_specific_candidate_wins_when_crosscheck_is_equal(): + resolver = _FakeResolver() + grouped = { + "CCO": [_FakeMolecule(mode="name", service="pubchem")], + "C[C@H](O)F": [_FakeMolecule(mode="name", service="pubchem")], + } + + candidates = build_structure_group_candidates(resolver, grouped) + scored = score_structure_groups(candidates) + best = select_best_scored_structure(scored) + + assert best.smiles == "C[C@H](O)F" + assert best.stereo_specificity_bonus > 0 diff --git a/tests/test_search_strategy.py b/tests/test_search_strategy.py new file mode 100644 index 0000000..bf26341 --- /dev/null +++ b/tests/test_search_strategy.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from moleculeresolver import MoleculeResolver + + +def _make_exhaustive_adapter_result( + smiles: str, + synonym: str, + *, + identifier: str, + mode: str = "name", +): + return SimpleNamespace( + molecule=SimpleNamespace(SMILES=smiles), + synonyms=[synonym], + cas=set(), + additional_information=f"{identifier}:{mode}", + mode_used=mode, + identifier_used=identifier, + ) + + +def _make_strategy_test_resolver(monkeypatch) -> MoleculeResolver: + mr = MoleculeResolver() + mr.supported_modes_by_services = { + "svc1": ["name"], + "svc2": ["name"], + } + mr._available_services = ["svc1", "svc2"] + + monkeypatch.setattr(mr, "_check_parameters", lambda **_: None) + monkeypatch.setattr( + mr, + "_check_and_flatten_identifiers_and_modes", + lambda identifiers, modes: ( + ["alpha", "beta"], + ["name", "name"], + ["alpha", "beta"], + set(), + None, + ), + ) + monkeypatch.setattr(mr, "standardize_SMILES", lambda smi: smi) + return mr + + +def test_first_hit_stops_after_first_match(monkeypatch): + mr = _make_strategy_test_resolver(monkeypatch) + calls = [] + + def fake_resolve_with_adapter(service, identifiers, modes, *_): + calls.append((service, tuple(identifiers), tuple(modes))) + if service == "svc1": + return SimpleNamespace( + molecule=SimpleNamespace(SMILES="CCO"), + synonyms=["alpha"], + additional_information="svc1:alpha", + mode_used="name", + identifier_used="alpha", + cas=set(), + ) + return None + + monkeypatch.setattr(mr, "_resolve_service_with_adapter", fake_resolve_with_adapter) + + result = mr.find_single_molecule( + identifiers=["alpha", "beta"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + search_strategy="first_hit", + ) + + assert result is not None + assert result.SMILES == "CCO" + assert calls == [("svc1", ("alpha", "beta"), ("name", "name"))] + + +def test_exhaustive_search_checks_all_pairs_and_uses_consensus(monkeypatch): + mr = _make_strategy_test_resolver(monkeypatch) + calls = [] + response_map = { + ("svc1", "alpha"): _make_exhaustive_adapter_result("CCO", "alpha", identifier="alpha"), + ("svc1", "beta"): _make_exhaustive_adapter_result("CCC", "beta", identifier="beta"), + ("svc2", "alpha"): _make_exhaustive_adapter_result("CCO", "alpha", identifier="alpha"), + ("svc2", "beta"): _make_exhaustive_adapter_result("CCO", "beta", identifier="beta"), + } + + def fake_resolve(service, identifier, mode, *_): + calls.append((service, identifier, mode)) + return response_map.get((service, identifier)) + + monkeypatch.setattr(mr, "_resolve_identifier_with_adapter", fake_resolve) + + result = mr.find_single_molecule( + identifiers=["alpha", "beta"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + search_strategy="exhaustive", + ) + + assert result is not None + assert result.SMILES == "CCO" + assert len(calls) == 4 + assert set(result.synonyms) == {"alpha", "beta"} diff --git a/tests/test_service_adapters.py b/tests/test_service_adapters.py new file mode 100644 index 0000000..28e7d07 --- /dev/null +++ b/tests/test_service_adapters.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from moleculeresolver.services import ( + CASRegistryServiceAdapter, + OPSINServiceAdapter, + PubChemServiceAdapter, + ServiceAdapterRegistry, +) + + +@dataclass +class _FakeMolecule: + SMILES: str + synonyms: list[str] + CAS: list[str] + service: str + mode: str + identifier: str + additional_information: str + + +class _FakeResolver: + def __init__(self): + self.supported_modes_by_services = { + "opsin": ["name"], + "pubchem": ["name", "cas"], + "cas_registry": ["name", "cas"], + } + + def get_molecule_from_OPSIN( + self, + identifier, + required_formula, + required_charge, + required_structure_type, + ): + if identifier == "ethanol": + return _FakeMolecule( + "CCO", + ["ethanol"], + [], + "opsin", + "name", + identifier, + "opsin resolution", + ) + return None + + def get_molecule_from_pubchem( + self, + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ): + if identifier == "ethanol": + return _FakeMolecule( + "CCO", + ["ethyl alcohol"], + ["64-17-5"], + "pubchem", + mode, + identifier, + "702", + ) + return None + + def get_molecule_from_CAS_registry( + self, + identifier, + mode, + required_formula, + required_charge, + required_structure_type, + ): + if identifier == "64-17-5": + return _FakeMolecule( + "CCO", + ["ethanol"], + ["64-17-5"], + "cas_registry", + mode, + identifier, + "cas_registry", + ) + return None + + +def test_service_adapter_registry_roundtrip(): + registry = ServiceAdapterRegistry() + adapter = OPSINServiceAdapter() + registry.register(adapter) + + assert registry.get("opsin") is adapter + assert registry.names() == ["opsin"] + + +def test_opsin_adapter_resolves_name_only(): + resolver = _FakeResolver() + adapter = OPSINServiceAdapter() + + result = adapter.resolve( + resolver, + flattened_identifiers=["ethanol"], + flattened_modes=["name"], + required_formula=None, + required_charge=None, + required_structure_type=None, + ) + assert result is not None + assert result.molecule.SMILES == "CCO" + assert result.mode_used == "name" + assert result.current_service == "opsin" + + +def test_pubchem_adapter_formats_additional_information(): + resolver = _FakeResolver() + adapter = PubChemServiceAdapter() + + result = adapter.resolve( + resolver, + flattened_identifiers=["ethanol"], + flattened_modes=["name"], + required_formula=None, + required_charge=None, + required_structure_type=None, + ) + assert result is not None + assert result.additional_information == "pubchem id: 702" + assert "64-17-5" in result.cas + + +def test_cas_registry_adapter_provides_authoritative_cas_set(): + resolver = _FakeResolver() + adapter = CASRegistryServiceAdapter() + + result = adapter.resolve( + resolver, + flattened_identifiers=["64-17-5"], + flattened_modes=["cas"], + required_formula=None, + required_charge=None, + required_structure_type=None, + ) + assert result is not None + assert result.current_service == "cas_registry" + assert result.cas == {"64-17-5"} + + +def test_resolve_one_contract_for_opsin_adapter(): + resolver = _FakeResolver() + adapter = OPSINServiceAdapter() + + result = adapter.resolve_one( + resolver, + identifier="ethanol", + mode="name", + required_formula=None, + required_charge=None, + required_structure_type=None, + ) + + assert result is not None + assert result.molecule.SMILES == "CCO" + assert result.identifier_used == "ethanol" + + +def test_resolve_one_returns_none_for_unsupported_mode(): + resolver = _FakeResolver() + adapter = PubChemServiceAdapter() + + result = adapter.resolve_one( + resolver, + identifier="ethanol", + mode="inchikey", + required_formula=None, + required_charge=None, + required_structure_type=None, + ) + + assert result is None diff --git a/tests/test_strict_isomer_mode.py b/tests/test_strict_isomer_mode.py new file mode 100644 index 0000000..37aa353 --- /dev/null +++ b/tests/test_strict_isomer_mode.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from moleculeresolver import MoleculeResolver +from moleculeresolver.molecule import Molecule + + +def _build_stubbed_crosscheck_resolver(monkeypatch): + mr = MoleculeResolver() + mr._available_services = ["svc1", "svc2"] + + mol_a = Molecule("CCO", ["alpha"], [], "svc1", "name", "svc1", 1, "alpha") + mol_b = Molecule("C[C@H](O)C", ["beta"], [], "svc2", "name", "svc2", 1, "beta") + + def fake_find_single_molecule(*args, **kwargs): + service = kwargs["services_to_use"][0] + return mol_a if service == "svc1" else mol_b + + monkeypatch.setattr(mr, "find_single_molecule", fake_find_single_molecule) + monkeypatch.setattr( + mr, + "filter_molecules", + lambda molecules, *_: [molecule for molecule in molecules if molecule is not None], + ) + monkeypatch.setattr( + mr, + "group_molecules_by_structure", + lambda molecules, *_: {"CCO": [mol_a], "C[C@H](O)C": [mol_b]}, + ) + return mr + + +def test_strict_isomer_rejects_when_no_opsin_match(monkeypatch): + mr = _build_stubbed_crosscheck_resolver(monkeypatch) + monkeypatch.setattr( + mr, + "_collect_opsin_isomer_matches", + lambda grouped, smiles: {smiles[0]: False, smiles[1]: False}, + ) + + result = mr.find_single_molecule_crosschecked( + identifiers=["alpha", "beta"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + resolution_mode="strict_isomer", + ) + + assert result is None + + +def test_strict_isomer_returns_verified_candidate(monkeypatch): + mr = _build_stubbed_crosscheck_resolver(monkeypatch) + monkeypatch.setattr( + mr, + "_collect_opsin_isomer_matches", + lambda grouped, smiles: {smiles[0]: False, smiles[1]: True}, + ) + + result = mr.find_single_molecule_crosschecked( + identifiers=["alpha", "beta"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + resolution_mode="strict_isomer", + ) + + assert result is not None + assert result.SMILES == "C[C@H](O)C" + + +def test_legacy_mode_skips_strict_isomer_filter(monkeypatch): + mr = _build_stubbed_crosscheck_resolver(monkeypatch) + + def fail_if_called(*args, **kwargs): + raise AssertionError("strict OPSIN filtering should not run in legacy mode") + + monkeypatch.setattr(mr, "_collect_opsin_isomer_matches", fail_if_called) + + result = mr.find_single_molecule_crosschecked( + identifiers=["alpha", "beta"], + modes=["name", "name"], + services_to_use=["svc1", "svc2"], + resolution_mode="legacy", + ) + + assert result is not None