Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions docs/source/tutorials/python/adding_benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "38b6013d",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: line 1: pre-commit: command not found\n"
]
}
],
"source": [
"! pre-commit install"
]
Expand All @@ -130,10 +138,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "5f033938",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"M\tdocs/source/tutorials/python/adding_benchmark.ipynb\n",
"M\tuv.lock\n",
"Already on 'main'\n",
"Your branch is up to date with 'origin/main'.\n",
"From https://github.com/Quantumplations/ml-peg-fork\n",
" * branch main -> FETCH_HEAD\n",
"Already up to date.\n",
"Switched to a new branch 'my_new_benchmark'\n"
]
}
],
"source": [
"! git checkout main\n",
"! git pull origin main\n",
Expand Down Expand Up @@ -959,7 +982,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "ml-peg (3.12.8)",
"display_name": "jax_linus",
"language": "python",
"name": "python3"
},
Expand All @@ -973,7 +996,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
219 changes: 219 additions & 0 deletions ml_peg/analysis/molecular/rxn_barriers/analyse_rxn_barriers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Analyse CRBH20 Reaction Barriers benchmark."""

from __future__ import annotations

from pathlib import Path
import re

from ase import units
from ase.io import read, write
import pytest

# ml_peg imports
from ml_peg.analysis.utils.decorators import build_table, plot_parity
from ml_peg.analysis.utils.utils import build_d3_name_map, load_metrics_config, mae
from ml_peg.app import APP_ROOT
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

# --- Configuration ---
MODELS = get_model_names(current_models)
D3_MODEL_NAMES = build_d3_name_map(MODELS)

# Path to where the calc script outputted the data
# Update this to match your actual folder structure
CALC_PATH = CALCS_ROOT / "molecular" / "rxn_barriers" / "outputs"

# Path where this analysis script will save data for the Streamlit App
OUT_PATH = APP_ROOT / "data" / "reaction_barriers" / "CRBH20"

# Load metrics configuration (thresholds for green/red coloring in tables)
METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml")
# If the file doesn't exist, we provide defaults, but usually it should exist.
try:
DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config(METRICS_CONFIG_PATH)
except FileNotFoundError:
# Fallback defaults if metrics.yml is missing
DEFAULT_THRESHOLDS = {"MAE": [1.0, 5.0]} # Green < 1.0, Red > 5.0
DEFAULT_TOOLTIPS = {"MAE": "Mean Absolute Error"}
DEFAULT_WEIGHTS = {}

# --- Reference Data (Appendix B.5 of arXiv:2401.00096) ---
# The paper compares MACE against these specific DFT (r2SCAN) values.
# Original unit was eV. Converted here to kcal/mol (1 eV = 23.0605 kcal/mol)
REF_BARRIERS_KCAL = {
1: 1.7194 * 23.0605,
2: 1.9241 * 23.0605,
3: 1.7499 * 23.0605,
4: 1.8238 * 23.0605,
5: 1.7237 * 23.0605,
6: 1.5653 * 23.0605,
7: 1.0911 * 23.0605,
8: 1.8983 * 23.0605,
9: 1.5477 * 23.0605,
10: 1.7115 * 23.0605,
11: 1.7379 * 23.0605,
12: 2.0361 * 23.0605,
13: 1.8739 * 23.0605,
14: 1.9760 * 23.0605,
15: 1.8865 * 23.0605,
16: 1.5741 * 23.0605,
17: 1.2587 * 23.0605,
18: 1.7497 * 23.0605,
19: 1.6989 * 23.0605,
20: 1.7654 * 23.0605
}
def get_reaction_ids() -> list[str]:
"""
Get list of Reaction IDs for plotting hover data.
We just use 1..20, sorted.
"""
return [str(i) for i in range(1, 21)]

def numeric_sort_key(filepath: Path):
"""Sort helper to ensure files 1, 2, ... 10 come in numerical order, not alpha."""
# Extract the number from 'crbh20_12.xyz'
match = re.search(r'crbh20_(\d+).xyz', filepath.name)
if match:
return int(match.group(1))
return 0

@pytest.fixture
@plot_parity(
filename=OUT_PATH / "figure_reaction_barriers.json",
title="CRBH20 Reaction Barriers",
x_label="Predicted Barrier (kcal/mol)",
y_label="Reference Barrier (kcal/mol)",
hoverdata={
"Reaction ID": get_reaction_ids(),
},
)
def reaction_barriers() -> dict[str, list]:
"""
Get barriers for all CRBH20 systems.

