From f4dfed25ad6e2e391aedc948a8155d7090478346 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Wed, 23 Apr 2025 22:23:04 +0200 Subject: [PATCH 1/4] feat: Add helper methods for dataset store with similarity --- src/material_hasher/similarity/base.py | 59 ++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/material_hasher/similarity/base.py b/src/material_hasher/similarity/base.py index e9c0851..6b87283 100644 --- a/src/material_hasher/similarity/base.py +++ b/src/material_hasher/similarity/base.py @@ -110,3 +110,62 @@ def get_pairwise_equivalence( Matrix of equivalence between structures. """ pass + + def get_structure_embeddings( + self, structures: list[Structure] | Structure + ) -> np.ndarray: + """Get the embeddings of a list of structures. + + This is not compatible with all the similarity matchers. + + Parameters + ---------- + structures : list[Structure] | Structure + List of structures to get the embeddings of. + + Returns + ------- + np.ndarray + Embeddings of the structures. + """ + raise NotImplementedError( + "This method is not implemented for this similarity matcher." + ) + + def get_similarity_embeddings( + self, embeddings1: np.ndarray, embeddings2: np.ndarray + ) -> float: + """Get the similarity score between two embeddings. + + Parameters + ---------- + embeddings1 : np.ndarray + First embeddings to compare. + embeddings2 : np.ndarray + Second embeddings to compare. + + Returns + ------- + float + Similarity score between the two embeddings. + """ + raise NotImplementedError( + "This method is not implemented for this similarity matcher." + ) + + def get_pairwise_similarity_embeddings(self, embeddings: np.ndarray) -> np.ndarray: + """Get the pairwise similarity embeddings of a list of embeddings.""" + n = len(embeddings) + scores = np.zeros((n, n)) + + for i, embedding1 in enumerate(embeddings): + for j, embedding2 in enumerate(embeddings): + if i <= j: + scores[i, j] = self.get_similarity_embeddings( + embedding1, embedding2 + ) + + # Fill tril + scores = scores + scores.T - np.diag(np.diag(scores)) + + return scores From 6a6fed0592fa911e16092bbf944c87b94279ee72 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Wed, 23 Apr 2025 22:24:38 +0200 Subject: [PATCH 2/4] feat: Add dataset store --- src/material_hasher/dataset_store.py | 194 +++++++++++++++++++++++++++ tests/test_dataset_store.py | 160 ++++++++++++++++++++++ 2 files changed, 354 insertions(+) create mode 100644 src/material_hasher/dataset_store.py create mode 100644 tests/test_dataset_store.py diff --git a/src/material_hasher/dataset_store.py b/src/material_hasher/dataset_store.py new file mode 100644 index 0000000..5a35b4e --- /dev/null +++ b/src/material_hasher/dataset_store.py @@ -0,0 +1,194 @@ +# Copyright 2025 Entalpic +from typing import Any, Optional, TypeVar + +import numpy as np +from pymatgen.core import Structure + +from material_hasher.hasher.base import HasherBase +from material_hasher.similarity.base import SimilarityMatcherBase +from material_hasher.types import StructureEquivalenceChecker + +EquivalenceCheckerType = TypeVar( + "EquivalenceCheckerType", bound=HasherBase | SimilarityMatcherBase +) + + +class DatasetStore: + """Stores the hashes or embedding vectors of a dataset. + + This is used for comparing structures against a reference dataset. + + Parameters + ---------- + equivalence_checker_class : type[EquivalenceCheckerType] + The class of the equivalence checker to use (either a hasher or similarity matcher). + equivalence_checker_kwargs : dict[str, Any] + Keyword arguments to pass to the equivalence checker constructor. + """ + + def __init__( + self, + equivalence_checker_class: type[EquivalenceCheckerType], + equivalence_checker_kwargs: dict[str, Any] = {}, + ): + self.equivalence_checker_class = equivalence_checker_class + self.equivalence_checker_kwargs = equivalence_checker_kwargs + self.equivalence_checker = self.equivalence_checker_class( + **self.equivalence_checker_kwargs + ) + self.embeddings: list[np.ndarray | str] = [] + + @staticmethod + def _get_structure_embedding( + structure: Structure, equivalence_checker: StructureEquivalenceChecker + ) -> np.ndarray | str: + """Get the embedding or hash of a structure. + + Parameters + ---------- + structure : Structure + The structure to get the embedding/hash for. + equivalence_checker : StructureEquivalenceChecker + The equivalence checker to use. + + Returns + ------- + np.ndarray | str + The embedding vector for similarity matchers or hash string for hashers. + """ + if isinstance(equivalence_checker, HasherBase): + return equivalence_checker.get_material_hash(structure) + elif isinstance(equivalence_checker, SimilarityMatcherBase): + return equivalence_checker.get_structure_embeddings(structure) + else: + raise ValueError( + f"Unsupported equivalence checker: {type(equivalence_checker)}" + ) + + def _get_structures_embeddings( + self, structures: list[Structure] + ) -> list[np.ndarray | str]: + """Get the embeddings of a list of structures. + + Parameters + ---------- + structures : list[Structure] + The structures to get embeddings for. + + Returns + ------- + list[np.ndarray | str] + List of embeddings or hashes for each structure. + """ + return [ + self._get_structure_embedding(structure, self.equivalence_checker) + for structure in structures + ] + + def store_embeddings(self, structures: list[Structure]) -> None: + """Store the embeddings/hashes of the given structures. + + Parameters + ---------- + structures : list[Structure] + The structures to store embeddings/hashes for. + """ + self.embeddings.extend(self._get_structures_embeddings(structures)) + + def is_equivalent( + self, structure: Structure, threshold: Optional[float] = None + ) -> list[bool]: + """Check if a structure is equivalent to any of the stored structures. + + Parameters + ---------- + structure : Structure + The structure to check. + threshold : float, optional + Threshold for similarity matchers, by default None. + + Returns + ------- + list[bool] + List of boolean values indicating equivalence with each stored structure. + """ + query_embedding = self._get_structure_embedding( + structure, self.equivalence_checker + ) + + if isinstance(self.equivalence_checker, HasherBase): + return [ + query_embedding == stored_embedding + for stored_embedding in self.embeddings + ] + elif isinstance(self.equivalence_checker, SimilarityMatcherBase): + return [ + self.equivalence_checker.get_similarity_embeddings( + query_embedding, stored_embedding + ) + >= ( + threshold + if threshold is not None + else self.equivalence_checker.threshold + ) + for stored_embedding in self.embeddings + ] + else: + raise ValueError( + f"Unsupported equivalence checker: {type(self.equivalence_checker)}" + ) + + def reset(self) -> None: + """Reset the dataset store.""" + self.embeddings = [] + + def save(self, path: str) -> None: + """Save the dataset store to a file. + + Parameters + ---------- + path : str + Path to save the dataset store to. + """ + save_data = { + "equivalence_checker_class": self.equivalence_checker_class.__name__, + "equivalence_checker_kwargs": self.equivalence_checker_kwargs, + "embeddings": self.embeddings, + } + np.save(path, save_data, allow_pickle=True) + + @classmethod + def load( + cls, + path: str, + equivalence_checker_class: type[EquivalenceCheckerType], + equivalence_checker_kwargs: dict[str, Any] = {}, + ) -> "DatasetStore": + """Load the dataset store from a file. + + Parameters + ---------- + path : str + Path to load the dataset store from. + equivalence_checker_class : type[EquivalenceCheckerType] + The class of the equivalence checker to use. + equivalence_checker_kwargs : dict[str, Any] + Keyword arguments to pass to the equivalence checker constructor. + + Returns + ------- + DatasetStore + The loaded dataset store. + """ + save_data = np.load(path, allow_pickle=True).item() + + # Verify the equivalence checker class matches + if save_data["equivalence_checker_class"] != equivalence_checker_class.__name__: + raise ValueError( + f"Loaded equivalence checker class {save_data['equivalence_checker_class']} " + f"does not match provided class {equivalence_checker_class.__name__}" + ) + + store = cls(equivalence_checker_class, equivalence_checker_kwargs) + store.embeddings = save_data["embeddings"] + return store diff --git a/tests/test_dataset_store.py b/tests/test_dataset_store.py new file mode 100644 index 0000000..41c5bc2 --- /dev/null +++ b/tests/test_dataset_store.py @@ -0,0 +1,160 @@ +# Copyright 2025 Entalpic +import os +from typing import Optional + +import numpy as np +import pytest +from pymatgen.core import Lattice, Structure + +from material_hasher.dataset_store import DatasetStore +from material_hasher.hasher.base import HasherBase +from material_hasher.similarity.base import SimilarityMatcherBase + + +class DummyHasher(HasherBase): + """A dummy hasher that just uses the number of sites as a hash.""" + + def get_material_hash(self, structure: Structure) -> str: + return str(len(structure)) + + +class DummySimilarityMatcher(SimilarityMatcherBase): + """A dummy similarity matcher that compares number of sites.""" + + def __init__(self, threshold: float = 0.49): + self.threshold = threshold + + def get_similarity_score( + self, structure1: Structure, structure2: Structure + ) -> float: + # Return the absolute difference between the number of sites + return 1 / (np.abs(len(structure1) - len(structure2)) + 1) + + def get_similarity_embeddings( + self, embeddings1: np.ndarray, embeddings2: np.ndarray + ) -> float: + return 1 / (np.abs(embeddings1 - embeddings2) + 1) + + def is_equivalent( + self, + structure1: Structure, + structure2: Structure, + threshold: Optional[float] = None, + ) -> bool: + score = self.get_similarity_score(structure1, structure2) + return score >= (threshold if threshold is not None else self.threshold) + + def get_pairwise_equivalence( + self, structures: list[Structure], threshold: Optional[float] = None + ) -> np.ndarray: + n = len(structures) + result = np.zeros((n, n), dtype=bool) + for i in range(n): + for j in range(n): + result[i, j] = self.is_equivalent( + structures[i], structures[j], threshold + ) + return result + + def get_structure_embeddings(self, structure: Structure) -> np.ndarray: + # Return a 1D array with the number of sites + return np.array([len(structure)], dtype=float) + + +@pytest.fixture +def simple_structures(): + """Create a few simple test structures.""" + lattice = Lattice.cubic(1.0) + structures = [ + Structure(lattice, ["H"], [[0, 0, 0]]), # 1 site + Structure(lattice, ["H", "He"], [[0, 0, 0], [0.5, 0.5, 0.5]]), # 2 sites + Structure(lattice, ["H"], [[0, 0, 0]]), # 1 site + ] + return structures + + +def test_hasher_store(simple_structures): + """Test storing and comparing structures using the dummy hasher.""" + store = DatasetStore(DummyHasher) + + # Store first two structures + store.store_embeddings(simple_structures[:2]) + + # Compare third structure (should match first but not second) + results = store.is_equivalent(simple_structures[2]) + assert len(results) == 2 + assert results[0] + assert not results[1] + + +def test_similarity_store(simple_structures): + """Test storing and comparing structures using the dummy similarity matcher.""" + store = DatasetStore(DummySimilarityMatcher, {"threshold": 0.95}) + + store.store_embeddings(simple_structures[:2]) + + results = store.is_equivalent(simple_structures[2]) + assert len(results) == 2 + assert results[0] + assert not results[1] + + +def test_save_load_hasher(simple_structures, tmp_path): + """Test saving and loading a hasher store.""" + save_path = os.path.join(tmp_path, "store.npy") + + # Create and save store + store = DatasetStore(DummyHasher) + store.store_embeddings(simple_structures[:2]) + store.save(save_path) + + # Load store and verify + loaded_store = DatasetStore.load(save_path, DummyHasher) + results = loaded_store.is_equivalent(simple_structures[2]) + assert len(results) == 2 + assert results[0] # Should match first structure (1 site) + assert not results[1] # Should not match second structure (2 sites) + + +def test_save_load_similarity(simple_structures, tmp_path): + """Test saving and loading a similarity matcher store.""" + save_path = os.path.join(tmp_path, "store.npy") + + store = DatasetStore(DummySimilarityMatcher, {"threshold": 0.95}) + store.store_embeddings(simple_structures[:2]) + store.save(save_path) + + loaded_store = DatasetStore.load( + save_path, DummySimilarityMatcher, {"threshold": 0.95} + ) + results = loaded_store.is_equivalent(simple_structures[2]) + assert len(results) == 2 + assert results[0] + assert not results[1] + + +def test_reset(simple_structures): + """Test resetting the store.""" + store = DatasetStore(DummyHasher) + + # Store structures and verify + store.store_embeddings(simple_structures[:2]) + assert len(store.embeddings) == 2 + + # Reset and verify + store.reset() + assert len(store.embeddings) == 0 + + +def test_incompatible_checker_load(simple_structures, tmp_path): + """Test that loading with incompatible checker raises error.""" + save_path = os.path.join(tmp_path, "store.npy") + + # Save with hasher + store = DatasetStore(DummyHasher) + store.store_embeddings(simple_structures[:2]) + store.save(save_path) + + # Try to load with similarity matcher + with pytest.raises(ValueError, match="does not match provided class"): + DatasetStore.load(save_path, DummySimilarityMatcher) From 2fc27b751a5cd14dd3af8ce1972c02b8113aec58 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Tue, 6 May 2025 04:55:54 +0200 Subject: [PATCH 3/4] chore: store embeddings directly --- src/material_hasher/dataset_store.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/material_hasher/dataset_store.py b/src/material_hasher/dataset_store.py index 5a35b4e..316487f 100644 --- a/src/material_hasher/dataset_store.py +++ b/src/material_hasher/dataset_store.py @@ -85,8 +85,8 @@ def _get_structures_embeddings( for structure in structures ] - def store_embeddings(self, structures: list[Structure]) -> None: - """Store the embeddings/hashes of the given structures. + def compute_and_store_embeddings(self, structures: list[Structure]) -> None: + """Compute the embeddings/hashes of the given structures and store them. Parameters ---------- @@ -95,6 +95,16 @@ def store_embeddings(self, structures: list[Structure]) -> None: """ self.embeddings.extend(self._get_structures_embeddings(structures)) + def store_embeddings(self, embeddings: list[np.ndarray | str]) -> None: + """Store the embeddings/hashes of the given structures. + + Parameters + ---------- + embeddings : list[np.ndarray | str] + The embeddings/hashes to store. + """ + self.embeddings.extend(embeddings) + def is_equivalent( self, structure: Structure, threshold: Optional[float] = None ) -> list[bool]: From 33c8652a2187a62226f0a177fc74d0fd72c97399 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Tue, 6 May 2025 14:13:32 +0200 Subject: [PATCH 4/4] chore: Load class from store directly --- src/material_hasher/dataset_store.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/material_hasher/dataset_store.py b/src/material_hasher/dataset_store.py index 316487f..4a9c460 100644 --- a/src/material_hasher/dataset_store.py +++ b/src/material_hasher/dataset_store.py @@ -4,10 +4,17 @@ import numpy as np from pymatgen.core import Structure +from material_hasher.hasher import HASHERS from material_hasher.hasher.base import HasherBase +from material_hasher.similarity import SIMILARITY_MATCHERS from material_hasher.similarity.base import SimilarityMatcherBase from material_hasher.types import StructureEquivalenceChecker +ALL_EQUIVALENCE_CHECKERS = { + **{v.__name__: v for v in HASHERS.values()}, + **{v.__name__: v for v in SIMILARITY_MATCHERS.values()}, +} + EquivalenceCheckerType = TypeVar( "EquivalenceCheckerType", bound=HasherBase | SimilarityMatcherBase ) @@ -171,8 +178,6 @@ def save(self, path: str) -> None: def load( cls, path: str, - equivalence_checker_class: type[EquivalenceCheckerType], - equivalence_checker_kwargs: dict[str, Any] = {}, ) -> "DatasetStore": """Load the dataset store from a file. @@ -192,13 +197,10 @@ def load( """ save_data = np.load(path, allow_pickle=True).item() - # Verify the equivalence checker class matches - if save_data["equivalence_checker_class"] != equivalence_checker_class.__name__: - raise ValueError( - f"Loaded equivalence checker class {save_data['equivalence_checker_class']} " - f"does not match provided class {equivalence_checker_class.__name__}" - ) + equivalence_checker_class = ALL_EQUIVALENCE_CHECKERS[ + save_data["equivalence_checker_class"] + ] + store = cls(equivalence_checker_class, save_data["equivalence_checker_kwargs"]) - store = cls(equivalence_checker_class, equivalence_checker_kwargs) store.embeddings = save_data["embeddings"] return store