Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions src/material_hasher/dataset_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2025 Entalpic
from typing import Any, Optional, TypeVar

import numpy as np
from pymatgen.core import Structure

from material_hasher.hasher import HASHERS
from material_hasher.hasher.base import HasherBase
from material_hasher.similarity import SIMILARITY_MATCHERS
from material_hasher.similarity.base import SimilarityMatcherBase
from material_hasher.types import StructureEquivalenceChecker

ALL_EQUIVALENCE_CHECKERS = {
**{v.__name__: v for v in HASHERS.values()},
**{v.__name__: v for v in SIMILARITY_MATCHERS.values()},
}

EquivalenceCheckerType = TypeVar(
"EquivalenceCheckerType", bound=HasherBase | SimilarityMatcherBase
)


class DatasetStore:
"""Stores the hashes or embedding vectors of a dataset.

This is used for comparing structures against a reference dataset.

Parameters
----------
equivalence_checker_class : type[EquivalenceCheckerType]
The class of the equivalence checker to use (either a hasher or similarity matcher).
equivalence_checker_kwargs : dict[str, Any]
Keyword arguments to pass to the equivalence checker constructor.
"""

def __init__(
self,
equivalence_checker_class: type[EquivalenceCheckerType],
equivalence_checker_kwargs: dict[str, Any] = {},
):
self.equivalence_checker_class = equivalence_checker_class
self.equivalence_checker_kwargs = equivalence_checker_kwargs
self.equivalence_checker = self.equivalence_checker_class(
**self.equivalence_checker_kwargs
)
self.embeddings: list[np.ndarray | str] = []

@staticmethod
def _get_structure_embedding(
structure: Structure, equivalence_checker: StructureEquivalenceChecker
) -> np.ndarray | str:
"""Get the embedding or hash of a structure.

Parameters
----------
structure : Structure
The structure to get the embedding/hash for.
equivalence_checker : StructureEquivalenceChecker
The equivalence checker to use.

Returns
-------
np.ndarray | str
The embedding vector for similarity matchers or hash string for hashers.
"""
if isinstance(equivalence_checker, HasherBase):
return equivalence_checker.get_material_hash(structure)
elif isinstance(equivalence_checker, SimilarityMatcherBase):
return equivalence_checker.get_structure_embeddings(structure)
else:
raise ValueError(
f"Unsupported equivalence checker: {type(equivalence_checker)}"
)

def _get_structures_embeddings(
self, structures: list[Structure]
) -> list[np.ndarray | str]:
"""Get the embeddings of a list of structures.

Parameters
----------
structures : list[Structure]
The structures to get embeddings for.

Returns
-------
list[np.ndarray | str]
List of embeddings or hashes for each structure.
"""
return [
self._get_structure_embedding(structure, self.equivalence_checker)
for structure in structures
]

def compute_and_store_embeddings(self, structures: list[Structure]) -> None:
"""Compute the embeddings/hashes of the given structures and store them.

Parameters
----------
structures : list[Structure]
The structures to store embeddings/hashes for.
"""
self.embeddings.extend(self._get_structures_embeddings(structures))

def store_embeddings(self, embeddings: list[np.ndarray | str]) -> None:
"""Store the embeddings/hashes of the given structures.

Parameters
----------
embeddings : list[np.ndarray | str]
The embeddings/hashes to store.
"""
self.embeddings.extend(embeddings)

def is_equivalent(
self, structure: Structure, threshold: Optional[float] = None
) -> list[bool]:
"""Check if a structure is equivalent to any of the stored structures.

Parameters
----------
structure : Structure
The structure to check.
threshold : float, optional
Threshold for similarity matchers, by default None.

Returns
-------
list[bool]
List of boolean values indicating equivalence with each stored structure.
"""
query_embedding = self._get_structure_embedding(
structure, self.equivalence_checker
)

if isinstance(self.equivalence_checker, HasherBase):
return [
query_embedding == stored_embedding
for stored_embedding in self.embeddings
]
elif isinstance(self.equivalence_checker, SimilarityMatcherBase):
return [
self.equivalence_checker.get_similarity_embeddings(
query_embedding, stored_embedding
)
>= (
threshold
if threshold is not None
else self.equivalence_checker.threshold
)
for stored_embedding in self.embeddings
]
else:
raise ValueError(
f"Unsupported equivalence checker: {type(self.equivalence_checker)}"
)

def reset(self) -> None:
"""Reset the dataset store."""
self.embeddings = []

def save(self, path: str) -> None:
"""Save the dataset store to a file.

Parameters
----------
path : str
Path to save the dataset store to.
"""
save_data = {
"equivalence_checker_class": self.equivalence_checker_class.__name__,
"equivalence_checker_kwargs": self.equivalence_checker_kwargs,
"embeddings": self.embeddings,
}
np.save(path, save_data, allow_pickle=True)

@classmethod
def load(
cls,
path: str,
) -> "DatasetStore":
"""Load the dataset store from a file.

Parameters
----------
path : str
Path to load the dataset store from.
equivalence_checker_class : type[EquivalenceCheckerType]
The class of the equivalence checker to use.
equivalence_checker_kwargs : dict[str, Any]
Keyword arguments to pass to the equivalence checker constructor.

Returns
-------
DatasetStore
The loaded dataset store.
"""
save_data = np.load(path, allow_pickle=True).item()

equivalence_checker_class = ALL_EQUIVALENCE_CHECKERS[
save_data["equivalence_checker_class"]
]
store = cls(equivalence_checker_class, save_data["equivalence_checker_kwargs"])

store.embeddings = save_data["embeddings"]
return store
59 changes: 59 additions & 0 deletions src/material_hasher/similarity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,62 @@ def get_pairwise_equivalence(
Matrix of equivalence between structures.
"""
pass

def get_structure_embeddings(
self, structures: list[Structure] | Structure
) -> np.ndarray:
"""Get the embeddings of a list of structures.

This is not compatible with all the similarity matchers.

Parameters
----------
structures : list[Structure] | Structure
List of structures to get the embeddings of.

Returns
-------
np.ndarray
Embeddings of the structures.
"""
raise NotImplementedError(
"This method is not implemented for this similarity matcher."
)

def get_similarity_embeddings(
self, embeddings1: np.ndarray, embeddings2: np.ndarray
) -> float:
"""Get the similarity score between two embeddings.

Parameters
----------
embeddings1 : np.ndarray
First embeddings to compare.
embeddings2 : np.ndarray
Second embeddings to compare.

Returns
-------
float
Similarity score between the two embeddings.
"""
raise NotImplementedError(
"This method is not implemented for this similarity matcher."
)

def get_pairwise_similarity_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""Get the pairwise similarity embeddings of a list of embeddings."""
n = len(embeddings)
scores = np.zeros((n, n))

for i, embedding1 in enumerate(embeddings):
for j, embedding2 in enumerate(embeddings):
if i <= j:
scores[i, j] = self.get_similarity_embeddings(
embedding1, embedding2
)

# Fill tril
scores = scores + scores.T - np.diag(np.diag(scores))

return scores
Loading