From c62059493fc8e018e02a9b2930fd18b1f1067d30 Mon Sep 17 00:00:00 2001 From: dwl38 Date: Fri, 30 Jan 2026 18:14:05 +0000 Subject: [PATCH 1/3] Added new callback for plot_from_struct --- ml_peg/app/utils/build_callbacks.py | 46 +++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/ml_peg/app/utils/build_callbacks.py b/ml_peg/app/utils/build_callbacks.py index f3aab2ee2..5714c71a9 100644 --- a/ml_peg/app/utils/build_callbacks.py +++ b/ml_peg/app/utils/build_callbacks.py @@ -130,6 +130,52 @@ def show_plot(active_cell, current_table_data) -> Div: return Div("Click on a metric to view plot.") +def plot_from_scatter( + scatter_id: str, + plot_id: str, + plots_list: list[Graph], +) -> None: + """ + Attach callback to show plot when a table cell is clicked. + + Parameters + ---------- + scatter_id + ID for Dash scatter being clicked. + plot_id + ID for Dash plot placeholder Div where new plot will be rendered. + plots_list + List of plots to show, in same order as scatter data. + """ + + @callback( + Output(plot_id, "children", allow_duplicate=True), + Input(scatter_id, "clickData"), + prevent_initial_call="initial_duplicate", + ) + def show_plot(click_data) -> Div: + """ + Register callback to show plot when a scatter point is clicked. + + Parameters + ---------- + click_data + Clicked data point in scatter plot. + + Returns + ------- + Div + Plot on scatter click. + """ + if not click_data: + return Div("Click on a metric to view plot.") + idx = click_data["points"][0]["pointNumber"] + + if idx >= 0 and idx < len(plots_list): + return Div(plots_list[idx]) + return Div("Click on a metric to view plot.") + + def struct_from_scatter( scatter_id: str, struct_id: str, From def095ca62e115d4c94aa3cf3b1fb845e9c85225 Mon Sep 17 00:00:00 2001 From: dwl38 Date: Sat, 31 Jan 2026 18:59:53 +0000 Subject: [PATCH 2/3] Added graphene wetting under strain benchmark --- .../source/user_guide/benchmarks/surfaces.rst | 47 ++ .../analyse_graphene_wetting_under_strain.py | 462 ++++++++++++++++++ .../graphene_wetting_under_strain/metrics.yml | 22 + .../app_graphene_wetting_under_strain.py | 373 ++++++++++++++ ml_peg/app/utils/build_callbacks.py | 2 +- .../calc_graphene_wetting_under_strain.py | 85 ++++ .../database_info.yml | 8 + 7 files changed, 998 insertions(+), 1 deletion(-) create mode 100644 ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py create mode 100644 ml_peg/analysis/surfaces/graphene_wetting_under_strain/metrics.yml create mode 100644 ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py create mode 100644 ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py create mode 100644 ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml diff --git a/docs/source/user_guide/benchmarks/surfaces.rst b/docs/source/user_guide/benchmarks/surfaces.rst index 780118c7c..d2fe4225c 100644 --- a/docs/source/user_guide/benchmarks/surfaces.rst +++ b/docs/source/user_guide/benchmarks/surfaces.rst @@ -102,6 +102,7 @@ Reference data: * Same as input data * PBE-D3(BJ), MPRelaxSet settings + Elemental Slab Oxygen Adsorption ================================ @@ -145,3 +146,49 @@ Reference data: * S. P. Ong, W. D. Richards, A. Jain, G. Hautier, M. Kocher, S. Cholia, D. Gunter, V. Chevrier, K. A. Persson, G. Ceder, "Python Materials Genomics (pymatgen): A Robust, Open-Source Python Library for Materials Analysis," Comput. Mater. Sci., 2013, 68, 314–319. https://doi.org/10.1016/j.commatsci.2012.10.028 * Tran et al. relaxed the slabs using spin-polarized PBE calculations performed in VASP, with a cutoff energy of 400 eV. + + +Graphene Wetting Under Strain +============================= + +Summary +------- + +Performance in predicting adsorption energies for a water molecule on graphene under varying strain conditions. + +Metrics +------- + +MAE of adsorption energies + +For each combination of water molecule orientation, water-graphene distance, and strain +condition, the adsorption energy is calculated by taking the difference between the +energy of the combined water + graphene system and the sum of individual water and +graphene energies. This is compared to the reference adsorption energy, calculated in the +same way. + +MAE of binding energies & lengths + +The adsorption energies calculated above are fitted to Morse potentials, to obtain an +effective binding energy and binding length (i.e. minimum of adsorption energy curve) for +each strain condition. This is compared to the reference binding energy & length, +calculated in the same way. + +Computational cost +------------------ + +Very low: tests are likely to take less than a minute to run on CPU. + +Data availability +----------------- + +Input data: + +* Structures were taken from: + + * D. W. Lim, X. R. Advincula, W. C. Witt, F. L. Thiemann, C. Schran, “Revealing Strain Effects on the Graphene-Water Contact Angle Using a Machine Learning Potential,” *awaiting publication* (arXiv:2601.20134) + +Reference data: + +* Same as input data +* PBE (with D3 dispersion correction), FHI-aims "intermediate" settings diff --git a/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py b/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py new file mode 100644 index 000000000..90d079b00 --- /dev/null +++ b/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py @@ -0,0 +1,462 @@ +"""Analyse graphene wetting under strain benchmark.""" + +from __future__ import annotations + +from pathlib import Path + +from ase import Atoms +import ase.io +import numpy as np +import pytest +from scipy.optimize import curve_fit +import yaml + +from ml_peg.analysis.utils.decorators import build_table, plot_parity, plot_scatter +from ml_peg.analysis.utils.utils import load_metrics_config +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +CALC_PATH = CALCS_ROOT / "surfaces" / "graphene_wetting_under_strain" / "outputs" +OUT_PATH = APP_ROOT / "data" / "surfaces" / "graphene_wetting_under_strain" + +METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") +DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( + METRICS_CONFIG_PATH +) + +with open( + CALCS_ROOT / "surfaces" / "graphene_wetting_under_strain" / "database_info.yml" +) as fp: + DATABASE_INFO = yaml.safe_load(fp) +ORIENTATIONS = DATABASE_INFO["orientations"] +STRAINS = DATABASE_INFO["strains"] + + +def get_molecule_distance(sys: Atoms) -> float: + """ + Compute distance of water molecule from graphene sheet for one configuration. + + Parameters + ---------- + sys + Single frame of water molecule + graphene system. + + Returns + ------- + float + Water molecule distance. + """ + assert np.sum(sys.symbols == "O") == 1 + assert np.sum(sys.symbols == "C") > 1 + oxygens = sys.positions[sys.symbols == "O"] + carbons = sys.positions[sys.symbols == "C"] + return oxygens[0, 2] - np.mean(carbons[:, 2]) + + +def morse_potential(r: np.ndarray, de: float, a: float, re: float) -> np.ndarray: + """ + Compute Morse potential. + + Parameters + ---------- + r + Radial coordinates. + de + Potential well depth. + a + Decay coefficient (related to spring constant). + re + Equilibrium length. + + Returns + ------- + NDArray + Potentials at corresponding radii. + """ + return de * (((1.0 - np.exp(a * (re - r))) ** 2) - 1.0) + + +def get_binding_parameters( + distances: list[float], adsorption_energies: list[float] +) -> tuple[float, float, float]: + """ + Compute best-fit parameters for adsorption energy curve. + + Parameters + ---------- + distances + Water molecule distances. + adsorption_energies + Corresponding adsorption energies. + + Returns + ------- + float + Potential well depth (de) of Morse potential. + float + Decay coefficient (a) of Morse potential. + float + Equilibrium length (re) of Morse potential. + """ + popt = (np.inf, 0.0, np.inf) + if np.min(adsorption_energies) < 0.0: + depth = max(abs(np.min(adsorption_energies)), 5.0) + idx = np.argmin(adsorption_energies) + re = distances[idx] + if idx > 0 and idx < (len(distances) - 1): + second_deriv = ( + adsorption_energies[idx + 1] + + adsorption_energies[idx - 1] + - (2.0 * adsorption_energies[idx]) + ) + second_deriv /= ((distances[idx + 1] - distances[idx - 1]) / 2.0) ** 2 + a = max(np.sqrt(second_deriv / (2.0 * depth)), 0.5) + else: + a = 1.3 + try: + popt, _ = curve_fit( + morse_potential, + distances, + adsorption_energies, + p0=(depth, a, re), + bounds=(0.0, np.inf), + ) + except (ValueError, RuntimeError): + pass + return popt + + +@pytest.fixture +def processed_data() -> dict[str, list]: + """ + Gather and process all data for all systems. + + Returns + ------- + dict[str, list | dict[str, dict[str, dict[str, dict[str, list | tuple]]]]] + Dictionary of all processed data. + """ + results = {"distances": [], "ref": {}} | {model: {} for model in MODELS} + ref_stored = False + dist_stored = False + + for model in MODELS: + model_dir = CALC_PATH / model + if not model_dir.exists(): + continue + + for orientation in ORIENTATIONS: + if not ref_stored: + results["ref"][orientation] = {} + results[model][orientation] = {} + + for strain in STRAINS: + if not ref_stored: + results["ref"][orientation][strain] = { + "energies": [], + } + results[model][orientation][strain] = { + "energies": [], + } + + struct_write_dir = OUT_PATH / model / "structs" + struct_write_dir.mkdir(parents=True, exist_ok=True) + systems = ase.io.iread( + model_dir / f"{orientation}_{strain}.xyz", + index=":", + format="extxyz", + ) + for atoms in systems: + dist = get_molecule_distance(atoms) + if not dist_stored: + results["distances"].append(dist) + if not ref_stored: + results["ref"][orientation][strain]["energies"].append( + atoms.info["ref_adsorption_energy"] * 1000.0 + ) + results[model][orientation][strain]["energies"].append( + atoms.info["mlip_adsorption_energy"] * 1000.0 + ) + + ase.io.write( + struct_write_dir / f"{orientation}_{strain}_L{dist:.2f}.xyz", + atoms, + format="xyz", + ) + + if not ref_stored: + results["ref"][orientation][strain]["params"] = ( + get_binding_parameters( + results["distances"], + results["ref"][orientation][strain]["energies"], + ) + ) + results[model][orientation][strain]["params"] = get_binding_parameters( + results["distances"], + results[model][orientation][strain]["energies"], + ) + + @plot_scatter( + filename=OUT_PATH / model / f"figure_{orientation}_{strain}.json", + title=f"{orientation} binding energy curve ({strain[1:5]}% strain)", + x_label="Distance / Å", + y_label="Adsorption energy / meV", + show_line=True, + ) + def plot_model_binding_energy_curve( + model, orientation, strain + ) -> dict[str, tuple[list[float], list[float]]]: + return { + "ref": ( + results["distances"], + results["ref"][orientation][strain]["energies"], + ), + model: ( + results["distances"], + results[model][orientation][strain]["energies"], + ), + } + + plot_model_binding_energy_curve(model, orientation, strain) + dist_stored = True + + def get_all_hover_data() -> dict[str, list[str]]: + hover_data = {"Orientation": [], "Strain": [], "Distance": []} + for orientation in ORIENTATIONS: + for strain in STRAINS: + for dist in results["distances"]: + hover_data["Orientation"].append(orientation) + hover_data["Strain"].append(strain[1:5] + "%") + hover_data["Distance"].append(f"{dist:.2f} Å") + return hover_data + + @plot_parity( + filename=OUT_PATH / model / "figure_all_parity.json", + title="Adsorption energies", + x_label="Predicted adsorption energy / meV", + y_label="Reference adsorption energy / meV", + hoverdata=get_all_hover_data(), + ) + def plot_model_all_parity(model) -> dict[str, list[float]]: + parity_data = {"ref": [], model: []} + for orientation in ORIENTATIONS: + for strain in STRAINS: + parity_data["ref"].extend( + results["ref"][orientation][strain]["energies"] + ) + parity_data[model].extend( + results[model][orientation][strain]["energies"] + ) + return parity_data + + def get_binding_hover_data() -> dict[str, list[str]]: + hover_data = {"Orientation": [], "Strain": []} + for orientation in ORIENTATIONS: + for strain in STRAINS: + hover_data["Orientation"].append(orientation) + hover_data["Strain"].append(strain[1:5] + "%") + return hover_data + + @plot_parity( + filename=OUT_PATH / model / "figure_binding_energies_parity.json", + title="Binding energies", + x_label="Predicted binding energy / meV", + y_label="Reference binding energy / meV", + hoverdata=get_binding_hover_data(), + ) + def plot_model_binding_energies_parity(model) -> dict[str, list[float]]: + parity_data = {"ref": [], model: []} + for orientation in ORIENTATIONS: + for strain in STRAINS: + parity_data["ref"].append( + results["ref"][orientation][strain]["params"][0] + ) + parity_data[model].append( + np.nan_to_num( + results[model][orientation][strain]["params"][0], + nan=-1.0, + posinf=-1.0, + neginf=-1.0, + ) + ) + return parity_data + + @plot_parity( + filename=OUT_PATH / model / "figure_binding_lengths_parity.json", + title="Binding lengths", + x_label="Predicted binding length / Å", + y_label="Reference binding length / Å", + hoverdata=get_binding_hover_data(), + ) + def plot_model_binding_lengths_parity(model) -> dict[str, list[float]]: + parity_data = {"ref": [], model: []} + for orientation in ORIENTATIONS: + for strain in STRAINS: + parity_data["ref"].append( + results["ref"][orientation][strain]["params"][2] + ) + parity_data[model].append( + np.nan_to_num( + results[model][orientation][strain]["params"][2], + nan=-1.0, + posinf=-1.0, + neginf=-1.0, + ) + ) + return parity_data + + plot_model_all_parity(model) + plot_model_binding_energies_parity(model) + plot_model_binding_lengths_parity(model) + ref_stored = True + + return results + + +@pytest.fixture +def all_adsorption_energies_mae(processed_data) -> dict[str, float]: + """ + Get mean absolute error for all adsorption energies. + + Parameters + ---------- + processed_data + Dictionary of processed data. + + Returns + ------- + dict[str, float] + Dictionary of MAEs for all models. + """ + results = {} + for model in MODELS: + deviations = [] + for orientation in ORIENTATIONS: + for strain in STRAINS: + for i in range(len(processed_data["distances"])): + deviations.append( + abs( + processed_data[model][orientation][strain]["energies"][i] + - processed_data["ref"][orientation][strain]["energies"][i] + ) + ) + results[model] = np.mean(deviations) + return results + + +@pytest.fixture +def binding_energies_mae(processed_data) -> dict[str, float]: + """ + Get mean absolute error of binding energies across all orientations and strains. + + Parameters + ---------- + processed_data + Dictionary of processed data. + + Returns + ------- + dict[str, float] + Dictionary of binding energy MAEs for all models. + """ + results = {} + for model in MODELS: + deviations = [] + for orientation in ORIENTATIONS: + for strain in STRAINS: + deviations.append( + abs( + processed_data[model][orientation][strain]["params"][0] + - processed_data["ref"][orientation][strain]["params"][0] + ) + ) + results[model] = np.nan_to_num( + np.mean(deviations), nan=99999, posinf=99999, neginf=99999 + ) + return results + + +@pytest.fixture +def binding_lengths_mae(processed_data) -> dict[str, float]: + """ + Get mean absolute error of binding lengths across all orientations and strains. + + Parameters + ---------- + processed_data + Dictionary of processed data. + + Returns + ------- + dict[str, float] + Dictionary of binding length MAEs for all models. + """ + results = {} + for model in MODELS: + deviations = [] + for orientation in ORIENTATIONS: + for strain in STRAINS: + deviations.append( + abs( + processed_data[model][orientation][strain]["params"][2] + - processed_data["ref"][orientation][strain]["params"][2] + ) + ) + results[model] = np.nan_to_num( + np.mean(deviations), nan=999, posinf=999, neginf=999 + ) + return results + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "graphene_wetting_under_strain_metrics_table.json", + metric_tooltips=DEFAULT_TOOLTIPS, + thresholds=DEFAULT_THRESHOLDS, + weights=DEFAULT_WEIGHTS, +) +def metrics( + all_adsorption_energies_mae: dict[str, float], + binding_energies_mae: dict[str, float], + binding_lengths_mae: dict[str, float], +) -> dict[str, dict]: + """ + Get all graphene wetting metrics. + + Parameters + ---------- + all_adsorption_energies_mae + Mean absolute errors across all orientations, distances, and strains for all + models. + binding_energies_mae + Mean absolute errors of binding energies across all orientations and strains + for all models. + binding_lengths_mae + Mean absolute errors of binding lengths across all orientations and strains for + all models. + + Returns + ------- + dict[str, dict] + Metric names and values for all models. + """ + return { + "All Adsorption Energies MAE": all_adsorption_energies_mae, + "Binding Energies MAE": binding_energies_mae, + "Binding Lengths MAE": binding_lengths_mae, + } + + +def test_graphene_wetting_under_strain(metrics: dict[str, dict]) -> None: + """ + Run graphene wetting test. + + Parameters + ---------- + metrics + All graphene wetting metrics. + """ + return diff --git a/ml_peg/analysis/surfaces/graphene_wetting_under_strain/metrics.yml b/ml_peg/analysis/surfaces/graphene_wetting_under_strain/metrics.yml new file mode 100644 index 000000000..711268165 --- /dev/null +++ b/ml_peg/analysis/surfaces/graphene_wetting_under_strain/metrics.yml @@ -0,0 +1,22 @@ +metrics: + All Adsorption Energies MAE: + good: 40.0 + bad: 1000.0 + unit: meV + weight: 1.0 + tooltip: Mean Absolute Error across all orientations, distances, and strains + level_of_theory: PBE + Binding Energies MAE: + good: 40.0 + bad: 1000.0 + unit: meV + weight: 1.0 + tooltip: Mean Absolute Error of binding energies across all orientations and strains + level_of_theory: PBE + Binding Lengths MAE: + good: 0.0 + bad: 1.0 + unit: Å + weight: 1.0 + tooltip: Mean Absolute Error of binding lengths across all orientations and strains + level_of_theory: PBE diff --git a/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py b/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py new file mode 100644 index 000000000..3fab93e07 --- /dev/null +++ b/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py @@ -0,0 +1,373 @@ +"""Run graphene wetting under strain app.""" + +from __future__ import annotations + +from dash import Dash, Input, Output, callback +from dash.exceptions import PreventUpdate +from dash.html import Div, Iframe +import yaml + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.app.utils.build_callbacks import plot_from_scatter +from ml_peg.app.utils.load import read_plot +from ml_peg.app.utils.weas import generate_weas_html +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +# Get all models +MODELS = get_model_names(current_models) +BENCHMARK_NAME = "Graphene Wetting Under Strain" +DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/surfaces.html#graphene-wetting-under-strain" +DATA_PATH = APP_ROOT / "data" / "surfaces" / "graphene_wetting_under_strain" + +with open( + CALCS_ROOT / "surfaces" / "graphene_wetting_under_strain" / "database_info.yml" +) as fp: + DATABASE_INFO = yaml.safe_load(fp) +ORIENTATIONS = DATABASE_INFO["orientations"] +STRAINS = DATABASE_INFO["strains"] + + +def plot_from_table_cell_with_erasure( + table_id, plot_id, plot_id_erase, struct_id_erase, cell_to_plot +): + """ + Attach callback to show plot when a table cell is clicked. + + Parameters + ---------- + table_id + ID for Dash table being clicked. + plot_id + ID for Dash plot placeholder Div. + plot_id_erase + ID of plot to be erased. + struct_id_erase + ID of struct to be erased. + cell_to_plot + Nested dictionary of model names, column names, and plot to show. + """ + + @callback( + Output(plot_id, "children", allow_duplicate=True), + Output(plot_id_erase, "children", allow_duplicate=True), + Output(struct_id_erase, "children", allow_duplicate=True), + Input(table_id, "active_cell"), + Input(table_id, "data"), + prevent_initial_call="initial_duplicate", + ) + def show_plot(active_cell, current_table_data): + """ + Register callback to show plot when a table cell is clicked. + + Parameters + ---------- + active_cell + Clicked cell in Dash table. + current_table_data + Current table data (includes live updates from callbacks). + + Returns + ------- + Div + Message explaining interactivity, or plot on cell click. + Div + Message explaining interactivity. + Div + Message explaining interactivity. + """ + if not active_cell: + return ( + Div("Click on a metric to view plot."), + Div("Click on a metric to view plot."), + Div("Click on a metric to view the structure."), + ) + column_id = active_cell.get("column_id", None) + row_id = active_cell.get("row_id", None) + row_index = active_cell.get("row", None) + if current_table_data and row_index is not None: + try: + cell_value = current_table_data[row_index].get(column_id) + if cell_value is None: + return ( + Div("No data available for this model."), + Div("Click on a metric to view plot."), + Div("Click on a metric to view the structure."), + ) + except (IndexError, KeyError, TypeError): + pass + if row_id in cell_to_plot and column_id in cell_to_plot[row_id]: + return ( + Div(cell_to_plot[row_id][column_id]), + Div("Click on a metric to view plot."), + Div("Click on a metric to view the structure."), + ) + return ( + Div("Click on a metric to view plot."), + Div("Click on a metric to view plot."), + Div("Click on a metric to view the structure."), + ) + + +def plot_and_struct_from_scatter(scatter_id, plot_id, plots_list, struct_id, structs): + """ + Attach callback to show a structure when a scatter point is clicked. + + Parameters + ---------- + scatter_id + ID for Dash scatter being clicked. + plot_id + ID for Dash plot placeholder Div where new plot will be rendered. + plots_list + List of plots to show, in same order as scatter data. + struct_id + ID for Dash plot placeholder Div where structures will be visualised. + structs + List of structure filenames in same order as scatter data to be visualised. + """ + + @callback( + Output(plot_id, "children", allow_duplicate=True), + Output(struct_id, "children", allow_duplicate=True), + Input(scatter_id, "clickData"), + prevent_initial_call="initial_duplicate", + ) + def show_plot_and_struct(click_data): + """ + Register callback to show plot and structure when a scatter point is clicked. + + Parameters + ---------- + click_data + Clicked data point in scatter plot. + + Returns + ------- + Div + Plot on scatter click. + Div + Visualised structure on plot click. + """ + if not click_data: + return Div("Click on a metric to view plot."), Div( + "Click on a metric to view the structure." + ) + idx = click_data["points"][0]["pointNumber"] + return Div(plots_list[idx]), Div( + Iframe( + srcDoc=generate_weas_html(structs[idx], "struct", 0), + style={ + "height": "550px", + "width": "100%", + "border": "1px solid #ddd", + "borderRadius": "5px", + }, + ) + ) + + +def struct_from_scatter_custom(scatter_id, struct_id, structs): + """ + Attach callback to show a structure when a scatter point is clicked. + + Parameters + ---------- + scatter_id + ID for Dash scatter being clicked. + struct_id + ID for Dash plot placeholder Div where structures will be visualised. + structs + List of structure filenames in same order as scatter data to be visualised. + """ + + @callback( + Output(struct_id, "children", allow_duplicate=True), + Input(scatter_id, "clickData"), + prevent_initial_call="initial_duplicate", + ) + def show_struct(click_data): + """ + Register callback to show structure when a scatter point is clicked. + + Parameters + ---------- + click_data + Clicked data point in scatter plot. + + Returns + ------- + Div + Visualised structure on plot click. + """ + if not click_data: + raise PreventUpdate() + idx = click_data["points"][0]["pointNumber"] + return Div( + Iframe( + srcDoc=generate_weas_html(structs[idx], "struct", 0), + style={ + "height": "550px", + "width": "100%", + "border": "1px solid #ddd", + "borderRadius": "5px", + }, + ) + ) + + +class GrapheneWettingUnderStrainApp(BaseApp): + """Graphene wetting under strain benchmark app layout and callbacks.""" + + def register_callbacks(self) -> None: + """Register callbacks to app.""" + parity_plots = { + model: { + "All Adsorption Energies MAE": read_plot( + DATA_PATH / model / "figure_all_parity.json", + id=f"{BENCHMARK_NAME}-{model}-all-parity-figure", + ), + "Binding Energies MAE": read_plot( + DATA_PATH / model / "figure_binding_energies_parity.json", + id=f"{BENCHMARK_NAME}-{model}-binding-energies-parity-figure", + ), + "Binding Lengths MAE": read_plot( + DATA_PATH / model / "figure_binding_lengths_parity.json", + id=f"{BENCHMARK_NAME}-{model}-binding-lengths-parity-figure", + ), + } + for model in MODELS + } + + structs_from_all = {model: [] for model in MODELS} + for model in MODELS: + for orientation in ORIENTATIONS: + for strain in STRAINS: + xyz_files = sorted( + (DATA_PATH / model / "structs").glob( + f"{orientation}_{strain}_L*.xyz" + ) + ) + for xyz_file in xyz_files: + structs_from_all[model].append( + f"assets/surfaces/graphene_wetting_under_strain/{model}/structs/{xyz_file.name}" + ) + + binding_curve_plots = {model: [] for model in MODELS} + for model in MODELS: + for orientation in ORIENTATIONS: + for strain in STRAINS: + binding_curve_plots[model].append( + read_plot( + DATA_PATH / model / f"figure_{orientation}_{strain}.json", + id=f"{BENCHMARK_NAME}-{model}-{orientation}-{strain[0:3]}", + ) + ) + + n_distances = len( + list( + (DATA_PATH / MODELS[0] / "structs").glob( + f"{ORIENTATIONS[0]}_{STRAINS[0]}_L*.xyz" + ) + ) + ) + curve_plots_from_all = {model: [] for model in MODELS} + for model in MODELS: + for i in range(len(ORIENTATIONS)): + for j in range(len(STRAINS)): + idx = (i * len(STRAINS)) + j + for _ in range(n_distances): + curve_plots_from_all[model].append( + binding_curve_plots[model][idx] + ) + + structs_from_binding_curves = {model: {} for model in MODELS} + for model in MODELS: + structs_from_binding_curves[model] = { + orientation: {} for orientation in ORIENTATIONS + } + for orientation in ORIENTATIONS: + structs_from_binding_curves[model][orientation] = { + strain: [] for strain in STRAINS + } + for strain in STRAINS: + xyz_files = sorted( + (DATA_PATH / model / "structs").glob( + f"{orientation}_{strain}_L*.xyz" + ) + ) + for xyz_file in xyz_files: + structs_from_binding_curves[model][orientation][strain].append( + f"assets/surfaces/graphene_wetting_under_strain/{model}/structs/{xyz_file.name}" + ) + + plot_from_table_cell_with_erasure( + table_id=self.table_id, + plot_id=f"{BENCHMARK_NAME}-figure-placeholder", + plot_id_erase=f"{BENCHMARK_NAME}-subfigure-placeholder", + struct_id_erase=f"{BENCHMARK_NAME}-struct-placeholder", + cell_to_plot=parity_plots, + ) + + for model in MODELS: + plot_and_struct_from_scatter( + scatter_id=f"{BENCHMARK_NAME}-{model}-all-parity-figure", + plot_id=f"{BENCHMARK_NAME}-subfigure-placeholder", + plots_list=curve_plots_from_all[model], + struct_id=f"{BENCHMARK_NAME}-struct-placeholder", + structs=structs_from_all[model], + ) + plot_from_scatter( + scatter_id=f"{BENCHMARK_NAME}-{model}-binding-energies-parity-figure", + plot_id=f"{BENCHMARK_NAME}-subfigure-placeholder", + plots_list=binding_curve_plots[model], + ) + plot_from_scatter( + scatter_id=f"{BENCHMARK_NAME}-{model}-binding-lengths-parity-figure", + plot_id=f"{BENCHMARK_NAME}-subfigure-placeholder", + plots_list=binding_curve_plots[model], + ) + for orientation in ORIENTATIONS: + for strain in STRAINS: + struct_from_scatter_custom( + scatter_id=f"{BENCHMARK_NAME}-{model}-{orientation}-{strain[0:3]}", + struct_id=f"{BENCHMARK_NAME}-struct-placeholder", + structs=structs_from_binding_curves[model][orientation][strain], + ) + + +def get_app() -> GrapheneWettingUnderStrainApp: + """ + Get graphene wetting under strain benchmark app layout and callback registration. + + Returns + ------- + GrapheneWettingUnderStrainApp + Benchmark layout and callback registration. + """ + return GrapheneWettingUnderStrainApp( + name=BENCHMARK_NAME, + description=("Adsorption energies for water on graphene."), + docs_url=DOCS_URL, + table_path=DATA_PATH / "graphene_wetting_under_strain_metrics_table.json", + extra_components=[ + Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + Div(id=f"{BENCHMARK_NAME}-subfigure-placeholder"), + Div(id=f"{BENCHMARK_NAME}-struct-placeholder"), + ], + ) + + +if __name__ == "__main__": + # Create Dash app + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + + # Construct layout and register callbacks + graphene_wetting_under_strain_app = get_app() + full_app.layout = graphene_wetting_under_strain_app.layout + graphene_wetting_under_strain_app.register_callbacks() + + # Run app + full_app.run(port=8052, debug=True) diff --git a/ml_peg/app/utils/build_callbacks.py b/ml_peg/app/utils/build_callbacks.py index 5714c71a9..4c43e1fd9 100644 --- a/ml_peg/app/utils/build_callbacks.py +++ b/ml_peg/app/utils/build_callbacks.py @@ -136,7 +136,7 @@ def plot_from_scatter( plots_list: list[Graph], ) -> None: """ - Attach callback to show plot when a table cell is clicked. + Attach callback to show plot when a scatter point is clicked. Parameters ---------- diff --git a/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py new file mode 100644 index 000000000..6183a32fb --- /dev/null +++ b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py @@ -0,0 +1,85 @@ +"""Run calculations for graphene wetting under strain benchmark.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any + +import ase.io +import pytest +import yaml + +from ml_peg.calcs.utils.utils import download_s3_data +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +MODELS = load_models(current_models) + +DATA_PATH = Path(__file__).parent / "data" +OUT_PATH = Path(__file__).parent / "outputs" + +with open(Path(__file__).parent / "database_info.yml") as fp: + DATABASE_INFO = yaml.safe_load(fp) + +# List of orientations used in the database +ORIENTATIONS = DATABASE_INFO["orientations"] + +# List of strains used in the database +STRAINS = DATABASE_INFO["strains"] + + +@pytest.mark.parametrize("mlip", MODELS.items()) +def test_graphene_wetting_energy(mlip: tuple[str, Any]) -> None: + """ + Run graphene wetting adsorption energy test. + + Parameters + ---------- + mlip + Name of model use and model to get calculator. + """ + model_name, model = mlip + calc = model.get_calculator() + write_dir = OUT_PATH / model_name + write_dir.mkdir(parents=True, exist_ok=True) + + # Add D3 calculator for this test (for models where applicable) + calc = model.add_d3_calculator(calc) + + # Download dataset + structs_dir = ( + download_s3_data( + key="inputs/surfaces/graphene_wetting_under_strain/graphene_wetting_under_strain.zip", + filename="graphene_wetting_under_strain.zip", + ) + / "graphene_wetting_under_strain" + ) + + # Calculate energy of single water molecule + atoms = ase.io.read(structs_dir / "ref_water.xyz", format="extxyz") + atoms.calc = calc + water_energy = atoms.get_potential_energy() + + # Iterate through strain conditions + for strain in STRAINS: + atoms = ase.io.read(structs_dir / f"ref_graphene_{strain}.xyz", format="extxyz") + atoms.calc = calc + graphene_energy = atoms.get_potential_energy() + + # Iterate through orientations + for orientation in ORIENTATIONS: + systems = ase.io.iread( + structs_dir / f"{orientation}_{strain}.xyz", index=":", format="extxyz" + ) + write_file = write_dir / f"{orientation}_{strain}.xyz" + if os.path.isfile(write_file): + os.remove(write_file) + for atoms in systems: + atoms.calc = calc + mlip_potential_energy = atoms.get_potential_energy() + mlip_adsorption_energy = ( + mlip_potential_energy - graphene_energy - water_energy + ) + atoms.info["mlip_adsorption_energy"] = mlip_adsorption_energy + ase.io.write(write_file, atoms, append=True) diff --git a/ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml new file mode 100644 index 000000000..cbbea5d6e --- /dev/null +++ b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml @@ -0,0 +1,8 @@ +orientations: +- 0-leg +- 1-leg +- 2-leg +strains: +- s+0.00 +- s+1.00 +- s+2.00 From ab24795a95ce6b19989d5e1ea4b9ece9a3001ea1 Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Tue, 3 Feb 2026 22:14:15 +0000 Subject: [PATCH 3/3] update logic for database info yml -> include in s3 --- .../analyse_graphene_wetting_under_strain.py | 4 +-- .../app_graphene_wetting_under_strain.py | 5 ++-- .../calc_graphene_wetting_under_strain.py | 29 ++++++++++++------- .../database_info.yml | 8 ----- 4 files changed, 22 insertions(+), 24 deletions(-) delete mode 100644 ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml diff --git a/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py b/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py index 90d079b00..7df90957e 100644 --- a/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py +++ b/ml_peg/analysis/surfaces/graphene_wetting_under_strain/analyse_graphene_wetting_under_strain.py @@ -27,9 +27,7 @@ METRICS_CONFIG_PATH ) -with open( - CALCS_ROOT / "surfaces" / "graphene_wetting_under_strain" / "database_info.yml" -) as fp: +with open(CALC_PATH / "database_info.yml") as fp: DATABASE_INFO = yaml.safe_load(fp) ORIENTATIONS = DATABASE_INFO["orientations"] STRAINS = DATABASE_INFO["strains"] diff --git a/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py b/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py index 3fab93e07..e3388d448 100644 --- a/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py +++ b/ml_peg/app/surfaces/graphene_wetting_under_strain/app_graphene_wetting_under_strain.py @@ -22,9 +22,8 @@ DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/surfaces.html#graphene-wetting-under-strain" DATA_PATH = APP_ROOT / "data" / "surfaces" / "graphene_wetting_under_strain" -with open( - CALCS_ROOT / "surfaces" / "graphene_wetting_under_strain" / "database_info.yml" -) as fp: +CALC_PATH = CALCS_ROOT / "surfaces" / "graphene_wetting_under_strain" / "outputs" +with open(CALC_PATH / "database_info.yml") as fp: DATABASE_INFO = yaml.safe_load(fp) ORIENTATIONS = DATABASE_INFO["orientations"] STRAINS = DATABASE_INFO["strains"] diff --git a/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py index 6183a32fb..c2c907b3b 100644 --- a/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py +++ b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/calc_graphene_wetting_under_strain.py @@ -19,14 +19,7 @@ DATA_PATH = Path(__file__).parent / "data" OUT_PATH = Path(__file__).parent / "outputs" -with open(Path(__file__).parent / "database_info.yml") as fp: - DATABASE_INFO = yaml.safe_load(fp) - -# List of orientations used in the database -ORIENTATIONS = DATABASE_INFO["orientations"] - -# List of strains used in the database -STRAINS = DATABASE_INFO["strains"] +DATABASE_INFO_SAVED = False @pytest.mark.parametrize("mlip", MODELS.items()) @@ -56,19 +49,35 @@ def test_graphene_wetting_energy(mlip: tuple[str, Any]) -> None: / "graphene_wetting_under_strain" ) + db_info_path = Path(structs_dir) / "database_info.yml" + with open(db_info_path) as fp: + database_info = yaml.safe_load(fp) + orientations = database_info["orientations"] + strains = database_info["strains"] + + # save database info for use in analysis + # (without needing to redownload to get the path) + global DATABASE_INFO_SAVED + if not DATABASE_INFO_SAVED: + OUT_PATH.mkdir(parents=True, exist_ok=True) + database_info_path = OUT_PATH / "database_info.yml" + with database_info_path.open("w", encoding="utf-8") as target_fp: + yaml.safe_dump(database_info, target_fp, sort_keys=False) + DATABASE_INFO_SAVED = True + # Calculate energy of single water molecule atoms = ase.io.read(structs_dir / "ref_water.xyz", format="extxyz") atoms.calc = calc water_energy = atoms.get_potential_energy() # Iterate through strain conditions - for strain in STRAINS: + for strain in strains: atoms = ase.io.read(structs_dir / f"ref_graphene_{strain}.xyz", format="extxyz") atoms.calc = calc graphene_energy = atoms.get_potential_energy() # Iterate through orientations - for orientation in ORIENTATIONS: + for orientation in orientations: systems = ase.io.iread( structs_dir / f"{orientation}_{strain}.xyz", index=":", format="extxyz" ) diff --git a/ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml b/ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml deleted file mode 100644 index cbbea5d6e..000000000 --- a/ml_peg/calcs/surfaces/graphene_wetting_under_strain/database_info.yml +++ /dev/null @@ -1,8 +0,0 @@ -orientations: -- 0-leg -- 1-leg -- 2-leg -strains: -- s+0.00 -- s+1.00 -- s+2.00