diff --git a/pyproject.toml b/pyproject.toml index 583a31b..ef89118 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = [ "ipython>=8.29.0", "pre-commit>=4.0.1", "ruff>=0.8.0", + "pytest>=8.3", "shibuya>=2024.10.15", "sphinx-autoapi>=3.3.2", "sphinx-autodoc-typehints>=2.5.0", @@ -38,5 +39,6 @@ dev = [ "sphinx-design>=0.6.1", "sphinx-math-dollar>=1.2.1", "sphinxawesome-theme>=5.3.2", + "ipdb>=0.13.13", ] diff --git a/src/material_hasher/benchmark/disordered.py b/src/material_hasher/benchmark/disordered.py index bc437db..cdaa4a0 100644 --- a/src/material_hasher/benchmark/disordered.py +++ b/src/material_hasher/benchmark/disordered.py @@ -1,7 +1,7 @@ # Copyright 2025 Entalpic +import logging from itertools import combinations from typing import Dict, List, Tuple -import logging import numpy as np import pandas as pd diff --git a/src/material_hasher/benchmark/run_disordered.py b/src/material_hasher/benchmark/run_disordered.py index 151b5a8..29a9c2b 100644 --- a/src/material_hasher/benchmark/run_disordered.py +++ b/src/material_hasher/benchmark/run_disordered.py @@ -1,11 +1,11 @@ # Copyright 2025 Entalpic import datetime +import logging import os import time from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple -import logging import numpy as np import pandas as pd @@ -13,8 +13,6 @@ import yaml from pymatgen.core import Structure -logger = logging.getLogger(__name__) - from material_hasher.benchmark.disordered import ( download_disordered_structures, get_classification_results_dissimilar, @@ -25,6 +23,8 @@ from material_hasher.similarity import SIMILARITY_MATCHERS from material_hasher.types import StructureEquivalenceChecker +logger = logging.getLogger(__name__) + STRUCTURE_CHECKERS = {**HASHERS, **SIMILARITY_MATCHERS} diff --git a/src/material_hasher/benchmark/run_transformations.py b/src/material_hasher/benchmark/run_transformations.py index d6c0805..9510bf7 100644 --- a/src/material_hasher/benchmark/run_transformations.py +++ b/src/material_hasher/benchmark/run_transformations.py @@ -1,17 +1,17 @@ # Copyright 2025 Entalpic import datetime import json +import logging import os import time from pathlib import Path from typing import Optional -import logging import matplotlib.pyplot as plt import numpy as np import pandas as pd import yaml -from datasets import Dataset, VerificationMode, concatenate_datasets, load_dataset +from datasets import Dataset, VerificationMode, load_dataset from pymatgen.core import Structure from material_hasher.benchmark.transformations import ALL_TEST_CASES, get_test_case @@ -27,7 +27,9 @@ STRUCTURE_CHECKERS = {**HASHERS, **SIMILARITY_MATCHERS} -def get_hugging_face_dataset(token: Optional[str] = None) -> Dataset: +def get_hugging_face_dataset( + token: Optional[str] = None, n_rows: Optional[int] = None +) -> Dataset: """ Only returns the dataset from Hugging Face where all the subsets are concatenated. @@ -36,6 +38,8 @@ def get_hugging_face_dataset(token: Optional[str] = None) -> Dataset: token : str, optional The authentication token required to access the dataset. Optional if the dataset is public or you have already configured the Hugging Face CLI. + n_rows : int, optional + Number of rows to load from the dataset. Returns ------- @@ -43,28 +47,24 @@ def get_hugging_face_dataset(token: Optional[str] = None) -> Dataset: The concatenated dataset from Hugging Face. """ - subsets = [ + split = "train" + if n_rows is not None: + split += f"[:{n_rows}]" + + return load_dataset( + "LeMaterial/LeMat-Bulk", "compatible_pbe", - "compatible_scan", - "compatible_pbesol", - "non_compatible", - ] - dss = [] - for subset in subsets: - dss.append( - load_dataset( - "LeMaterial/LeMat-Bulk", - subset, - token=token, - verification_mode=VerificationMode.NO_CHECKS, - )["train"] - ) - ds = concatenate_datasets(dss) - return ds + split=split, + token=token, + verification_mode=VerificationMode.NO_CHECKS, + ) def get_data_from_hugging_face( - token: Optional[str] = None, n_test_elements: int = 100, seed: int = 0 + token: Optional[str] = None, + n_test_elements: int = 100, + n_rows: Optional[int] = None, + seed: int = 0, ) -> list[Structure]: """ Downloads and processes structural data from the Hugging Face `datasets` library. @@ -80,6 +80,8 @@ def get_data_from_hugging_face( n_test_elements : int Number of elements to select from the dataset to run the benchmark on. Default is 100. This is used to run the transformation benchmark only a subset of LeMat-Bulk. + n_rows : int, optional + Number of rows to load from the dataset. This will load them in sequential order. seed : int Random seed for selecting a subset of the dataset. Default is 0. @@ -101,7 +103,7 @@ def get_data_from_hugging_face( - Errors during the transformation process are logged but do not halt execution. """ - ds = get_hugging_face_dataset(token) + ds = get_hugging_face_dataset(token, n_rows=n_rows) # Convert dataset to Pandas DataFrame logger.info("Loaded dataset:", len(ds)) @@ -232,7 +234,11 @@ def hasher_sensitivity( else: raise ValueError("Unknown structure checker") - return matching_hashes / len(transformed_structures) if len(transformed_structures) > 0 else 0 + return ( + matching_hashes / len(transformed_structures) + if len(transformed_structures) > 0 + else 0 + ) def mean_sensitivity( diff --git a/src/material_hasher/benchmark/transformations.py b/src/material_hasher/benchmark/transformations.py index 40d1cb1..5b4ff61 100644 --- a/src/material_hasher/benchmark/transformations.py +++ b/src/material_hasher/benchmark/transformations.py @@ -1,10 +1,9 @@ # Copyright 2025 Entalpic import inspect -import random -from typing import Optional, Union +from typing import Optional import numpy as np -from pymatgen.core import Structure, SymmOp +from pymatgen.core import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer ALL_TEST_CASES = [ @@ -16,10 +15,10 @@ ] PARAMETERS = { - "gaussian_noise": {"sigma": np.logspace(0.0001, 0.5,15, base=0.0000001)}, - "isometric_strain": {"pct": [1,1.05,1.1,1.2,1.5]}, - "strain": {"sigma": np.logspace(0.001, 0.5,10, base=0.0000001)}, - "translation": {"sigma": np.logspace(0.0001, 0.5,15, base=0.0000001)}, + "gaussian_noise": {"sigma": np.logspace(0.0001, 0.5, 15, base=0.0000001)}, + "isometric_strain": {"pct": [1, 1.05, 1.1, 1.2, 1.5]}, + "strain": {"sigma": np.logspace(0.001, 0.5, 10, base=0.0000001)}, + "translation": {"sigma": np.logspace(0.0001, 0.5, 15, base=0.0000001)}, "symm_ops": {"structure_symmetries": ["all_symmetries_found"]}, } diff --git a/src/material_hasher/hasher/__init__.py b/src/material_hasher/hasher/__init__.py index 2b997fe..9e873a9 100644 --- a/src/material_hasher/hasher/__init__.py +++ b/src/material_hasher/hasher/__init__.py @@ -1,10 +1,13 @@ # Copyright 2025 Entalpic -from material_hasher.hasher.entalpic import EntalpicMaterialsHasher, ShortenedEntalpicMaterialsHasher -from material_hasher.hasher.example import SimpleCompositionHasher +import warnings + +from material_hasher.hasher.entalpic import ( + EntalpicMaterialsHasher, + ShortenedEntalpicMaterialsHasher, +) from material_hasher.hasher.pdd import PointwiseDistanceDistributionHasher -import warnings -warnings.filterwarnings('always') +warnings.filterwarnings("always") __all__ = ["EntalpicMaterialsHasher"] @@ -17,6 +20,10 @@ try: from material_hasher.hasher.slices import SLICESHasher + HASHERS.update({"SLICES": SLICESHasher}) except ImportError: - warnings.warn('Failed to import SLICES. If you would like to use this module, please consider running uv pip install -r requirements_slices.txt', ImportWarning) + warnings.info( + "Failed to import SLICES. If you would like to use this module, please consider running uv pip install -r requirements_slices.txt", + ImportWarning, + ) diff --git a/src/material_hasher/hasher/entalpic.py b/src/material_hasher/hasher/entalpic.py index 3d15006..71faa99 100644 --- a/src/material_hasher/hasher/entalpic.py +++ b/src/material_hasher/hasher/entalpic.py @@ -13,6 +13,7 @@ class EntalpicMaterialsHasher(HasherBase): Returns hash based on bonding graph structure, composition, and symmetry. """ + def __init__( self, graphing_algorithm: str = "WL", diff --git a/src/material_hasher/hasher/pdd.py b/src/material_hasher/hasher/pdd.py index 555f3e6..a960e80 100644 --- a/src/material_hasher/hasher/pdd.py +++ b/src/material_hasher/hasher/pdd.py @@ -77,9 +77,7 @@ def get_material_hash(self, structure: Structure) -> str: """ periodic_set = self.periodicset_from_structure(structure) - pdd = PDD( - periodic_set, int(self.cutoff), collapse=False - ) + pdd = PDD(periodic_set, int(self.cutoff), collapse=False) # Round the PDD values to 4 decimal places for numerical stability and consistency. pdd = np.round(pdd, decimals=4) diff --git a/src/material_hasher/hasher/slices.py b/src/material_hasher/hasher/slices.py index 69ddc3c..e451f7c 100644 --- a/src/material_hasher/hasher/slices.py +++ b/src/material_hasher/hasher/slices.py @@ -4,11 +4,14 @@ # uv pip install -r requirements_slices.txt +import tensorflow as tf from pymatgen.core.structure import Structure from slices.core import SLICES from material_hasher.hasher.base import HasherBase +tf.get_logger().setLevel("ERROR") + class SLICESHasher(HasherBase): def __init__(self): @@ -32,4 +35,3 @@ def get_material_hash(self, structure: Structure) -> str: The SLICES string representation of the structure. """ return self.backend.structure2SLICES(structure) - diff --git a/src/material_hasher/hasher/utils/graph_structure.py b/src/material_hasher/hasher/utils/graph_structure.py index 349841d..2c4723b 100644 --- a/src/material_hasher/hasher/utils/graph_structure.py +++ b/src/material_hasher/hasher/utils/graph_structure.py @@ -1,8 +1,7 @@ # Copyright 2025 Entalpic -from pymatgen.analysis.graphs import StructureGraph +from networkx import Graph from pymatgen.analysis.local_env import EconNN, NearNeighbors from pymatgen.core import Structure -from networkx import Graph def get_structure_graph( @@ -23,10 +22,8 @@ class to build bonded structure. Defaults to EconNN. Returns: Graph: networkx Graph object """ - structure_graph = StructureGraph.with_local_env_strategy( - structure=structure, - strategy=bonding_algorithm(**bonding_kwargs), - ) + bonding = bonding_algorithm(**bonding_kwargs) + structure_graph = bonding.get_bonded_structure(structure) for n, site in zip(range(len(structure)), structure): structure_graph.graph.nodes[n]["specie"] = site.specie.name for edge in structure_graph.graph.edges: diff --git a/src/material_hasher/similarity/structure_matchers.py b/src/material_hasher/similarity/structure_matchers.py index e39036d..1a699c9 100644 --- a/src/material_hasher/similarity/structure_matchers.py +++ b/src/material_hasher/similarity/structure_matchers.py @@ -33,12 +33,18 @@ def is_equivalent( First structure to compare. structure2 : Structure Second structure to compare. + threshold : Optional[float] + Optional threshold to override the default tolerance. Returns ------- bool True if the two structures are similar, False otherwise. """ + if threshold is not None: + # Create a temporary matcher with the new threshold + temp_matcher = StructureMatcher(ltol=threshold) + return temp_matcher.fit(structure1, structure2) return self.matcher.fit(structure1, structure2) def get_similarity_score( @@ -60,7 +66,11 @@ def get_similarity_score( float Similarity score between the two structures. """ - return self.matcher.get_rms_dist(structure1, structure2) + # RMS displacement is normalized by (Vol / nsites) ** (1/3) in PMG + distance = self.matcher.get_rms_dist(structure1, structure2) + if distance is None: # No alignment found + return 0.0 + return 1.0 - distance[0] def get_pairwise_equivalence( self, structures: list[Structure], threshold: Optional[float] = None diff --git a/tests/benchmark/test_dataset_benchmarks.py b/tests/benchmark/test_dataset_benchmarks.py new file mode 100644 index 0000000..ca1fb4f --- /dev/null +++ b/tests/benchmark/test_dataset_benchmarks.py @@ -0,0 +1,51 @@ +import pytest +from datasets import Dataset +from material_hasher.benchmark.disordered import download_disordered_structures +from material_hasher.benchmark.run_transformations import get_data_from_hugging_face + + +@pytest.fixture +def small_test_dataset(): + """Create a small synthetic dataset for testing benchmark functions""" + data = { + "lattice_vectors": [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], + "species_at_sites": [["Si"]], + "cartesian_site_positions": [[[0, 0, 0]]], + } + return Dataset.from_dict(data) + + +def test_hugging_face_data_loading(small_test_dataset, monkeypatch): + """Test that data loading works with a small test dataset""" + + def mock_load(*args, **kwargs): + return small_test_dataset + + # Mock the dataset loading + monkeypatch.setattr( + "material_hasher.benchmark.run_transformations.load_dataset", mock_load + ) + + structures = get_data_from_hugging_face(n_test_elements=1) + assert len(structures) == 1 + assert structures[0].formula == "Si1" + + +@pytest.mark.integration +def test_download_transformations_dataset(): + """ + Download the HF dataset + """ + data = get_data_from_hugging_face( + n_test_elements=2, n_rows=2 + ) # Use small number for test + assert len(data) == 2 + + +@pytest.mark.integration +def test_download_disordered_structures(): + """ + Download the HF dataset + """ + structures = download_disordered_structures() + assert len(structures) > 0 diff --git a/tests/benchmark/test_run_disordered.py b/tests/benchmark/test_run_disordered.py new file mode 100644 index 0000000..4b3f42a --- /dev/null +++ b/tests/benchmark/test_run_disordered.py @@ -0,0 +1,164 @@ +import numpy as np +import pandas as pd +import pytest +from material_hasher.benchmark.run_disordered import ( + benchmark_disordered_structures, + run_group_structures_benchmark, +) +from material_hasher.hasher.base import HasherBase +from pymatgen.core import Lattice, Structure + + +class DummyConstantHasher(HasherBase): + """A dummy hasher that always returns the same hash, simulating complete insensitivity.""" + + def get_material_hash(self, structure: Structure) -> str: + return "constant_hash" + + def is_equivalent(self, structure1: Structure, structure2: Structure) -> bool: + return True + + +class DummyRandomHasher(HasherBase): + """A dummy hasher that always returns different hashes, simulating complete sensitivity.""" + + def get_material_hash(self, structure: Structure) -> str: + return f"random_hash_{np.random.rand()}" + + def is_equivalent(self, structure1: Structure, structure2: Structure) -> bool: + return False + + +@pytest.fixture +def dummy_structures(): + """Create a list of simple cubic structures for testing.""" + lattice = Lattice.cubic(1.0) + structures = [] + for i in range(5): + coords = [[0.0, 0.0, float(i) / 10]] # Slightly different z coordinates + structures.append(Structure(lattice, ["Si"], coords)) + return structures + + +@pytest.fixture +def constant_hasher(): + """Create a constant hasher instance.""" + return DummyConstantHasher() + + +@pytest.fixture +def random_hasher(): + """Create a random hasher instance.""" + return DummyRandomHasher() + + +def test_run_group_structures_benchmark_constant(dummy_structures, constant_hasher): + """Test group structures benchmark with constant hasher.""" + metrics = run_group_structures_benchmark( + constant_hasher, + "test_group", + dummy_structures, + n_pick_random=3, + n_random_structures=2, + seeds=[0], + ) + + # Constant hasher should always consider structures equivalent + assert len(metrics["success_rate"]) == 1 # One seed + assert metrics["success_rate"][0] == 1.0 + + +def test_run_group_structures_benchmark_random(dummy_structures, random_hasher): + """Test group structures benchmark with random hasher.""" + metrics = run_group_structures_benchmark( + random_hasher, + "test_group", + dummy_structures, + n_pick_random=3, + n_random_structures=2, + seeds=[0], + ) + + # Random hasher should never consider structures equivalent + assert len(metrics["success_rate"]) == 1 # One seed + assert metrics["success_rate"][0] == 0.0 + + +def test_run_group_structures_benchmark_small_group(dummy_structures, constant_hasher): + """Test group structures benchmark with a small group (less than n_pick_random).""" + small_group = dummy_structures[:2] # Only use 2 structures + metrics = run_group_structures_benchmark( + constant_hasher, + "small_group", + small_group, + n_pick_random=3, # Larger than group size + n_random_structures=2, + seeds=[0], + ) + + assert len(metrics["success_rate"]) == 1 + assert metrics["success_rate"][0] == 1.0 + + +@pytest.fixture +def mock_disordered_structures(dummy_structures): + """Create a mock dictionary of disordered structures for testing.""" + return { + "group1": dummy_structures[:3], + "group2": dummy_structures[3:], + } + + +# Mock the download_disordered_structures function +@pytest.fixture +def mock_download(monkeypatch, mock_disordered_structures): + """Mock the download_disordered_structures function.""" + + def mock_download_fn(): + return mock_disordered_structures + + monkeypatch.setattr( + "material_hasher.benchmark.run_disordered.download_disordered_structures", + mock_download_fn, + ) + + +def test_benchmark_disordered_structures(mock_download, constant_hasher): + """Test the full benchmark with mocked data.""" + df_results, total_time = benchmark_disordered_structures( + constant_hasher, + seeds=[0, 1], # Use two seeds for testing + ) + + # Check that we have results for all groups + assert isinstance(df_results, pd.DataFrame) + assert "group1" in df_results.index + assert "group2" in df_results.index + assert "dissimilar_case" in df_results.index + + # Check that total time is recorded + assert "total_time (s)" in df_results.index + assert isinstance(total_time, float) + assert total_time > 0 + + # For constant hasher, all success rates should be 1.0, except for dissimilar_case + for idx in ["group1", "group2", "dissimilar_case"]: + if idx == "dissimilar_case": + assert all(rate == 0.0 for rate in df_results.loc[idx, "success_rate"]) + else: + assert all(rate == 1.0 for rate in df_results.loc[idx, "success_rate"]) + + +def test_benchmark_disordered_structures_random(mock_download, random_hasher): + """Test the full benchmark with random hasher.""" + df_results, total_time = benchmark_disordered_structures( + random_hasher, + seeds=[0], # Use only one seed for testing + ) + + # For random hasher, all success rates should be 0.0 + for idx in ["group1", "group2", "dissimilar_case"]: + if idx == "dissimilar_case": # We have a 1 chance to generate different hashes + assert all(rate == 1.0 for rate in df_results.loc[idx, "success_rate"]) + else: + assert all(rate == 0.0 for rate in df_results.loc[idx, "success_rate"]) diff --git a/tests/benchmark/test_run_transformations.py b/tests/benchmark/test_run_transformations.py new file mode 100644 index 0000000..7fb3bf7 --- /dev/null +++ b/tests/benchmark/test_run_transformations.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +from material_hasher.benchmark.run_transformations import ( + hasher_sensitivity, + mean_sensitivity, + sensitivity_over_parameter_range, +) +from material_hasher.hasher.base import HasherBase +from pymatgen.core import Lattice, Structure + + +class DummyConstantHasher(HasherBase): + """A dummy hasher that always returns the same hash, simulating complete insensitivity to transformations.""" + + def get_material_hash(self, structure: Structure) -> str: + return "constant_hash" + + +class DummyRandomHasher(HasherBase): + """A dummy hasher that always returns different hashes, simulating complete sensitivity to transformations.""" + + def get_material_hash(self, structure: Structure) -> str: + return f"random_hash_{np.random.rand()}" + + +@pytest.fixture +def dummy_structure(): + """Create a simple cubic structure for testing.""" + lattice = Lattice.cubic(1.0) + coords = [[0.0, 0.0, 0.0]] + return Structure(lattice, ["Si"], coords) + + +@pytest.fixture +def constant_hasher(): + """Create a constant hasher instance.""" + return DummyConstantHasher() + + +@pytest.fixture +def random_hasher(): + """Create a random hasher instance.""" + return DummyRandomHasher() + + +def test_hasher_sensitivity_constant(dummy_structure, constant_hasher): + """Test that constant hasher always returns 1.0 sensitivity.""" + transformed_structures = [dummy_structure.copy(), dummy_structure.copy()] + sensitivity = hasher_sensitivity( + dummy_structure, transformed_structures, constant_hasher + ) + assert sensitivity == 1.0 + + +def test_hasher_sensitivity_random(dummy_structure, random_hasher): + """Test that random hasher returns very low sensitivity.""" + transformed_structures = [dummy_structure.copy(), dummy_structure.copy()] + sensitivity = hasher_sensitivity( + dummy_structure, transformed_structures, random_hasher + ) + assert sensitivity == 0.0 + + +def test_mean_sensitivity(dummy_structure, constant_hasher, random_hasher): + """Test mean sensitivity calculation with both hashers.""" + structures = [dummy_structure.copy() for _ in range(3)] + test_case = "gaussian_noise" + parameter = ("sigma", 0.0001) + + # Test constant hasher + mean_sens_constant = mean_sensitivity( + structures, test_case, parameter, constant_hasher + ) + assert mean_sens_constant == 1.0 + + # Test random hasher + mean_sens_random = mean_sensitivity(structures, test_case, parameter, random_hasher) + assert mean_sens_random == 0.0 + + +def test_sensitivity_over_parameter_range( + dummy_structure, constant_hasher, random_hasher +): + """Test sensitivity over parameter range with both hashers.""" + structures = [dummy_structure.copy() for _ in range(3)] + test_case = "gaussian_noise" + + # Test constant hasher + results_constant = sensitivity_over_parameter_range( + structures, test_case, constant_hasher + ) + assert all(value == 1.0 for value in results_constant.values()) + + # Test random hasher + results_random = sensitivity_over_parameter_range( + structures, test_case, random_hasher + ) + assert all(value == 0.0 for value in results_random.values()) diff --git a/tests/benchmark/test_transformations.py b/tests/benchmark/test_transformations.py new file mode 100644 index 0000000..aee0844 --- /dev/null +++ b/tests/benchmark/test_transformations.py @@ -0,0 +1,136 @@ +import numpy as np +import pytest +from material_hasher.benchmark.transformations import ( + get_new_structure_with_gaussian_noise, + get_new_structure_with_isometric_strain, + get_new_structure_with_strain, + get_new_structure_with_symm_ops, + get_new_structure_with_translation, +) +from pymatgen.core import Lattice, Structure + + +@pytest.fixture +def simple_cubic_structure(): + """Create a simple cubic structure for testing.""" + lattice = Lattice.cubic(1.0) + coords = [[0.0, 0.0, 0.0]] + return Structure(lattice, ["Si"], coords) + + +@pytest.fixture +def complex_structure(): + """Create a more complex structure with multiple atoms for testing.""" + lattice = Lattice.cubic(5.0) + coords = [ + [0.0, 0.0, 0.0], # Si at origin + [0.5, 0.2, 0.3], # O offset from center + [0.25, 0.4, 0.1], # H creating asymmetry + ] + species = ["Si", "O", "H"] + return Structure(lattice, species, coords) + + +def test_gaussian_noise_transformation(simple_cubic_structure): + """Test gaussian noise transformation.""" + # Test with small noise + sigma = 0.0001 + transformed = get_new_structure_with_gaussian_noise(simple_cubic_structure, sigma) + + # Check that structure is modified but not too much + assert transformed != simple_cubic_structure + assert np.allclose( + transformed.cart_coords, simple_cubic_structure.cart_coords, atol=sigma * 10 + ) + + # Test with larger noise + sigma = 0.1 + transformed = get_new_structure_with_gaussian_noise(simple_cubic_structure, sigma) + assert not np.allclose( + transformed.cart_coords, simple_cubic_structure.cart_coords, atol=sigma / 10 + ) + + +def test_isometric_strain_transformation(simple_cubic_structure): + """Test isometric strain transformation.""" + # Test expansion + pct = 1.1 + transformed = get_new_structure_with_isometric_strain(simple_cubic_structure, pct) + assert transformed.volume > simple_cubic_structure.volume + assert np.isclose(transformed.volume, simple_cubic_structure.volume * pct) + + # Test compression + pct = 0.9 + transformed = get_new_structure_with_isometric_strain(simple_cubic_structure, pct) + assert transformed.volume < simple_cubic_structure.volume + assert np.isclose(transformed.volume, simple_cubic_structure.volume * pct) + + +def test_strain_transformation(simple_cubic_structure): + """Test strain transformation.""" + sigma = 0.01 + transformed = get_new_structure_with_strain(simple_cubic_structure, sigma) + + # Check that structure is modified + assert transformed != simple_cubic_structure + + # Check that volume is changed (strain should affect volume) + assert transformed.volume != simple_cubic_structure.volume + + # Check that atomic positions are still valid + assert all( + coord >= 0 and coord <= 1 for site in transformed.frac_coords for coord in site + ) + + +def test_translation_transformation(complex_structure): + """Test translation transformation.""" + sigma = 0.1 + transformed = get_new_structure_with_translation(complex_structure, sigma) + + # Check that structure is modified + assert transformed != complex_structure + + # Check that relative distances between atoms are preserved + original_distances = complex_structure.distance_matrix + transformed_distances = transformed.distance_matrix + assert np.allclose(original_distances, transformed_distances) + + +def test_symm_ops_transformation(complex_structure): + """Test symmetry operations transformation.""" + transformed_structures = get_new_structure_with_symm_ops( + complex_structure, "all_symmetries_found" + ) + + # Check that we get multiple structures + assert len(transformed_structures) > 0 + + # Check that all transformed structures have the same number of sites + assert all(len(s) == len(complex_structure) for s in transformed_structures) + + # Check that all transformed structures have the same volume + assert all( + np.isclose(s.volume, complex_structure.volume) for s in transformed_structures + ) + + # Check that the structures are actually different + if len(transformed_structures) > 1: + assert not np.allclose( + transformed_structures[0].cart_coords, transformed_structures[1].cart_coords + ) + + +def test_edge_cases(): + """Test edge cases for transformations.""" + + # Test with zero parameters + structure = Structure(Lattice.cubic(1.0), ["Si"], [[0, 0, 0]]) + + # Zero noise should return effectively the same structure + transformed = get_new_structure_with_gaussian_noise(structure, 0.0) + assert np.allclose(transformed.cart_coords, structure.cart_coords) + + # Zero strain should return the same volume + transformed = get_new_structure_with_isometric_strain(structure, 1.0) + assert np.isclose(transformed.volume, structure.volume) diff --git a/tests/hasher/test_base.py b/tests/hasher/test_base.py new file mode 100644 index 0000000..d9fa397 --- /dev/null +++ b/tests/hasher/test_base.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest +from material_hasher.hasher.base import HasherBase +from pymatgen.core import Lattice, Structure + + +class SimpleHasher(HasherBase): + """A simple concrete implementation of HasherBase for testing.""" + + def get_material_hash(self, structure: Structure) -> str: + """Simple hash based on number of sites and species.""" + species = sorted([str(site.specie) for site in structure]) + return f"{len(structure)}_{'-'.join(species)}" + + +@pytest.fixture +def simple_hasher(): + return SimpleHasher() + + +@pytest.fixture +def simple_structure(): + lattice = Lattice.cubic(4.0) + species = ["Si", "Si"] + coords = [[0, 0, 0], [0.5, 0.5, 0.5]] + return Structure(lattice, species, coords) + + +@pytest.fixture +def different_structure(): + lattice = Lattice.cubic(4.0) + species = ["Si", "Ge"] + coords = [[0, 0, 0], [0.5, 0.5, 0.5]] + return Structure(lattice, species, coords) + + +def test_get_material_hash(simple_hasher, simple_structure): + """Test that get_material_hash returns expected format.""" + hash_value = simple_hasher.get_material_hash(simple_structure) + assert isinstance(hash_value, str) + assert hash_value == "2_Si-Si" + + +def test_is_equivalent_same_structure(simple_hasher, simple_structure): + """Test that identical structures are equivalent.""" + assert simple_hasher.is_equivalent(simple_structure, simple_structure) + + +def test_is_equivalent_different_structures( + simple_hasher, simple_structure, different_structure +): + """Test that different structures are not equivalent.""" + assert not simple_hasher.is_equivalent(simple_structure, different_structure) + + +def test_get_materials_hashes(simple_hasher, simple_structure, different_structure): + """Test getting hashes for multiple structures.""" + structures = [simple_structure, different_structure] + hashes = simple_hasher.get_materials_hashes(structures) + + assert len(hashes) == 2 + assert hashes[0] == "2_Si-Si" + assert hashes[1] == "2_Ge-Si" + + +def test_get_pairwise_equivalence(simple_hasher, simple_structure, different_structure): + """Test pairwise equivalence matrix generation.""" + structures = [simple_structure, different_structure] + equivalence_matrix = simple_hasher.get_pairwise_equivalence(structures) + + expected = np.array([[True, False], [False, True]]) + + assert isinstance(equivalence_matrix, np.ndarray) + assert equivalence_matrix.shape == (2, 2) + assert np.array_equal(equivalence_matrix, expected) + + +def test_get_pairwise_equivalence_single_structure(simple_hasher, simple_structure): + """Test pairwise equivalence matrix for a single structure.""" + structures = [simple_structure] + equivalence_matrix = simple_hasher.get_pairwise_equivalence(structures) + + expected = np.array([[True]]) + + assert isinstance(equivalence_matrix, np.ndarray) + assert equivalence_matrix.shape == (1, 1) + assert np.array_equal(equivalence_matrix, expected) + + +def test_get_pairwise_equivalence_symmetry( + simple_hasher, simple_structure, different_structure +): + """Test that the equivalence matrix is symmetric.""" + structures = [simple_structure, different_structure] + equivalence_matrix = simple_hasher.get_pairwise_equivalence(structures) + + assert np.array_equal(equivalence_matrix, equivalence_matrix.T) diff --git a/tests/hasher/test_bawl.py b/tests/hasher/test_bawl.py new file mode 100644 index 0000000..bedc75c --- /dev/null +++ b/tests/hasher/test_bawl.py @@ -0,0 +1,134 @@ +import pytest +from material_hasher.hasher.entalpic import ( + EntalpicMaterialsHasher, + ShortenedEntalpicMaterialsHasher, +) +from pymatgen.analysis.local_env import VoronoiNN +from pymatgen.core import Lattice, Structure + + +@pytest.fixture +def entalpic_hasher(): + return EntalpicMaterialsHasher() + + +@pytest.fixture +def shortened_hasher(): + return ShortenedEntalpicMaterialsHasher() + + +@pytest.fixture +def si_structure(): + lattice = Lattice.cubic(5.43) + species = ["Si"] * 8 + coords = [ + [0.0, 0.0, 0.0], + [0.25, 0.25, 0.25], + [0.0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0.0, 0.5], + [0.75, 0.25, 0.75], + [0.5, 0.5, 0.0], + [0.75, 0.75, 0.25], + ] + return Structure(lattice, species, coords) + + +@pytest.fixture +def ge_structure(): + lattice = Lattice.cubic(5.43) + species = ["Ge"] * 8 + coords = [ + [0.0, 0.0, 0.0], + [0.25, 0.25, 0.25], + [0.0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0.0, 0.5], + [0.75, 0.25, 0.75], + [0.5, 0.5, 0.0], + [0.75, 0.75, 0.25], + ] + return Structure(lattice, species, coords) + + +def test_entalpic_hasher_initialization(): + """Test different initialization parameters.""" + # Test with different bonding algorithm + hasher1 = EntalpicMaterialsHasher( + bonding_algorithm=VoronoiNN, bonding_kwargs={"tol": 0.1} + ) + assert isinstance(hasher1.bonding_algorithm, type) + assert hasher1.bonding_algorithm == VoronoiNN + + # Test with different bonding kwargs + hasher2 = EntalpicMaterialsHasher(bonding_kwargs={"tol": 0.1}) + assert hasher2.bonding_kwargs == {"tol": 0.1} + + # Test with composition disabled + hasher3 = EntalpicMaterialsHasher(include_composition=False) + assert not hasher3.include_composition + + +def test_entalpic_hash_components(entalpic_hasher, si_structure): + """Test that all hash components are present.""" + data = entalpic_hasher.get_entalpic_materials_data(si_structure) + + assert "bonding_graph_hash" in data + assert "symmetry_label" in data + assert "composition" in data + + # Test types + assert isinstance(data["bonding_graph_hash"], str) + assert isinstance(data["symmetry_label"], int) + assert isinstance(data["composition"], str) + + +def test_shortened_vs_full_hash(shortened_hasher, entalpic_hasher, si_structure): + """Test that shortened hash is different from full hash.""" + short_hash = shortened_hasher.get_material_hash(si_structure) + full_hash = entalpic_hasher.get_material_hash(si_structure) + + assert short_hash != full_hash + assert len(short_hash.split("_")) < len(full_hash.split("_")) + + +def test_hash_consistency(entalpic_hasher, si_structure): + """Test that the same structure always gets the same hash.""" + hash1 = entalpic_hasher.get_material_hash(si_structure) + hash2 = entalpic_hasher.get_material_hash(si_structure) + + assert hash1 == hash2 + + +def test_equivalent_structures(entalpic_hasher, si_structure): + """Test that equivalent structures are identified as such.""" + # Create a shifted version of the same structure + shifted_structure = si_structure.copy() + shifted_structure.translate_sites(range(len(shifted_structure)), [0.5, 0.5, 0.5]) + + assert entalpic_hasher.is_equivalent(si_structure, shifted_structure) + + +def test_symmetry_detection(entalpic_hasher, si_structure): + """Test that symmetry is correctly detected.""" + data = entalpic_hasher.get_entalpic_materials_data(si_structure) + + # Silicon has space group 227 (Fd-3m) + assert data["symmetry_label"] == 227 + + +def test_different_materials_have_different_hashes( + entalpic_hasher, si_structure, ge_structure +): + """Test that different materials have different hashes.""" + assert entalpic_hasher.get_material_hash( + si_structure + ) != entalpic_hasher.get_material_hash(ge_structure) + + noised_ge_structure = ge_structure.copy() + + # perturb the structure by 0.5 Angstrom (a lot!) + noised_ge_structure.perturb(0.5) + assert entalpic_hasher.get_material_hash( + si_structure + ) != entalpic_hasher.get_material_hash(noised_ge_structure) diff --git a/tests/hasher/test_example.py b/tests/hasher/test_example.py new file mode 100644 index 0000000..f8a8af9 --- /dev/null +++ b/tests/hasher/test_example.py @@ -0,0 +1,87 @@ +import pytest +from material_hasher.hasher.example import SimpleCompositionHasher +from pymatgen.core import Lattice, Structure + + +@pytest.fixture +def composition_hasher(): + return SimpleCompositionHasher() + + +@pytest.fixture +def si_structure(): + lattice = Lattice.cubic(5.43) + species = ["Si"] * 8 + coords = [ + [0.0, 0.0, 0.0], + [0.25, 0.25, 0.25], + [0.0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0.0, 0.5], + [0.75, 0.25, 0.75], + [0.5, 0.5, 0.0], + [0.75, 0.75, 0.25], + ] + return Structure(lattice, species, coords) + + +@pytest.fixture +def sio2_structure(): + lattice = Lattice.hexagonal(4.91, 5.40) + species = ["Si"] * 3 + ["O"] * 6 + coords = [ + [0.4697, 0.0000, 0.0000], + [0.0000, 0.4697, 0.6667], + [0.5303, 0.5303, 0.3333], + [0.4135, 0.2669, 0.1191], + [0.2669, 0.4135, 0.5476], + [0.7331, 0.1466, 0.7857], + [0.5865, 0.8534, 0.8810], + [0.8534, 0.5865, 0.4524], + [0.1466, 0.7331, 0.2143], + ] + return Structure(lattice, species, coords) + + +def test_get_material_hash_si(composition_hasher, si_structure): + """Test that silicon structure returns Si as hash.""" + hash_value = composition_hasher.get_material_hash(si_structure) + assert hash_value == "Si" + + +def test_get_material_hash_sio2(composition_hasher, sio2_structure): + """Test that silicon dioxide structure returns SiO2 as hash.""" + hash_value = composition_hasher.get_material_hash(sio2_structure) + assert hash_value == "SiO2" + + +def test_is_equivalent_same_composition(composition_hasher, sio2_structure): + """Test that structures with same composition are equivalent.""" + # Create a different SiO2 structure with same composition + lattice = Lattice.cubic(4.91) + species = ["Si"] * 3 + ["O"] * 6 + coords = [ + [x / 8, y / 8, z / 8] for x in range(3) for y in range(3) for z in range(3) + ][:9] + different_sio2 = Structure(lattice, species, coords) + + assert composition_hasher.is_equivalent(sio2_structure, different_sio2) + + +def test_is_equivalent_different_composition( + composition_hasher, si_structure, sio2_structure +): + """Test that structures with different compositions are not equivalent.""" + assert not composition_hasher.is_equivalent(si_structure, sio2_structure) + + +def test_get_materials_hashes_multiple( + composition_hasher, si_structure, sio2_structure +): + """Test getting hashes for multiple structures.""" + structures = [si_structure, sio2_structure] + hashes = composition_hasher.get_materials_hashes(structures) + + assert len(hashes) == 2 + assert hashes[0] == "Si" + assert hashes[1] == "SiO2" diff --git a/tests/hasher/test_pdd.py b/tests/hasher/test_pdd.py new file mode 100644 index 0000000..c3c805f --- /dev/null +++ b/tests/hasher/test_pdd.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest +from amd import PeriodicSet +from material_hasher.hasher.pdd import PointwiseDistanceDistributionHasher +from pymatgen.core import Lattice, Structure + + +@pytest.fixture +def pdd_hasher(): + return PointwiseDistanceDistributionHasher(cutoff=10.0) + + +@pytest.fixture +def si_structure(): + lattice = Lattice.cubic(5.43) + species = ["Si"] * 8 + coords = [ + [0.0, 0.0, 0.0], + [0.25, 0.25, 0.25], + [0.0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0.0, 0.5], + [0.75, 0.25, 0.75], + [0.5, 0.5, 0.0], + [0.75, 0.75, 0.25], + ] + return Structure(lattice, species, coords) + + +def test_periodicset_conversion(pdd_hasher, si_structure): + """Test conversion from pymatgen Structure to PeriodicSet.""" + periodic_set = pdd_hasher.periodicset_from_structure(si_structure) + + assert isinstance(periodic_set, PeriodicSet) + assert len(periodic_set.motif) == len(si_structure) + assert all(n == 14 for n in periodic_set.types) # Si atomic number is 14 + + +def test_hash_is_string(pdd_hasher, si_structure): + """Test that the hash is a string with expected format.""" + hash_value = pdd_hasher.get_material_hash(si_structure) + + assert isinstance(hash_value, str) + assert len(hash_value) == 64 # SHA256 hash length + + +def test_hash_consistency(pdd_hasher, si_structure): + """Test that the same structure always gets the same hash.""" + hash1 = pdd_hasher.get_material_hash(si_structure) + hash2 = pdd_hasher.get_material_hash(si_structure) + + assert hash1 == hash2 + + +def test_equivalent_structures(pdd_hasher, si_structure): + """Test that equivalent structures are identified as such.""" + # Create a shifted version of the same structure + shifted_structure = si_structure.copy() + shifted_structure.translate_sites(range(len(shifted_structure)), [0.5, 0.5, 0.5]) + + assert pdd_hasher.is_equivalent(si_structure, shifted_structure) + + +def test_different_cutoffs(si_structure): + """Test that significantly different cutoffs give different hashes.""" + hasher1 = PointwiseDistanceDistributionHasher(cutoff=5.0) + hasher2 = PointwiseDistanceDistributionHasher(cutoff=10.0) + + hash1 = hasher1.get_material_hash(si_structure) + hash2 = hasher2.get_material_hash(si_structure) + + assert hash1 != hash2 + + +def test_empty_structure(): + """Test that empty structure raises ValueError.""" + hasher = PointwiseDistanceDistributionHasher() + empty_structure = Structure(Lattice.cubic(1.0), [], []) + + with pytest.raises(ValueError): + hasher.periodicset_from_structure(empty_structure) + + +def test_pairwise_equivalence(pdd_hasher, si_structure): + """Test pairwise equivalence matrix generation.""" + # Create a different structure + different_structure = Structure( + Lattice.cubic(5.43), + ["Si"] * 4, + [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ) + + structures = [si_structure, different_structure] + equivalence_matrix = pdd_hasher.get_pairwise_equivalence(structures) + + expected = np.array([[True, False], [False, True]]) + + assert isinstance(equivalence_matrix, np.ndarray) + assert equivalence_matrix.shape == (2, 2) + assert np.array_equal(equivalence_matrix, expected) diff --git a/tests/hasher/test_slices.py b/tests/hasher/test_slices.py new file mode 100644 index 0000000..bb83c07 --- /dev/null +++ b/tests/hasher/test_slices.py @@ -0,0 +1,98 @@ +import pytest +from pymatgen.core import Lattice, Structure + +# Try to import SLICES, skip tests if not available +slices_available = False +try: + from material_hasher.hasher.slices import SLICESHasher + + slices_available = True +except ImportError: + pass + +# Skip all tests if SLICES is not available +pytestmark = pytest.mark.skipif( + not slices_available, + reason="SLICES package not installed. Install with 'uv pip install -r requirements_slices.txt'", +) + + +@pytest.fixture +def slices_hasher(): + return SLICESHasher() + + +@pytest.fixture +def si_structure(): + lattice = Lattice.cubic(5.43) + species = ["Si"] * 8 + coords = [ + [0.0, 0.0, 0.0], + [0.25, 0.25, 0.25], + [0.0, 0.5, 0.5], + [0.25, 0.75, 0.75], + [0.5, 0.0, 0.5], + [0.75, 0.25, 0.75], + [0.5, 0.5, 0.0], + [0.75, 0.75, 0.25], + ] + return Structure(lattice, species, coords) + + +def test_hash_is_string(slices_hasher, si_structure): + """Test that the hash is a string.""" + hash_value = slices_hasher.get_material_hash(si_structure) + assert isinstance(hash_value, str) + + +def test_hash_consistency(slices_hasher, si_structure): + """Test that the same structure gets the same hash.""" + hash1 = slices_hasher.get_material_hash(si_structure) + hash2 = slices_hasher.get_material_hash(si_structure) + + assert hash1 == hash2 + + +# TODO(Ramlaoui): This does not pass with SLICES is it expected? +# cf paper? +# def test_equivalent_structures(slices_hasher, si_structure): +# """Test that equivalent structures are identified as such.""" +# # Create a shifted version of the same structure +# shifted_structure = si_structure.copy() +# shifted_structure.translate_sites(range(len(shifted_structure)), [0.5, 0.5, 0.5]) + +# assert slices_hasher.is_equivalent(si_structure, shifted_structure) + + +def test_different_structures(slices_hasher, si_structure): + """Test that different structures get different hashes.""" + # Create a different structure (different number of sites) + different_structure = Structure( + Lattice.cubic(5.43), + ["Si"] * 4, + [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ) + + hash1 = slices_hasher.get_material_hash(si_structure) + hash2 = slices_hasher.get_material_hash(different_structure) + + assert hash1 != hash2 + + +def test_pairwise_equivalence(slices_hasher, si_structure): + """Test pairwise equivalence matrix generation.""" + # Create a different structure + different_structure = Structure( + Lattice.cubic(5.43), + ["Si"] * 4, + [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ) + + structures = [si_structure, different_structure] + equivalence_matrix = slices_hasher.get_pairwise_equivalence(structures) + + # Structures should only be equivalent to themselves + assert equivalence_matrix[0, 0] # si_structure with itself + assert equivalence_matrix[1, 1] # different_structure with itself + assert not equivalence_matrix[0, 1] # si_structure with different_structure + assert not equivalence_matrix[1, 0] # different_structure with si_structure diff --git a/tests/similarity/test_base_similarity.py b/tests/similarity/test_base_similarity.py new file mode 100644 index 0000000..5d31950 --- /dev/null +++ b/tests/similarity/test_base_similarity.py @@ -0,0 +1,178 @@ +import numpy as np +import pytest +from material_hasher.similarity.base import SimilarityMatcherBase +from pymatgen.core.lattice import Lattice +from pymatgen.core.structure import Structure + + +class PerfectMatcher(SimilarityMatcherBase): + """A dummy matcher that only considers structures equivalent if they're identical.""" + + def get_similarity_score( + self, structure1: Structure, structure2: Structure + ) -> float: + """Return 0.0 for identical structures, 1.0 otherwise.""" + return 0.0 if structure1 == structure2 else 1.0 + + def is_equivalent( + self, structure1: Structure, structure2: Structure, threshold=None + ) -> bool: + """Return True only for identical structures.""" + return structure1 == structure2 + + def get_pairwise_equivalence( + self, structures: list[Structure], threshold=None + ) -> np.ndarray: + """Get pairwise equivalence matrix.""" + n = len(structures) + matrix = np.zeros((n, n), dtype=bool) + for i in range(n): + for j in range(i, n): + matrix[i, j] = self.is_equivalent(structures[i], structures[j]) + + matrix = matrix | matrix.T + return matrix + + +class DistanceBasedMatcher(SimilarityMatcherBase): + """A dummy matcher that uses lattice parameter differences for similarity.""" + + def __init__(self, threshold: float = 0.1): + self.threshold = threshold + + def get_similarity_score( + self, structure1: Structure, structure2: Structure + ) -> float: + """Return the relative difference in lattice parameters.""" + a1 = structure1.lattice.a + a2 = structure2.lattice.a + return abs(a1 - a2) / max(a1, a2) + + def is_equivalent( + self, structure1: Structure, structure2: Structure, threshold=None + ) -> bool: + """Return True if lattice parameters are within threshold.""" + threshold = threshold if threshold is not None else self.threshold + return self.get_similarity_score(structure1, structure2) <= threshold + + def get_pairwise_equivalence( + self, structures: list[Structure], threshold=None + ) -> np.ndarray: + """Get pairwise equivalence matrix.""" + n = len(structures) + matrix = np.zeros((n, n), dtype=bool) + for i in range(n): + for j in range(i, n): + matrix[i, j] = self.is_equivalent( + structures[i], structures[j], threshold=threshold + ) + + matrix = matrix | matrix.T + return matrix + + +@pytest.fixture +def structures(): + """Create a list of test structures with different lattice parameters.""" + return [ + Structure(Lattice.cubic(4.0), ["Fe"], [[0, 0, 0]]), + Structure(Lattice.cubic(4.0), ["Fe"], [[0, 0, 0]]), # Identical to first + Structure(Lattice.cubic(4.1), ["Fe"], [[0, 0, 0]]), # Similar + Structure(Lattice.cubic(5.0), ["Fe"], [[0, 0, 0]]), # Different + ] + + +def test_perfect_matcher(structures): + matcher = PerfectMatcher() + + # Test similarity scores + assert matcher.get_similarity_score(structures[0], structures[1]) == 0.0 + assert matcher.get_similarity_score(structures[0], structures[2]) == 1.0 + + # Test equivalence + assert matcher.is_equivalent(structures[0], structures[1]) + assert not matcher.is_equivalent(structures[0], structures[2]) + + # Test pairwise similarity scores + scores = matcher.get_pairwise_similarity_scores(structures[:3]) + assert scores.shape == (3, 3) + assert scores[0, 1] == 0.0 # Identical structures + assert scores[0, 2] == 1.0 # Different structures + + # Test pairwise equivalence + equiv = matcher.get_pairwise_equivalence(structures[:3]) + assert equiv.shape == (3, 3) + assert equiv[0, 1] and equiv[1, 0] # Symmetric + assert not equiv[0, 2] and not equiv[2, 0] # Not equivalent + + +def test_distance_based_matcher(structures): + matcher = DistanceBasedMatcher(threshold=0.05) # 5% threshold + + # Test similarity scores + score01 = matcher.get_similarity_score(structures[0], structures[1]) + score02 = matcher.get_similarity_score(structures[0], structures[2]) + assert score01 == 0.0 # Identical + assert 0.0 < score02 < 0.1 # Similar but different + + # Test equivalence with different thresholds + assert matcher.is_equivalent( + structures[0], structures[2], threshold=0.1 + ) # Should be equivalent with 10% threshold + assert not matcher.is_equivalent( + structures[0], structures[2], threshold=0.01 + ) # Should not be equivalent with 1% threshold + + # Test pairwise similarity scores + scores = matcher.get_pairwise_similarity_scores(structures) + assert scores.shape == (4, 4) + assert np.allclose(scores, scores.T) # Symmetric + assert np.all(np.diag(scores) == 0) # Zero on diagonal + + # Test pairwise equivalence + equiv = matcher.get_pairwise_equivalence(structures, threshold=0.05) + assert equiv.shape == (4, 4) + assert np.all(np.diag(equiv)) # True on diagonal + assert equiv[0, 1] and equiv[1, 0] # Symmetric + assert not equiv[0, 3] # Far structures not equivalent + + +def test_interface_properties(): + """Test that the similarity interface maintains expected properties.""" + matcher = DistanceBasedMatcher() + structures = [ + Structure(Lattice.cubic(4.0), ["Fe"], [[0, 0, 0]]), + Structure(Lattice.cubic(4.1), ["Fe"], [[0, 0, 0]]), + ] + + # Test symmetry of similarity scores + score_ab = matcher.get_similarity_score(structures[0], structures[1]) + score_ba = matcher.get_similarity_score(structures[1], structures[0]) + assert np.isclose(score_ab, score_ba) + + # Test symmetry of equivalence + equiv_ab = matcher.is_equivalent(structures[0], structures[1]) + equiv_ba = matcher.is_equivalent(structures[1], structures[0]) + assert equiv_ab == equiv_ba + + # Test pairwise matrices are symmetric + scores = matcher.get_pairwise_similarity_scores(structures) + equiv = matcher.get_pairwise_equivalence(structures) + assert np.allclose(scores, scores.T) + assert np.array_equal(equiv, equiv.T) + + +def test_threshold_behavior(): + """Test that threshold parameter works correctly.""" + matcher = DistanceBasedMatcher(threshold=0.05) + struct1 = Structure(Lattice.cubic(4.0), ["Fe"], [[0, 0, 0]]) + struct2 = Structure(Lattice.cubic(4.04), ["Fe"], [[0, 0, 0]]) # 1% difference + + # Default threshold (5%) should make these equivalent + assert matcher.is_equivalent(struct1, struct2) + + # Stricter threshold should make them different + assert not matcher.is_equivalent(struct1, struct2, threshold=0.005) + + # More lenient threshold should keep them equivalent + assert matcher.is_equivalent(struct1, struct2, threshold=0.1) diff --git a/tests/similarity/test_structure_matchers.py b/tests/similarity/test_structure_matchers.py new file mode 100644 index 0000000..ac5245b --- /dev/null +++ b/tests/similarity/test_structure_matchers.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest +from material_hasher.similarity.structure_matchers import PymatgenStructureSimilarity +from pymatgen.core.lattice import Lattice +from pymatgen.core.structure import Structure + + +@pytest.fixture +def simple_cubic_structure(): + """Create a simple cubic structure.""" + lattice = Lattice.cubic(4.0) + species = ["Fe"] + coords = [[0, 0, 0]] + return Structure(lattice, species, coords) + + +@pytest.fixture +def slightly_distorted_cubic(): + """Create a slightly distorted cubic structure.""" + lattice = Lattice.cubic(4.0 * 1.02) + species = ["Fe"] + coords = [[0.015, 0.015, 0.015]] + return Structure(lattice, species, coords) + + +@pytest.fixture +def different_structure(): + """Create a different structure (BCC instead of simple cubic).""" + lattice = Lattice.cubic(4.0) + species = ["Fe", "Fe"] + coords = [[0, 0, 0], [0.5, 0.5, 0.5]] # BCC structure + return Structure(lattice, species, coords) + + +def test_structure_equivalence( + simple_cubic_structure, slightly_distorted_cubic, different_structure +): + matcher = PymatgenStructureSimilarity(tolerance=0.02) + + # Test similar structures + assert matcher.is_equivalent(simple_cubic_structure, slightly_distorted_cubic) + + # Test different structures + assert not matcher.is_equivalent(simple_cubic_structure, different_structure) + + # Test self-equivalence + assert matcher.is_equivalent(simple_cubic_structure, simple_cubic_structure) + + +def test_similarity_score( + simple_cubic_structure, slightly_distorted_cubic, different_structure +): + matcher = PymatgenStructureSimilarity(tolerance=0.02) + + # Test similar structures - should have a small RMSD + score_similar = matcher.get_similarity_score( + simple_cubic_structure, slightly_distorted_cubic + ) + assert score_similar > 0.9 # Small RMSD for similar structures + + # Test different structures - should have a larger RMSD + score_different = matcher.get_similarity_score( + simple_cubic_structure, different_structure + ) + assert ( + score_different < score_similar + ) # Different structures should have larger RMSD + + +def test_pairwise_equivalence( + simple_cubic_structure, slightly_distorted_cubic, different_structure +): + matcher = PymatgenStructureSimilarity(tolerance=0.02) + structures = [simple_cubic_structure, slightly_distorted_cubic, different_structure] + + matrix = matcher.get_pairwise_equivalence(structures) + + # Expected matrix shape + assert matrix.shape == (3, 3) + + # Matrix should be symmetric + assert np.array_equal(matrix, matrix.T) + + # Diagonal should be True (self-equivalence) + assert np.all(np.diag(matrix)) + + # First two structures should be equivalent + assert matrix[0, 1] and matrix[1, 0] + + # Different structure should not be equivalent to others + assert not matrix[0, 2] and not matrix[1, 2] diff --git a/uv.lock b/uv.lock index e902b06..bbdf0fd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11' and sys_platform == 'win32'", @@ -505,7 +506,7 @@ wheels = [ [[package]] name = "e3nn" -version = "0.5.4" +version = "0.5.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opt-einsum-fx" }, @@ -513,9 +514,9 @@ dependencies = [ { name = "sympy" }, { name = "torch" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/c8/59d2e0a890c8a6a8733b38dc40c48288730ced3c78e52cef3a19ebff1d4f/e3nn-0.5.4.tar.gz", hash = "sha256:66b4ea3fd1145ea11de8d91361c0367de2241999803686edeb2d374559e5f023", size = 434506 } +sdist = { url = "https://files.pythonhosted.org/packages/97/ac/ac6120e6e001677b6bbaeee3f7c40af784a1758568ae3246df7b5c3b9552/e3nn-0.5.6.tar.gz", hash = "sha256:e28aa6f67d9c090300d390f9e08fb57211d60562971ae3237e30023ff17093b1", size = 435651 } wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/74/0ddc27c458c48890e5e668879df2987e1126a60529f175843cfc0fa42b4f/e3nn-0.5.4-py3-none-any.whl", hash = "sha256:4c449f727fb72037908e2eddfb0479ef756268a29b3b0ddb00ed008f0dc638f3", size = 447203 }, + { url = "https://files.pythonhosted.org/packages/60/b4/c5101bb2a6417506241a9b1beb39ce7bce18c652bcf3ecd84d9158a4d53c/e3nn-0.5.6-py3-none-any.whl", hash = "sha256:172b450cf9c9cab0f341826e963a331675c085c502412ac9ab9d804107913c4c", size = 448027 }, ] [[package]] @@ -849,6 +850,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769 }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, +] + +[[package]] +name = "ipdb" +version = "0.13.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "decorator" }, + { name = "ipython" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/1b/7e07e7b752017f7693a0f4d41c13e5ca29ce8cbcfdcc1fd6c4ad8c0a27a0/ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726", size = 17042 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/4c/b075da0092003d9a55cf2ecc1cae9384a1ca4f650d51b00fc59875fe76f6/ipdb-0.13.13-py3-none-any.whl", hash = "sha256:45529994741c4ab6d2388bfa5d7b725c2cf7fe9deffabdb8a6113aa5ed449ed4", size = 12130 }, +] + [[package]] name = "ipython" version = "8.29.0" @@ -1137,8 +1161,10 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "ipdb" }, { name = "ipython" }, { name = "pre-commit" }, + { name = "pytest" }, { name = "ruff" }, { name = "shibuya" }, { name = "sphinx-autoapi" }, @@ -1164,8 +1190,10 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "ipdb", specifier = ">=0.13.13" }, { name = "ipython", specifier = ">=8.29.0" }, { name = "pre-commit", specifier = ">=4.0.1" }, + { name = "pytest", specifier = ">=8.3" }, { name = "ruff", specifier = ">=0.8.0" }, { name = "shibuya", specifier = ">=2024.10.15" }, { name = "sphinx-autoapi", specifier = ">=3.3.2" }, @@ -1527,7 +1555,6 @@ name = "nvidia-nccl-cu12" version = "2.20.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 }, { url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 }, ] @@ -1537,7 +1564,6 @@ version = "12.6.85" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 }, - { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338 }, ] [[package]] @@ -1825,6 +1851,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/ae/580600f441f6fc05218bd6c9d5794f4aef072a7d9093b291f1c50a9db8bc/plotly-5.24.1-py3-none-any.whl", hash = "sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089", size = 19054220 }, ] +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + [[package]] name = "pre-commit" version = "4.0.1" @@ -2175,6 +2210,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/ec/2eb3cd785efd67806c46c13a17339708ddc346cbb684eade7a6e6f79536a/pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84", size = 106921 }, ] +[[package]] +name = "pytest" +version = "8.3.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"