diff --git a/msp/structure/structure_util.py b/msp/structure/structure_util.py index 408a096..137936f 100644 --- a/msp/structure/structure_util.py +++ b/msp/structure/structure_util.py @@ -1,13 +1,128 @@ import numpy as np -from ase import Atoms +from ase import Atoms, io import torch from torch_geometric.data import Data from ase.data import chemical_symbols -import smact -from smact.screening import pauling_test +import pymatgen, pymatgen.io.ase, pymatgen +from pymatgen.core.structure import Element +from ase.data import chemical_symbols import itertools +from itertools import product, chain +from pymatgen.io.cif import CifWriter + +def atoms_to_dict(atoms, loss=None): + """ + Creates a list of dict from a list of ASE atoms objects + + Args: + atoms (list): A list of ASE atoms objects + energy (list): A list of predicted energies for each ASE atoms object. + + Returns: + list: Contains atoms represented as dicts + """ + res = [{} for _ in atoms] + for i, d in enumerate(res): + d['n_atoms'] = len(atoms[i].get_atomic_numbers()) + d['pos'] = atoms[i].get_positions() + d['cell'] = atoms[i].get_cell() + d['z'] = atoms[i].get_atomic_numbers() + d['atomic_numbers'] = atoms[i].get_atomic_numbers() + if loss is None: + d['loss'] = None + else: + d['loss'] = loss[i] + return res + +# takes a list of similar elements, base_list, and creates all possible permutations with the replacement elements +def generate_replacements(base_list, replacement): + replacements = product((base_list[0], replacement), repeat=len(base_list)) + + possibilities = [] + + for i in replacements: + possibilities.append(list(i)) + + return possibilities + +# for a given atom, returns a list of all possible replacements atoms based on common oxidation states +def generate_replacement_possibilities(atom): + element_abbreviations = [ + "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", + "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", + "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", + "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", + "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", + "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", + "Lv", "Ts", "Og" + ] + + possibilities = [] + + for el in element_abbreviations: + if len(set(Element(atom).common_oxidation_states).intersection(set(Element(el).common_oxidation_states))) > 0: + possibilities.append(el) + + return possibilities + +# recursive substitution method +def recursively_create_combinations(ls, index=0): + new_structures = [] + + replacement_possibilities = generate_replacement_possibilities(ls[index][0]) + replacements = [] + for atom in replacement_possibilities: + replacements.append(generate_replacements(ls[index], atom)) + + for x in replacements: + for y in x: + new_structure = ls.copy() + new_structure[index] = y + new_structures.append(new_structure) + + for struct in new_structures.copy(): + if index != len(ls) - 1: + new_structures.extend(recursively_create_combinations(struct, index + 1)) + + return new_structures + +# recursively finds new structure possibilities based on a list of ase Atoms object templates +def substitution_discovery(templates): + all_new_structures = [] + for t in templates: + crystal = io.read(f'/Users/oscarrivera/MatStructPredict_2/data/{t}') + crystal = crystal.get_chemical_symbols() + + formatted_crystal = [] + temp_ls = [crystal[0]] + previous_atom = crystal[0] + for i in crystal[1:]: + if i == previous_atom: + temp_ls.append(i) + else: + previous_atom = i + formatted_crystal.append(temp_ls) + temp_ls = [i] + + formatted_crystal.append(temp_ls) + + new_structures = recursively_create_combinations(formatted_crystal) + + unique = [] + for e in new_structures: + if e not in unique: + unique.append(e) + + for n, u in enumerate(unique): + flattened_u = chain(*u) + unique[n] = Atoms(flattened_u) + + unique = atoms_to_dict(unique) + all_new_structures.extend(unique) + + return all_new_structures def init_structure(composition, pyxtal=False, density=.2): """ @@ -81,30 +196,6 @@ def init_structure(composition, pyxtal=False, density=.2): return atoms_to_dict([atoms], [None])[0] -def atoms_to_dict(atoms, loss=None): - """ - Creates a list of dict from a list of ASE atoms objects - - Args: - atoms (list): A list of ASE atoms objects - energy (list): A list of predicted energies for each ASE atoms object. - - Returns: - list: Contains atoms represented as dicts - """ - res = [{} for _ in atoms] - for i, d in enumerate(res): - d['n_atoms'] = len(atoms[i].get_atomic_numbers()) - d['pos'] = atoms[i].get_positions() - d['cell'] = atoms[i].get_cell() - d['z'] = atoms[i].get_atomic_numbers() - d['atomic_numbers'] = atoms[i].get_atomic_numbers() - if loss is None: - d['loss'] = None - else: - d['loss'] = loss[i] - return res - def dict_to_atoms(dictionaries): """ Creates ASE atoms objects from a list of dictionaries