Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ml_peg/analysis/element_sets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
element_sets:
all:
name: All elements
description: Include every structure regardless of composition.
elements: null
hcno:
name: HCNO
description: Structures composed exclusively of hydrogen, carbon, nitrogen, or oxygen.
elements: [H, C, N, O]
238 changes: 143 additions & 95 deletions ml_peg/analysis/supramolecular/LNCI16/analyse_LNCI16.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
from __future__ import annotations

from pathlib import Path
import shutil
from typing import Any

from ase.io import read
import pytest

from ml_peg.analysis.utils.decorators import build_table, plot_parity
from ml_peg.analysis.utils.element_filters import (
build_element_set_masks,
filter_hoverdata_dict,
filter_results_dict,
load_element_sets,
write_element_sets_summary_file,
)
from ml_peg.analysis.utils.utils import load_metrics_config, mae
from ml_peg.app import APP_ROOT
from ml_peg.calcs import CALCS_ROOT
Expand All @@ -22,103 +31,156 @@
DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config(
METRICS_CONFIG_PATH
)
ELEMENT_SETS_CONFIG_PATH = Path(__file__).parents[2] / "element_sets.yml"
ELEMENT_SETS = load_element_sets(ELEMENT_SETS_CONFIG_PATH)


def get_system_names() -> list[str]:
def get_structure_info() -> list[dict[str, Any]]:
"""
Get list of LNCI16 system names.
Get structure information for LNCI16 systems.

This reads structures from the first model folder that contains data.
The returned list is used as the single ordering for hover labels,
filtering, and saved structure-index mapping.

Returns
-------
list[str]
List of system names from structure files.
list[dict[str, Any]]
One dictionary per structure with keys:
``system``, ``atom_count``, ``charge``, ``is_charged``,
``elements``, and ``index``.
"""
system_names = []
for model_name in MODELS:
model_dir = CALC_PATH / model_name
if model_dir.exists():
xyz_files = sorted(model_dir.glob("*.xyz"))
if xyz_files:
for xyz_file in xyz_files:
structure_metadata: list[dict[str, Any]] = []
for index, xyz_file in enumerate(xyz_files):
atoms = read(xyz_file)
system_names.append(
atoms.info.get("system", f"system_{xyz_file.stem}")
charge = int(atoms.info.get("complex_charge", 0))
structure_metadata.append(
{
"system": atoms.info.get(
"system", f"system_{xyz_file.stem}"
),
"atom_count": len(atoms),
"charge": charge,
"is_charged": charge != 0,
"elements": sorted(set(atoms.get_chemical_symbols())),
"index": index,
}
)
break
return system_names
return structure_metadata
return []


def get_atom_counts() -> list[int]:
def build_hoverdata_from_structure_info(
structure_info: list[dict[str, Any]],
) -> dict[str, list]:
"""
Get complex atom counts for LNCI16.
Build hover labels for the predicted-vs-reference scatter plot.

Parameters
----------
structure_info
Structure information returned by :func:`get_structure_info`.

Returns
-------
list[int]
List of complex atom counts from structure files.
dict[str, list]
Hover label columns mapped to values in structure order.
"""
from ase.io import read
return {
"System": [entry["system"] for entry in structure_info],
"Elements": ["".join(entry["elements"]) for entry in structure_info],
"Complex Atoms": [entry["atom_count"] for entry in structure_info],
"Charge": [entry["charge"] for entry in structure_info],
"Charged": [entry["is_charged"] for entry in structure_info],
}

for model_name in MODELS:
model_dir = CALC_PATH / model_name
if model_dir.exists():
xyz_files = sorted(model_dir.glob("*.xyz"))
if xyz_files:
atom_counts = []
for xyz_file in xyz_files:
atoms = read(xyz_file)
atom_counts.append(len(atoms))
return atom_counts
return []

STRUCTURE_INFO = get_structure_info()
HOVERDATA = build_hoverdata_from_structure_info(STRUCTURE_INFO)

def get_charges() -> list[int]:

def compute_lnci16_mae(energies: dict[str, list]) -> dict[str, float | None]:
"""
Get complex charges for LNCI16.
Compute mean absolute error (MAE) for each model.

Parameters
----------
energies
Interaction energies with keys ``ref`` and one key per model.

Returns
-------
list[int]
List of complex charges from structure files.
dict[str, float | None]
MAE value for each model. Returns ``None`` for models with no data.
"""
from ase.io import read

results: dict[str, float | None] = {}
ref_values = energies["ref"]
for model_name in MODELS:
model_dir = CALC_PATH / model_name
if model_dir.exists():
xyz_files = sorted(model_dir.glob("*.xyz"))
if xyz_files:
charges = []
for xyz_file in xyz_files:
atoms = read(xyz_file)
charges.append(atoms.info.get("complex_charge", 0))
return charges
return []
model_values = energies[model_name]
if ref_values and model_values:
results[model_name] = mae(ref_values, model_values)
else:
results[model_name] = None
return results


def get_is_charged() -> list[bool]:
def write_lnci16_element_set_outputs(interaction_energies: dict[str, list]) -> None:
"""
Get whether systems are charged for LNCI16.
Write filtered LNCI16 outputs for every configured element set.

Returns
-------
list[bool]
List of boolean values indicating if systems are charged.
"""
from ase.io import read
For each element set (for example ``all`` and ``hcno``), this writes:
1. A filtered predicted-vs-reference scatter plot JSON.
2. A filtered metrics table JSON.
3. A ``element_sets.json`` file with counts and original
structure positions.

for model_name in MODELS:
model_dir = CALC_PATH / model_name
if model_dir.exists():
xyz_files = sorted(model_dir.glob("*.xyz"))
if xyz_files:
is_charged = []
for xyz_file in xyz_files:
atoms = read(xyz_file)
charge = atoms.info.get("complex_charge", 0)
is_charged.append(charge != 0)
return is_charged
return []
Parameters
----------
interaction_energies
Full LNCI16 interaction energies before filtering.
"""
structure_elements = [set(entry["elements"]) for entry in STRUCTURE_INFO]
element_set_masks = build_element_set_masks(structure_elements, ELEMENT_SETS)

for element_set_key, element_set_mask in element_set_masks.items():
filtered_results = filter_results_dict(interaction_energies, element_set_mask)
filtered_hoverdata = filter_hoverdata_dict(HOVERDATA, element_set_mask)
filtered_mae = compute_lnci16_mae(filtered_results)
element_set_out_path = OUT_PATH / "element_sets" / element_set_key

@plot_parity(
filename=element_set_out_path / "figure_interaction_energies.json",
title="LNCI16 Interaction Energies",
x_label="Predicted interaction energy / kcal/mol",
y_label="Reference interaction energy / kcal/mol",
hoverdata=filtered_hoverdata,
)
def _filtered_interaction_energies(
results: dict[str, list] = filtered_results,
) -> dict[str, list]:
return results

_filtered_interaction_energies()

@build_table(
filename=element_set_out_path / "lnci16_metrics_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
weights=DEFAULT_WEIGHTS,
)
def _filtered_metrics(
mae_by_model: dict[str, float | None] = filtered_mae,
) -> dict[str, dict]:
return {"MAE": mae_by_model}

_filtered_metrics()

write_element_sets_summary_file(OUT_PATH, ELEMENT_SETS, element_set_masks)


@pytest.fixture
Expand All @@ -127,37 +189,33 @@ def get_is_charged() -> list[bool]:
title="LNCI16 Interaction Energies",
x_label="Predicted interaction energy / kcal/mol",
y_label="Reference interaction energy / kcal/mol",
hoverdata={
"System": get_system_names(),
"Complex Atoms": get_atom_counts(),
"Charge": get_charges(),
"Charged": get_is_charged(),
},
hoverdata=HOVERDATA,
)
def interaction_energies() -> dict[str, list]:
"""
Get interaction energies for all LNCI16 systems.

This fixture also copies structure files for the app and writes
element-set specific outputs used by the element-set selector.

Returns
-------
dict[str, list]
Dictionary of reference and predicted interaction energies.
"""
from ase.io import read

results = {"ref": []} | {mlip: [] for mlip in MODELS}
ref_stored = False
interaction_energy_results = {"ref": []} | {mlip: [] for mlip in MODELS}
reference_is_stored = False

for model_name in MODELS:
model_dir = CALC_PATH / model_name
model_output_dir = CALC_PATH / model_name

if not model_dir.exists():
results[model_name] = []
if not model_output_dir.exists():
interaction_energy_results[model_name] = []
continue

xyz_files = sorted(model_dir.glob("*.xyz"))
xyz_files = sorted(model_output_dir.glob("*.xyz"))
if not xyz_files:
results[model_name] = []
interaction_energy_results[model_name] = []
continue

model_energies = []
Expand All @@ -166,27 +224,25 @@ def interaction_energies() -> dict[str, list]:
for xyz_file in xyz_files:
atoms = read(xyz_file)
model_energies.append(atoms.info["E_int_model_kcal"])
if not ref_stored:
if not reference_is_stored:
ref_energies.append(atoms.info["E_int_ref_kcal"])

results[model_name] = model_energies
interaction_energy_results[model_name] = model_energies

# Store reference energies (only once)
if not ref_stored:
results["ref"] = ref_energies
ref_stored = True
if not reference_is_stored:
interaction_energy_results["ref"] = ref_energies
reference_is_stored = True

# Copy individual structure files to app data directory
structs_dir = OUT_PATH / model_name
structs_dir.mkdir(parents=True, exist_ok=True)

# Copy individual structure files
import shutil

for i, xyz_file in enumerate(xyz_files):
shutil.copy(xyz_file, structs_dir / f"{i}.xyz")

return results
write_lnci16_element_set_outputs(interaction_energy_results)
return interaction_energy_results


@pytest.fixture
Expand All @@ -204,15 +260,7 @@ def lnci16_mae(interaction_energies) -> dict[str, float]:
dict[str, float]
Dictionary of predicted interaction energy errors for all models.
"""
results = {}
for model_name in MODELS:
if interaction_energies[model_name]:
results[model_name] = mae(
interaction_energies["ref"], interaction_energies[model_name]
)
else:
results[model_name] = None
return results
return compute_lnci16_mae(interaction_energies)


@pytest.fixture
Expand Down
Loading