Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Analyse aqueous Iron Chloride oxidation states."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from ml_peg.analysis.utils.decorators import build_table, plot_scatter
from ml_peg.analysis.utils.utils import load_metrics_config
from ml_peg.app import APP_ROOT
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)

CALC_PATH = CALCS_ROOT / "physicality" / "oxidation_states" / "outputs"
OUT_PATH = APP_ROOT / "data" / "physicality" / "oxidation_states"

METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml")
DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, _ = load_metrics_config(METRICS_CONFIG_PATH)

IRON_SALTS = ["Fe2Cl", "Fe3Cl"]
TESTS = ["Fe-O RDF Peak Split", "Peak Within Experimental Ref"]
REF_PEAK_RANGE = {
"Fe<sup>+2</sup><br>Ref": [2.0, 2.2],
"Fe<sup>+3</sup><br>Ref": [1.9, 2.0],
}


def get_rdf_results(
model: str,
) -> dict[str, tuple[list[float], list[float]]]:
"""
Get a model's Fe-O RDFs for the aqueous Fe2Cl and Fe3Cl MD.

Parameters
----------
model
Name of MLIP.

Returns
-------
results
RDF Radii and intensities for the aqueous Fe2Cl and Fe3Cl systems.
"""
results = {salt: [] for salt in IRON_SALTS}

model_calc_path = CALC_PATH / model

for salt in IRON_SALTS:
rdf_file = model_calc_path / f"O-Fe_{salt}_{model}.rdf"

fe_o_rdf = np.loadtxt(rdf_file)
r = list(fe_o_rdf[:, 0])
g_r = list(fe_o_rdf[:, 1])

results[salt].append(r)
results[salt].append(g_r)

return results


def plot_rdfs(model: str, results: dict[str, tuple[list[float], list[float]]]) -> None:
"""
Plot Fe-O RDFs.

Parameters
----------
model
Name of MLIP.
results
RDF Radii and intensities for the aqueous Fe2Cl and Fe3Cl systems.
"""

@plot_scatter(
filename=OUT_PATH / f"Fe-O_{model}_RDF_scatter.json",
title=f"<b>{model} MD</b>",
x_label="r [Å]",
y_label="Fe-O G(r)",
show_line=True,
show_markers=False,
highlight_range=REF_PEAK_RANGE,
)
def plot_result() -> dict[str, tuple[list[float], list[float]]]:
"""
Plot the RDFs.

Returns
-------
model_results
Dictionary of model Fe-O RDFs for the aqueous Fe2Cl and Fe3Cl systems.
"""
return results

plot_result()


@pytest.fixture
def get_oxidation_states_passfail() -> dict[str, dict]:
"""
Test whether model RDF peaks are split and they fall within the reference range.

Returns
-------
oxidation_states_passfail
Dictionary of pass fail per model.
"""
oxidation_state_passfail = {test: {} for test in TESTS}

fe_2_ref = [2.0, 2.2]
fe_3_ref = [1.9, 2.0]

for model in MODELS:
peak_position = {}
results = get_rdf_results(model)
plot_rdfs(model, results)

for salt in IRON_SALTS:
r = results[salt][0]
g_r = results[salt][1]
peak_position[salt] = r[g_r.index(max(g_r))]

peak_difference = abs(peak_position["Fe2Cl"] - peak_position["Fe3Cl"])

oxidation_state_passfail["Fe-O RDF Peak Split"][model] = 0.0
oxidation_state_passfail["Peak Within Experimental Ref"][model] = 0.0

if peak_difference > 0.07:
oxidation_state_passfail["Fe-O RDF Peak Split"][model] = 1.0

if fe_2_ref[0] <= peak_position["Fe2Cl"] <= fe_2_ref[1]:
oxidation_state_passfail["Peak Within Experimental Ref"][model] += 0.5

if fe_3_ref[0] <= peak_position["Fe3Cl"] <= fe_3_ref[1]:
oxidation_state_passfail["Peak Within Experimental Ref"][model] += 0.5

return oxidation_state_passfail


@pytest.fixture
@build_table(
filename=OUT_PATH / "oxidation_states_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
)
def oxidation_states_passfail_metrics(
get_oxidation_states_passfail: dict[str, dict],
) -> dict[str, dict]:
"""
Get all oxidation states pass fail metrics.

Parameters
----------
get_oxidation_states_passfail
Dictionary of pass fail per model.

Returns
-------
dict[str, dict]
Dictionary of pass fail per model.
"""
return get_oxidation_states_passfail


