Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 118 additions & 27 deletions msp/structure/structure_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,128 @@
import numpy as np
from ase import Atoms
from ase import Atoms, io
import torch
from torch_geometric.data import Data
from ase.data import chemical_symbols
import smact
from smact.screening import pauling_test
import pymatgen, pymatgen.io.ase, pymatgen
from pymatgen.core.structure import Element
from ase.data import chemical_symbols
import itertools
from itertools import product, chain
from pymatgen.io.cif import CifWriter

def atoms_to_dict(atoms, loss=None):
"""
Creates a list of dict from a list of ASE atoms objects

Args:
atoms (list): A list of ASE atoms objects
energy (list): A list of predicted energies for each ASE atoms object.

Returns:
list: Contains atoms represented as dicts
"""
res = [{} for _ in atoms]
for i, d in enumerate(res):
d['n_atoms'] = len(atoms[i].get_atomic_numbers())
d['pos'] = atoms[i].get_positions()
d['cell'] = atoms[i].get_cell()
d['z'] = atoms[i].get_atomic_numbers()
d['atomic_numbers'] = atoms[i].get_atomic_numbers()
if loss is None:
d['loss'] = None
else:
d['loss'] = loss[i]
return res

# takes a list of similar elements, base_list, and creates all possible permutations with the replacement elements
def generate_replacements(base_list, replacement):
replacements = product((base_list[0], replacement), repeat=len(base_list))

possibilities = []

for i in replacements:
possibilities.append(list(i))

return possibilities

# for a given atom, returns a list of all possible replacements atoms based on common oxidation states
def generate_replacement_possibilities(atom):
element_abbreviations = [
"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
"Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y",
"Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce",
"Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir",
"Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm",
"Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc",
"Lv", "Ts", "Og"
]

possibilities = []

for el in element_abbreviations:
if len(set(Element(atom).common_oxidation_states).intersection(set(Element(el).common_oxidation_states))) > 0:
possibilities.append(el)

return possibilities

# recursive substitution method
def recursively_create_combinations(ls, index=0):
new_structures = []

replacement_possibilities = generate_replacement_possibilities(ls[index][0])

replacements = []

for atom in replacement_possibilities:
replacements.append(generate_replacements(ls[index], atom))

for x in replacements:
for y in x:
new_structure = ls.copy()
new_structure[index] = y
new_structures.append(new_structure)

for struct in new_structures.copy():
if index != len(ls) - 1:
new_structures.extend(recursively_create_combinations(struct, index + 1))

return new_structures

# recursively finds new structure possibilities based on a list of ase Atoms object templates
def substitution_discovery(templates):
all_new_structures = []
for t in templates:
crystal = io.read(f'/Users/oscarrivera/MatStructPredict_2/data/{t}')
crystal = crystal.get_chemical_symbols()

formatted_crystal = []
temp_ls = [crystal[0]]
previous_atom = crystal[0]
for i in crystal[1:]:
if i == previous_atom:
temp_ls.append(i)
else:
previous_atom = i
formatted_crystal.append(temp_ls)
temp_ls = [i]

formatted_crystal.append(temp_ls)

new_structures = recursively_create_combinations(formatted_crystal)

unique = []
for e in new_structures:
if e not in unique:
unique.append(e)

for n, u in enumerate(unique):
flattened_u = chain(*u)
unique[n] = Atoms(flattened_u)

unique = atoms_to_dict(unique)
all_new_structures.extend(unique)

return all_new_structures

def init_structure(composition, pyxtal=False, density=.2):
"""
Expand Down Expand Up @@ -81,30 +196,6 @@ def init_structure(composition, pyxtal=False, density=.2):

return atoms_to_dict([atoms], [None])[0]

def atoms_to_dict(atoms, loss=None):
"""
Creates a list of dict from a list of ASE atoms objects

Args:
atoms (list): A list of ASE atoms objects
energy (list): A list of predicted energies for each ASE atoms object.

Returns:
list: Contains atoms represented as dicts
"""
res = [{} for _ in atoms]
for i, d in enumerate(res):
d['n_atoms'] = len(atoms[i].get_atomic_numbers())
d['pos'] = atoms[i].get_positions()
d['cell'] = atoms[i].get_cell()
d['z'] = atoms[i].get_atomic_numbers()
d['atomic_numbers'] = atoms[i].get_atomic_numbers()
if loss is None:
d['loss'] = None
else:
d['loss'] = loss[i]
return res

def dict_to_atoms(dictionaries):
"""
Creates ASE atoms objects from a list of dictionaries
Expand Down