From 136b48e5acf7df43a75d1acc7359879973590c76 Mon Sep 17 00:00:00 2001 From: martinsiron Date: Thu, 17 Apr 2025 15:52:16 +0200 Subject: [PATCH] reduced composition and primitive structure for BAWL --- src/material_hasher/hasher/bawl.py | 33 ++++++++++++------- .../hasher/utils/graph_structure.py | 15 +++++++-- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/material_hasher/hasher/bawl.py b/src/material_hasher/hasher/bawl.py index d19a5bb..148c320 100644 --- a/src/material_hasher/hasher/bawl.py +++ b/src/material_hasher/hasher/bawl.py @@ -52,6 +52,7 @@ def __init__( bonding_kwargs: dict = {"tol": 0.2, "cutoff": 10, "use_fictive_radius": True}, include_composition: bool = True, symmetry_labeling: str = "moyo", + primitive_reduction: bool = False, shorten_hash: bool = False, ): self.graphing_algorithm = graphing_algorithm @@ -59,6 +60,7 @@ def __init__( self.bonding_kwargs = bonding_kwargs self.include_composition = include_composition self.symmetry_labeling = symmetry_labeling + self.primitive_reduction = primitive_reduction self.shorten_hash = shorten_hash def get_bawl_materials_data( @@ -90,6 +92,7 @@ def get_bawl_materials_data( structure, bonding_kwargs=self.bonding_kwargs, bonding_algorithm=self.bonding_algorithm, + primitive_reduction=self.primitive_reduction, ) data["bonding_graph_hash"] = get_weisfeiler_lehman_hash(graph) else: @@ -97,19 +100,25 @@ def get_bawl_materials_data( "Graphing algorithm {} not implemented".format(self.graphing_algorithm) ) if not self.shorten_hash: - match (self.symmetry_labeling, symmetry_label): - case (_, label) if label is not None: - data["symmetry_label"] = label - case ("AFLOW", _): - data["symmetry_label"] = AFLOWSymmetry().get_symmetry_label(structure) - case ("SPGLib", _): - data["symmetry_label"] = SPGLibSymmetry().get_symmetry_label(structure) - case ("moyo", _): - data["symmetry_label"] = MoyoSymmetry().get_symmetry_label(structure) - case (unknown, _): - raise ValueError(f"Symmetry algorithm {unknown} not implemented") + match (self.symmetry_labeling, symmetry_label): + case (_, label) if label is not None: + data["symmetry_label"] = label + case ("AFLOW", _): + data["symmetry_label"] = AFLOWSymmetry().get_symmetry_label( + structure + ) + case ("SPGLib", _): + data["symmetry_label"] = SPGLibSymmetry().get_symmetry_label( + structure + ) + case ("moyo", _): + data["symmetry_label"] = MoyoSymmetry().get_symmetry_label( + structure + ) + case (unknown, _): + raise ValueError(f"Symmetry algorithm {unknown} not implemented") if self.include_composition: - data["composition"] = structure.composition.formula.replace(" ", "") + data["composition"] = structure.composition.reduced_formula.replace(" ", "") return data def get_material_hash(self, structure: Structure) -> str: diff --git a/src/material_hasher/hasher/utils/graph_structure.py b/src/material_hasher/hasher/utils/graph_structure.py index 349841d..58cc648 100644 --- a/src/material_hasher/hasher/utils/graph_structure.py +++ b/src/material_hasher/hasher/utils/graph_structure.py @@ -3,12 +3,16 @@ from pymatgen.analysis.local_env import EconNN, NearNeighbors from pymatgen.core import Structure from networkx import Graph +from moyopy import MoyoDataset +from moyopy.interface import MoyoAdapter +import warnings def get_structure_graph( structure: Structure, bonding_kwargs: dict = {}, bonding_algorithm: NearNeighbors = EconNN, + primitive_reduction: bool = False, ) -> Graph: """Method to build networkx graph object based on bonding algorithm from Pymatgen Structure @@ -23,11 +27,18 @@ class to build bonded structure. Defaults to EconNN. Returns: Graph: networkx Graph object """ + assess_structure = ( + MoyoAdapter.get_structure( + MoyoDataset(MoyoAdapter.from_structure(structure)).prim_std_cell + ) + if primitive_reduction + else structure.copy() + ) structure_graph = StructureGraph.with_local_env_strategy( - structure=structure, + structure=assess_structure, strategy=bonding_algorithm(**bonding_kwargs), ) - for n, site in zip(range(len(structure)), structure): + for n, site in zip(range(len(assess_structure)), assess_structure): structure_graph.graph.nodes[n]["specie"] = site.specie.name for edge in structure_graph.graph.edges: structure_graph.graph.edges[edge]["voltage"] = structure_graph.graph.edges[