def test_oxidation_states_passfail_metrics(
oxidation_states_passfail_metrics: dict[str, dict],
) -> None:
"""
Run oxidation states test.

Parameters
----------
oxidation_states_passfail_metrics
All oxidation states pass fail.
"""
return
13 changes: 13 additions & 0 deletions ml_peg/analysis/physicality/oxidation_states/metrics.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
Fe-O RDF Peak Split:
good: 1.0
bad: 0.0
unit: Yes(1)/No(0)
tooltip: Whether there is a split between Fe-O RDF peaks for different iron oxidation states
level_of_theory: Experimental
Peak Within Experimental Ref:
good: 1.0
bad: 0.0
unit: Yes(1)/No(0)
tooltip: Whether the RDF peak positions match experimental peaks
level_of_theory: Experimental
29 changes: 28 additions & 1 deletion ml_peg/analysis/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dash import dash_table
import numpy as np
import pandas as pd
import plotly.colors as pc
import plotly.graph_objects as go

from ml_peg.analysis.utils.utils import calc_table_scores
Expand Down Expand Up @@ -431,8 +432,10 @@ def plot_scatter(
x_label: str | None = None,
y_label: str | None = None,
show_line: bool = False,
show_markers: bool = True,
hoverdata: dict | None = None,
filename: str = "scatter.json",
highlight_range: dict = None,
) -> Callable:
"""
Plot scatter plot of MLIP results.
Expand All @@ -447,10 +450,14 @@ def plot_scatter(
Label for y-axis. Default is `None`.
show_line
Whether to show line between points. Default is False.
show_markers
Whether to show markers on the plot. Default is True.
hoverdata
Hover data dictionary. Default is `{}`.
filename
Filename to save plot as JSON. Default is "scatter.json".
highlight_range
Dictionary of rectangle title and x-axis endpoints.

Returns
-------
Expand Down Expand Up @@ -499,7 +506,13 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]:
hovertemplate += f"<b>{key}: </b>%{{customdata[{i}]}}<br>"
customdata = list(zip(*hoverdata.values(), strict=True))

mode = "lines+markers" if show_line else "markers"
modes = []
if show_line:
modes.append("lines")
if show_markers:
modes.append("markers")

mode = "+".join(modes)

fig = go.Figure()
for mlip, value in results.items():
Expand All @@ -515,6 +528,20 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]:
)
)

colors = pc.qualitative.Plotly

if highlight_range:
for i, (h_text, range) in enumerate(highlight_range.items()):
fig.add_vrect(
x0=range[0],
x1=range[1],
annotation_text=h_text,
annotation_position="top",
fillcolor=colors[i],
opacity=0.25,
line_width=0,
)

fig.update_layout(
title={"text": title},
xaxis={"title": {"text": x_label}},
Expand Down
89 changes: 89 additions & 0 deletions ml_peg/app/physicality/oxidation_states/app_oxidation_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Run oxidation states app."""

from __future__ import annotations

from dash import Dash
from dash.html import Div

from ml_peg.app import APP_ROOT
from ml_peg.app.base_app import BaseApp
from ml_peg.app.utils.build_callbacks import (
plot_from_table_cell,
)
from ml_peg.app.utils.load import read_plot
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)

BENCHMARK_NAME = "Iron Oxidation States"
DATA_PATH = APP_ROOT / "data" / "physicality" / "oxidation_states"
REF_PATH = CALCS_ROOT / "physicality" / "oxidation_states" / "data"


class FeOxidationStatesApp(BaseApp):
"""Fe Oxidation States benchmark app layout and callbacks."""

def register_callbacks(self) -> None:
"""Register callbacks to app."""
scatter_plots = {
model: {
"Fe-O RDF Peak Split": read_plot(
DATA_PATH / f"Fe-O_{model}_RDF_scatter.json",
id=f"{BENCHMARK_NAME}-{model}-figure-Fe-O-RDF",
),
"Peak Within Experimental Ref": read_plot(
DATA_PATH / f"Fe-O_{model}_RDF_scatter.json",
id=f"{BENCHMARK_NAME}-{model}-figure-Fe-O-RDF",
),
}
for model in MODELS
}

plot_from_table_cell(
table_id=self.table_id,
plot_id=f"{BENCHMARK_NAME}-figure-placeholder",
cell_to_plot=scatter_plots,
)


def get_app() -> FeOxidationStatesApp:
"""
Get Fe Oxidation States benchmark app layout and callback registration.

Returns
-------
FeOxidationStatesApp
Benchmark layout and callback registration.
"""
return FeOxidationStatesApp(
name=BENCHMARK_NAME,
description=(
"Evaluate model ability to capture different oxidation states of Fe"
"from aqueous Fe 2Cl and Fe 3Cl MD RDFs"
),
# docs_url=DOCS_URL,
table_path=DATA_PATH / "oxidation_states_table.json",
extra_components=[
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
],
)


if __name__ == "__main__":
# Create Dash app
full_app = Dash(
__name__,
assets_folder=DATA_PATH.parent.parent,
suppress_callback_exceptions=True,
)

# Construct layout and register callbacks
FeOxidationStatesApp = get_app()
full_app.layout = FeOxidationStatesApp.layout
FeOxidationStatesApp.register_callbacks()

# Run app
full_app.run(port=8054, debug=True)
Loading