Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/source/user_guide/benchmarks/physicality.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,36 @@ Data availability
-----------------

None required; diatomics are generated in ASE.


Water Slab Dipoles
==================

Summary
-------

Distribution of dipole of water slab, checking for width of distribution and structures with dielectric breakdown.


Metrics
-------

1. Standard Deviation of Dipole Distribution

For a number of samples from an MD simulation, the total dipole is calculated. Compare to a reference of a LR model trained on revPBE-D3.

2. Number of structures with dielectric breakdown

Estimate band gap based on dipole, count structures where band gap disappears.


Computational Cost
------------------

High: Requires around 500 ps of MD of 40 A slab to get converged distribution.


Data availability
-----------------

Paper in preparation, contact Isaac Parker (ijp30@cam.ac.uk).
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""Analyse water slab dipole benchmark."""

from __future__ import annotations

from pathlib import Path

from ase import units
from ase.io import read
import numpy as np
import pytest
from scipy.constants import e, epsilon_0

from ml_peg.analysis.utils.decorators import build_table, plot_hist
from ml_peg.analysis.utils.utils import build_d3_name_map, load_metrics_config
from ml_peg.app import APP_ROOT
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)
D3_MODEL_NAMES = build_d3_name_map(MODELS)
CALC_PATH = CALCS_ROOT / "physicality" / "water_slab_dipoles" / "outputs"
OUT_PATH = APP_ROOT / "data" / "physicality" / "water_slab_dipoles"

METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml")
DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config(
METRICS_CONFIG_PATH
)

# Unit conversion
EV_TO_KJ_PER_MOL = units.mol / units.kJ

# We consider a dipole as bad if the expected band gap is <= 0
# The expected band gap is 4.50 V - |P_z_per_unit_area| / epsilon_0
# Hence "bad" is |P_z_per_unit_area| > 4.50 V * epsilon_0 / e * 10^(-10)
# epsilon_0 is in F/m = C/(V*m), so this gives it in e/(V*A)
DIPOLE_BAD_THRESHOLD = 4.50 * epsilon_0 / e * 10 ** (-10)


def get_dipoles() -> dict[str, np.ndarray]:
"""
Get total dipole per unit area in z direction.

Returns
-------
dict[str, np.ndarray]
Dictionary with array of dipoles for each model.
"""
results = {}
for model_name in MODELS:
model_dir = CALC_PATH / model_name
if model_dir.exists():
if (model_dir / "dipoles.npy").is_file():
results[model_name] = np.load(model_dir / "dipoles.npy")
else:
atoms = read(model_dir / "slab.xyz", ":")
dipoles = np.zeros(len(atoms))
for i, struc in enumerate(atoms):
o_index = [atom.index for atom in struc if atom.number == 8]
h_index = [atom.index for atom in struc if atom.number == 1]
dipoles[i] = (
np.sum(struc.positions[o_index, 2]) * (-0.8476)
+ np.sum(struc.positions[h_index, 2]) * 0.4238
)
dipoles_unit_area = dipoles / atoms[0].cell[0, 0] / atoms[0].cell[1, 1]
results[model_name] = dipoles_unit_area
np.save(model_dir / "dipoles.npy", dipoles_unit_area)
return results


def plot_distribution(model: str) -> None:
"""
Plot Dipole Distribution.

Parameters
----------
model
Name of MLIP.
"""
bins_start = -1.5 * DIPOLE_BAD_THRESHOLD
bins_stop = 1.5 * DIPOLE_BAD_THRESHOLD
bins_size = 3 * DIPOLE_BAD_THRESHOLD / 40
# one might want to consider reducing the bin size further...

@plot_hist(
filename=OUT_PATH / f"figure_{model}_dipoledistr.json",
title=f"Dipole Distribution {model}",
x_label="Total z-Dipole per unit area [e/A]",
y_label="Probability Density",
good=-DIPOLE_BAD_THRESHOLD,
bad=DIPOLE_BAD_THRESHOLD,
bins={"start": bins_start, "end": bins_stop, "size": bins_size},
)
def plot_distr() -> dict[str, np.ndarray]:
"""
Plot a NEB and save the structure file.

Returns
-------
dict[str, np.ndarray]
Dictionary of array with all dipoles for each model.
"""
return {model: get_dipoles()[model]}

plot_distr()


@pytest.fixture
def dipole_std() -> dict[str, float]:
"""
Get standard deviation of total z dipole per unit area (in e/A).

Returns
-------
dict[str, float]
Dictionary of standard deviation of dipole distribution for all models.
"""
dipoles = get_dipoles()
results = {}
for model_name in MODELS:
if model_name in dipoles.keys():
plot_distribution(model_name)
results[model_name] = np.std(dipoles[model_name])
else:
results[model_name] = None
return results


@pytest.fixture
def n_bad() -> dict[str, float]:
"""
Get fraction of dipoles that are bad.