Returns
-------
dict[str, list]
Dictionary of reference and predicted barriers in kcal/mol.
Format: {'ref': [10.53, ...], 'mace-mp-0b3': [10.2, ...]}
"""
# --- DEBUGGING START ---
print(f"\nDEBUG: CALC_PATH is set to: {CALC_PATH.resolve()}")
print(f"DEBUG: OUT_PATH is set to: {OUT_PATH.resolve()}")
# --- DEBUGGING END ---
results = {"ref": []} | {mlip: [] for mlip in MODELS}
ref_stored = False

# We iterate 1..20 to ensure the lists are perfectly aligned
rxn_ids = range(1, 21)

for model_name in MODELS:
model_dir = CALC_PATH / model_name

# --- DEBUGGING START ---
print(f"Checking for model: {model_name:<20}", end="")
if not model_dir.exists():
print(f"[MISSING] -> Skipped {model_dir}")
continue
print(f"[FOUND]")
# --- DEBUGGING END ---
if not model_dir.exists():
continue

# Temporary list to ensure we collect this model's data in 1..20 order
model_barriers = []

for rxn_id in rxn_ids:
# Construct expected filename
xyz_file = model_dir / f"crbh20_{rxn_id}.xyz"

if not xyz_file.exists():
# Handle missing data (e.g., if calc failed)
# For parity plots, lists must be equal length.
# We append None or NaN, though ml-peg might prefer dropping the point.
# Here we assume completeness or append 0.0 with a warning.
model_barriers.append(None)
if not ref_stored: results["ref"].append(REF_BARRIERS_KCAL[rxn_id])
continue

# Read the combined XYZ (Reactant is index 0, TS is index 1)
# We only need index 0 because we stored the barrier in info tag of both
structs = read(xyz_file, index=":")
reactant = structs[0]

# Extract ML Barrier (calculated in the previous script)
# stored as "barrier_kcal"
barrier_ml = reactant.info.get("barrier_kcal", 0.0)
model_barriers.append(barrier_ml)

# Copy structure files to APP directory for visualization
# This allows the web app to show the molecule when you hover/click
structs_dir = OUT_PATH / model_name
structs_dir.mkdir(parents=True, exist_ok=True)
write(structs_dir / f"crbh20_{rxn_id}.xyz", structs)

# Store reference energies (only once, during the first model loop)
if not ref_stored:
ref_val = REF_BARRIERS_KCAL.get(rxn_id, 0.0)
results["ref"].append(ref_val)

# Update the main results dict
results[model_name] = model_barriers

# Mark reference as stored so we don't duplicate it
if any(x is not None for x in model_barriers):
ref_stored = True

return results

@pytest.fixture
def crbh20_errors(reaction_barriers) -> dict[str, float]:
"""
Compute Mean Absolute Error (MAE) for reaction barriers.
"""
results = {}
for model_name in MODELS:
if reaction_barriers.get(model_name):
# Filter out None values in case of failed calculations
y_true = []
y_pred = []
for r, p in zip(reaction_barriers["ref"], reaction_barriers[model_name]):
if r is not None and p is not None:
y_true.append(r)
y_pred.append(p)

if y_true:
results[model_name] = mae(y_true, y_pred)
else:
results[model_name] = None
else:
results[model_name] = None
return results

@pytest.fixture
@build_table(
filename=OUT_PATH / "crbh20_metrics_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
mlip_name_map=D3_MODEL_NAMES,
)
def metrics(crbh20_errors: dict[str, float]) -> dict[str, dict]:
"""
Compile all metrics for the table.
"""
return {
"MAE": crbh20_errors,
}

def test_crbh20_analysis(metrics: dict[str, dict]) -> None:
"""
Trigger the analysis pipeline.

The decorators on the fixtures above (@plot_parity, @build_table)
do the heavy lifting of saving the JSON files when this test runs.
"""
# Verify we actually calculated something
assert metrics is not None
assert "MAE" in metrics
7 changes: 7 additions & 0 deletions ml_peg/analysis/molecular/rxn_barriers/metrics.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metrics:
MAE:
tooltip: "Mean Absolute Error (kcal/mol)"
unit: kcal/mol
good: 1.0 # Any error below 1.0 will be Green (Chemical Accuracy)
bad: 3.0 # Any error above 3.0 will be Red
weight: 1.0
Loading
Loading