Skip to content
94 changes: 93 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,93 @@
# MatStructPredict
# MatStructPredict: An Open Source Library for GNN-Powered Structure Prediction

MatStructPredict is a machine learning library that offers simple, flexible pipelines for structure prediction and active learning.

## Table of Contents
- [MatStructPredict: An Open Source Library for GNN-Powered Structure Prediction](#matstructpredict-an-open-source-library-for-gnn-powered-structure-prediction)
- [Table of Contents](#table-of-contents)
- [Motivation](#motivation)
- [Features](#features)
- [Installation](#installation)
- [Quick Start: Structure Prediction](#quick-start-structure-prediction)
- [Contributing](#contributing)
- [License](#license)
- [Citation](#citation)


## Motivation

With more powerful and more accurate Graph Neural Networks coming into play, structure prediction using GNNs has become a fast and effective method for generating structures at scale. There have been multiple occasions where a vast amount of structures were generated using GNNs for property optimizaiton. However, creating programs to run GNN-based structure prediction requires specialized knowledge of PyTorch and Machine Learning. To help enable people of varying levels of machine learning knowledge to generate structures, we have created MatStructPredict.

MatStructPredict is a library that offers simple, customizable pipelines for structure prediction. The library offers the following features:
- Training and Evaluating ML Models
- Composition generation
- Global Optimization
- BasinHopping
- Structure Prediction
- Optimize structures for multiple objectives
- Optimize atomic positions and atomic cell

By simplifying the process of predicting structures, MatStructPredict gives researchers the ability to generate structures for their own use cases, regardless of whether or not they are familiar with machine learning.

## Features

- **Pre-trained Model Support**: Use multiple pre-trained models for ASE optimization:
- Chgnet
- MACE
- M3GNet

- **MatDeepLearn Model Features**: Use all models supported by MatDeepLearn for:
- Training
- Evaluating
- Batch Optimization
- Custom Objective Structure Prediction

- **Flexible Property**: Support for various molecular and materials properties:
- Energy prediction
- Force prediction (both conservative and non-conservative)
- Stress tensor prediction

- **Flexible Objectives**: Support for various molecular and materials properties:
- Energy
- Novelty
- Embedding Distance
- Uncertainty
- LJR Loss

- **Structure Prediction**: Pipelines for Structure Prediction from start to finish:
- SMACT Valid Composition generation
- Custom compositions
- Random Lithium Compositions
- Random Generic Compositions
- Global Optimization with Basin Hopping
- Includes following perturbs:
- Cell
- Positions
- Atomic Numbers
- Add/remove/swap atoms
- Saving structures
- Finetuning model on new structures

## Installation
TODO
```bash
pip install matstructpredict
```

## Quick Start: Structure Prediction

Use example.py and mdl_config.yml for a quick structure prediction run using a MatDeepLearn model. Remember to adjust the file paths to your corresponding dataset and ideal save paths.

Example.ipynb provides a Jupyter Notebook that takes users through each step of the Structure Prediction process.

## Contributing
TODO
## License
TODO
## Citation
TODO
If you use MatStructPredict in your research, please cite:

TODO
```bibtex
```
4 changes: 2 additions & 2 deletions msp/composition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__all__ = ["generate_random_compositions", "sample_random_composition"]
__all__ = ["generate_random_compositions", "sample_random_composition", "generate_random_lithium_compositions"]

from .composition import generate_random_compositions, sample_random_composition
from .composition import generate_random_compositions, sample_random_composition, generate_random_lithium_compositions
79 changes: 72 additions & 7 deletions msp/composition/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def hash_structure(atomic_numbers):
"""
counts = Counter(atomic_numbers)
sorted_counts = sorted(counts.items())
# divide the counts by the gcd of the counts
gcd = sorted_counts[0][1]
for elem in sorted_counts:
gcd = math.gcd(gcd, elem[1])
Expand All @@ -43,8 +42,9 @@ def generate_random_compositions(dataset, n=5, max_elements=5, max_atoms=20, ele
Args:
dataset (dict): dictionary of dataset
n (int): number of compositions to generate
max_elements (int): maximum number of elements in composition
max_atoms (int): maximum number of atoms per element
max_elements (int): maximum number of unique elements in composition
max_atoms (int): maximum number of atoms in composition
elems_to_sample (list): list of elements to sample from

Returns:
compositions (list): list of compositions
Expand All @@ -55,8 +55,7 @@ def generate_random_compositions(dataset, n=5, max_elements=5, max_atoms=20, ele
if len(elems_to_sample) == 0:
elems_to_sample = [1, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
89, 90, 91, 92, 93, 94]
48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83]
for i in range(n):
while True:
comp = []
Expand Down Expand Up @@ -85,7 +84,7 @@ def generate_random_compositions(dataset, n=5, max_elements=5, max_atoms=20, ele
print('Potential composition: ', comp)
smact_valid = smact_validity(rand_elems, freq)
print('SMACT validity: ', smact_valid)
if not smact_validity(rand_elems, freq):
if not smact_valid:
print('Invalid composition')
continue
comp_hash = hash_structure(comp)
Expand All @@ -95,10 +94,76 @@ def generate_random_compositions(dataset, n=5, max_elements=5, max_atoms=20, ele
comp_hashes.append(comp_hash)
break
else:
print('Invalid compositon, already occurs')
print('Invalid compositon, already occurs in dataset')
return compositions

def generate_random_lithium_compositions(dataset, n=5, max_elements=6, max_atoms=20, li_ratio_lower=.2, li_ratio_upper=.4, halide_ratio_lower=.2, halide_ratio_upper=.5):
"""
Generate n unique lithium compositions that do not appear in dataset randomly
Args:
dataset (dict): dictionary of dataset
n (int): number of compositions to generate
max_elements (int): maximum number of unique elements in composition
max_atoms (int): maximum number of atoms in composition
li_ratio_lower (float): lower bound for lithium ratio
li_ratio_upper (float): upper bound for lithium ratio
halide_ratio_lower (float): lower bound for halide ratio
halide_ratio_upper (float): upper bound for halide ratio

Returns:
compositions (list): list of compositions
"""
compositions = []
comp_hashes = []
hashed_dataset = hash_dataset(dataset)
halides = [9, 17]
metals = [39, 13, 22, 21, 31, 49, 40, 12, 30, 32, 57, 58, 41]
for i in range(n):
while True:
comp = []
total_atoms = np.random.randint(5, max_atoms + 1)
num_lithium = np.random.randint(max(1, round(total_atoms * li_ratio_lower)), round(total_atoms * li_ratio_upper) + 1)
comp.extend([3] * num_lithium)
num_halides = np.random.randint(round(total_atoms * halide_ratio_lower), round(total_atoms * halide_ratio_upper) + 1)
comp.extend(np.random.choice(halides, num_halides, replace=True))
if len(comp) >= total_atoms:
print('Invalid composition, no space for metals: ', comp)
continue
temp_metals = []
while np.unique(comp).size < max_elements and len(comp) < total_atoms:
temp_metals = np.random.choice(metals, 1, replace=True)
comp.append(temp_metals[-1])
if len(comp) < total_atoms:
comp.extend(np.random.choice(temp_metals, total_atoms - len(comp), replace=True))
elems = np.unique(comp)
freq = [comp.count(elem) for elem in elems]
smact_valid = smact_validity(elems, freq)
print('SMACT validity: ', smact_valid)
if not smact_valid:
print('Invalid composition')
continue
comp_hash = hash_structure(comp)
if comp_hash not in hashed_dataset and comp_hash not in comp_hashes:
print('Accepted composition', i, ':', comp)
comp.sort()
compositions.append(comp)
comp_hashes.append(comp_hash)
break
else:
print('Invalid compositon, already occurs in dataset')
return compositions


def sample_random_composition(dataset, n=5):
"""
Sample n random compositions from the dataset
Args:
dataset (dict): dictionary of dataset
n (int): number of compositions to sample

Returns:
dataset_comps (list): list of compositions
"""
dataset_comps = []
for data in dataset:
data['atomic_numbers'].sort()
Expand Down
Loading