Returns
-------
dict[str, float]
Dictionary of percentage of breakdown candidates for all models.
"""
dipoles = get_dipoles()

results = {}
for model_name in MODELS:
if model_name in dipoles.keys():
plot_distribution(model_name)
results[model_name] = (
np.abs(dipoles[model_name]) > DIPOLE_BAD_THRESHOLD
).sum() / len(dipoles[model_name])
else:
results[model_name] = None
return results


@pytest.fixture
@build_table(
filename=OUT_PATH / "water_slab_dipoles_metrics_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
)
def metrics(dipole_std: dict[str, float], n_bad: dict[str, float]) -> dict[str, dict]:
"""
Get all water slab dipoles metrics.

Parameters
----------
dipole_std
Standard deviation of dipole distribution.
n_bad
Percentage of tested structures with dipole larger than water band gap.

Returns
-------
dict[str, dict]
Metric names and values for all models.
"""
return {
"sigma": dipole_std,
"Fraction Breakdown Candidates": n_bad,
}


def test_water_slab_dipoles(metrics: dict[str, dict]) -> None:
"""
Run water slab dipoles test.

Parameters
----------
metrics
All water slab dipole metrics.
"""
return
13 changes: 13 additions & 0 deletions ml_peg/analysis/physicality/water_slab_dipoles/metrics.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
sigma:
good: 0.007
bad: 0.015
unit: e/A
tooltip: Standard deviation of total dipole in z direction
level of theory: null
Fraction Breakdown Candidates:
good: 0
bad: 1
unit: null
tooltip: Fraction of structures with dipole larger than band gap
level of theory: revPBE-D3
143 changes: 143 additions & 0 deletions ml_peg/analysis/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,149 @@ def wrapper(*args, **kwargs):
return decorator


def plot_hist(
*,
bins: Any | None = None,
good: float | None = None,
bad: float | None = None,
title: str | None = None,
x_label: str | None = None,
y_label: str | None = None,
filename: str | Path,
) -> Callable:
"""
Plot scatter plot of MLIP results.

Parameters
----------
bins
Bins for histogram. Either int or directory
with start, end, size. Default is None.
good
Minimum threshold for good values. Requires bins dict.
Default is None.
bad
Maximum threshold for good values. Requires bins dict.
Default is None.
title
Graph title.
x_label
Label for x-axis. Default is `None`.
y_label
Label for y axis. Default is `None`.
filename
Filename to save plot as JSON. Default is "scatter.json".

Returns
-------
Callable
Decorator to wrap function.
"""

def plot_hist_decorator(func: Callable) -> Callable:
"""
Decorate function to plot scatter.

Parameters
----------
func
Function being wrapped.

Returns
-------
Callable
Wrapped function.
"""

@functools.wraps(func)
def plot_hist_wrapper(*args, **kwargs) -> dict[str, Any]:
"""
Wrap function to plot scatter.

Parameters
----------
*args
Arguments to pass to the function being wrapped.
**kwargs
Key word arguments to pass to the function being wrapped.

Returns
-------
dict
Results dictionary.
"""
results = func(*args, **kwargs)

# hovertemplate = "<b>Pred: </b>%{x}<br>" + "<b>Ref: </b>%{y}<br>"
# customdata = []
# if hoverdata:
# for i, key in enumerate(hoverdata):
# hovertemplate += f"<b>{key}: </b>%{{customdata[{i}]}}<br>"
# customdata = list(zip(*hoverdata.values(), strict=True))

fig = go.Figure()
data_all = []
for model_name, hist_data in results.items():
# Create figure
for point in hist_data:
data_all.append(point)
if bins is None or isinstance(bins, int) or isinstance(bins, float):
fig.add_trace(
go.Histogram(
x=hist_data,
histnorm="probability density",
nbinsx=bins,
name=model_name,
)
)
else:
fig.add_trace(
go.Histogram(
x=hist_data,
histnorm="probability density",
xbins=bins,
autobinx=False,
name=model_name,
)
)

if good is not None and bad is not None and isinstance(bins, dict):
actual_bins = [min(data_all)]
point = actual_bins[0]
while point < max(data_all):
point += bins["size"]
actual_bins.append(point)
colors = np.zeros_like(actual_bins)
bad_exists = False
for i, point in enumerate(actual_bins):
if point < good or point > bad:
bad_exists = True
colors[i] = bins["start"]
else:
colors[i] = bins["end"]
if not bad_exists:
colors = "#276419"
fig.update_traces(marker_color=colors)
# Update layout
fig.update_layout(
title={"text": title},
xaxis={"title": {"text": x_label}},
yaxis={"title": {"text": y_label}},
)

fig.update_traces()

# Write to file
Path(filename).parent.mkdir(parents=True, exist_ok=True)
fig.write_json(filename)

return results

return plot_hist_wrapper

return plot_hist_decorator


def plot_scatter(
title: str | None = None,
x_label: str | None = None,
Expand Down
Loading