diff --git a/ml_peg/analysis/element_sets.yml b/ml_peg/analysis/element_sets.yml new file mode 100644 index 000000000..b43ed2250 --- /dev/null +++ b/ml_peg/analysis/element_sets.yml @@ -0,0 +1,9 @@ +element_sets: + all: + name: All elements + description: Include every structure regardless of composition. + elements: null + hcno: + name: HCNO + description: Structures composed exclusively of hydrogen, carbon, nitrogen, or oxygen. + elements: [H, C, N, O] diff --git a/ml_peg/analysis/supramolecular/LNCI16/analyse_LNCI16.py b/ml_peg/analysis/supramolecular/LNCI16/analyse_LNCI16.py index 1902a03b5..28e867887 100644 --- a/ml_peg/analysis/supramolecular/LNCI16/analyse_LNCI16.py +++ b/ml_peg/analysis/supramolecular/LNCI16/analyse_LNCI16.py @@ -3,11 +3,20 @@ from __future__ import annotations from pathlib import Path +import shutil +from typing import Any from ase.io import read import pytest from ml_peg.analysis.utils.decorators import build_table, plot_parity +from ml_peg.analysis.utils.element_filters import ( + build_element_set_masks, + filter_hoverdata_dict, + filter_results_dict, + load_element_sets, + write_element_sets_summary_file, +) from ml_peg.analysis.utils.utils import load_metrics_config, mae from ml_peg.app import APP_ROOT from ml_peg.calcs import CALCS_ROOT @@ -22,103 +31,156 @@ DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( METRICS_CONFIG_PATH ) +ELEMENT_SETS_CONFIG_PATH = Path(__file__).parents[2] / "element_sets.yml" +ELEMENT_SETS = load_element_sets(ELEMENT_SETS_CONFIG_PATH) -def get_system_names() -> list[str]: +def get_structure_info() -> list[dict[str, Any]]: """ - Get list of LNCI16 system names. + Get structure information for LNCI16 systems. + + This reads structures from the first model folder that contains data. + The returned list is used as the single ordering for hover labels, + filtering, and saved structure-index mapping. Returns ------- - list[str] - List of system names from structure files. + list[dict[str, Any]] + One dictionary per structure with keys: + ``system``, ``atom_count``, ``charge``, ``is_charged``, + ``elements``, and ``index``. """ - system_names = [] for model_name in MODELS: model_dir = CALC_PATH / model_name if model_dir.exists(): xyz_files = sorted(model_dir.glob("*.xyz")) if xyz_files: - for xyz_file in xyz_files: + structure_metadata: list[dict[str, Any]] = [] + for index, xyz_file in enumerate(xyz_files): atoms = read(xyz_file) - system_names.append( - atoms.info.get("system", f"system_{xyz_file.stem}") + charge = int(atoms.info.get("complex_charge", 0)) + structure_metadata.append( + { + "system": atoms.info.get( + "system", f"system_{xyz_file.stem}" + ), + "atom_count": len(atoms), + "charge": charge, + "is_charged": charge != 0, + "elements": sorted(set(atoms.get_chemical_symbols())), + "index": index, + } ) - break - return system_names + return structure_metadata + return [] -def get_atom_counts() -> list[int]: +def build_hoverdata_from_structure_info( + structure_info: list[dict[str, Any]], +) -> dict[str, list]: """ - Get complex atom counts for LNCI16. + Build hover labels for the predicted-vs-reference scatter plot. + + Parameters + ---------- + structure_info + Structure information returned by :func:`get_structure_info`. Returns ------- - list[int] - List of complex atom counts from structure files. + dict[str, list] + Hover label columns mapped to values in structure order. """ - from ase.io import read + return { + "System": [entry["system"] for entry in structure_info], + "Elements": ["".join(entry["elements"]) for entry in structure_info], + "Complex Atoms": [entry["atom_count"] for entry in structure_info], + "Charge": [entry["charge"] for entry in structure_info], + "Charged": [entry["is_charged"] for entry in structure_info], + } - for model_name in MODELS: - model_dir = CALC_PATH / model_name - if model_dir.exists(): - xyz_files = sorted(model_dir.glob("*.xyz")) - if xyz_files: - atom_counts = [] - for xyz_file in xyz_files: - atoms = read(xyz_file) - atom_counts.append(len(atoms)) - return atom_counts - return [] +STRUCTURE_INFO = get_structure_info() +HOVERDATA = build_hoverdata_from_structure_info(STRUCTURE_INFO) -def get_charges() -> list[int]: + +def compute_lnci16_mae(energies: dict[str, list]) -> dict[str, float | None]: """ - Get complex charges for LNCI16. + Compute mean absolute error (MAE) for each model. + + Parameters + ---------- + energies + Interaction energies with keys ``ref`` and one key per model. Returns ------- - list[int] - List of complex charges from structure files. + dict[str, float | None] + MAE value for each model. Returns ``None`` for models with no data. """ - from ase.io import read - + results: dict[str, float | None] = {} + ref_values = energies["ref"] for model_name in MODELS: - model_dir = CALC_PATH / model_name - if model_dir.exists(): - xyz_files = sorted(model_dir.glob("*.xyz")) - if xyz_files: - charges = [] - for xyz_file in xyz_files: - atoms = read(xyz_file) - charges.append(atoms.info.get("complex_charge", 0)) - return charges - return [] + model_values = energies[model_name] + if ref_values and model_values: + results[model_name] = mae(ref_values, model_values) + else: + results[model_name] = None + return results -def get_is_charged() -> list[bool]: +def write_lnci16_element_set_outputs(interaction_energies: dict[str, list]) -> None: """ - Get whether systems are charged for LNCI16. + Write filtered LNCI16 outputs for every configured element set. - Returns - ------- - list[bool] - List of boolean values indicating if systems are charged. - """ - from ase.io import read + For each element set (for example ``all`` and ``hcno``), this writes: + 1. A filtered predicted-vs-reference scatter plot JSON. + 2. A filtered metrics table JSON. + 3. A ``element_sets.json`` file with counts and original + structure positions. - for model_name in MODELS: - model_dir = CALC_PATH / model_name - if model_dir.exists(): - xyz_files = sorted(model_dir.glob("*.xyz")) - if xyz_files: - is_charged = [] - for xyz_file in xyz_files: - atoms = read(xyz_file) - charge = atoms.info.get("complex_charge", 0) - is_charged.append(charge != 0) - return is_charged - return [] + Parameters + ---------- + interaction_energies + Full LNCI16 interaction energies before filtering. + """ + structure_elements = [set(entry["elements"]) for entry in STRUCTURE_INFO] + element_set_masks = build_element_set_masks(structure_elements, ELEMENT_SETS) + + for element_set_key, element_set_mask in element_set_masks.items(): + filtered_results = filter_results_dict(interaction_energies, element_set_mask) + filtered_hoverdata = filter_hoverdata_dict(HOVERDATA, element_set_mask) + filtered_mae = compute_lnci16_mae(filtered_results) + element_set_out_path = OUT_PATH / "element_sets" / element_set_key + + @plot_parity( + filename=element_set_out_path / "figure_interaction_energies.json", + title="LNCI16 Interaction Energies", + x_label="Predicted interaction energy / kcal/mol", + y_label="Reference interaction energy / kcal/mol", + hoverdata=filtered_hoverdata, + ) + def _filtered_interaction_energies( + results: dict[str, list] = filtered_results, + ) -> dict[str, list]: + return results + + _filtered_interaction_energies() + + @build_table( + filename=element_set_out_path / "lnci16_metrics_table.json", + metric_tooltips=DEFAULT_TOOLTIPS, + thresholds=DEFAULT_THRESHOLDS, + weights=DEFAULT_WEIGHTS, + ) + def _filtered_metrics( + mae_by_model: dict[str, float | None] = filtered_mae, + ) -> dict[str, dict]: + return {"MAE": mae_by_model} + + _filtered_metrics() + + write_element_sets_summary_file(OUT_PATH, ELEMENT_SETS, element_set_masks) @pytest.fixture @@ -127,37 +189,33 @@ def get_is_charged() -> list[bool]: title="LNCI16 Interaction Energies", x_label="Predicted interaction energy / kcal/mol", y_label="Reference interaction energy / kcal/mol", - hoverdata={ - "System": get_system_names(), - "Complex Atoms": get_atom_counts(), - "Charge": get_charges(), - "Charged": get_is_charged(), - }, + hoverdata=HOVERDATA, ) def interaction_energies() -> dict[str, list]: """ Get interaction energies for all LNCI16 systems. + This fixture also copies structure files for the app and writes + element-set specific outputs used by the element-set selector. + Returns ------- dict[str, list] Dictionary of reference and predicted interaction energies. """ - from ase.io import read - - results = {"ref": []} | {mlip: [] for mlip in MODELS} - ref_stored = False + interaction_energy_results = {"ref": []} | {mlip: [] for mlip in MODELS} + reference_is_stored = False for model_name in MODELS: - model_dir = CALC_PATH / model_name + model_output_dir = CALC_PATH / model_name - if not model_dir.exists(): - results[model_name] = [] + if not model_output_dir.exists(): + interaction_energy_results[model_name] = [] continue - xyz_files = sorted(model_dir.glob("*.xyz")) + xyz_files = sorted(model_output_dir.glob("*.xyz")) if not xyz_files: - results[model_name] = [] + interaction_energy_results[model_name] = [] continue model_energies = [] @@ -166,27 +224,25 @@ def interaction_energies() -> dict[str, list]: for xyz_file in xyz_files: atoms = read(xyz_file) model_energies.append(atoms.info["E_int_model_kcal"]) - if not ref_stored: + if not reference_is_stored: ref_energies.append(atoms.info["E_int_ref_kcal"]) - results[model_name] = model_energies + interaction_energy_results[model_name] = model_energies # Store reference energies (only once) - if not ref_stored: - results["ref"] = ref_energies - ref_stored = True + if not reference_is_stored: + interaction_energy_results["ref"] = ref_energies + reference_is_stored = True # Copy individual structure files to app data directory structs_dir = OUT_PATH / model_name structs_dir.mkdir(parents=True, exist_ok=True) - # Copy individual structure files - import shutil - for i, xyz_file in enumerate(xyz_files): shutil.copy(xyz_file, structs_dir / f"{i}.xyz") - return results + write_lnci16_element_set_outputs(interaction_energy_results) + return interaction_energy_results @pytest.fixture @@ -204,15 +260,7 @@ def lnci16_mae(interaction_energies) -> dict[str, float]: dict[str, float] Dictionary of predicted interaction energy errors for all models. """ - results = {} - for model_name in MODELS: - if interaction_energies[model_name]: - results[model_name] = mae( - interaction_energies["ref"], interaction_energies[model_name] - ) - else: - results[model_name] = None - return results + return compute_lnci16_mae(interaction_energies) @pytest.fixture diff --git a/ml_peg/analysis/utils/element_filters.py b/ml_peg/analysis/utils/element_filters.py new file mode 100644 index 000000000..1988af0f7 --- /dev/null +++ b/ml_peg/analysis/utils/element_filters.py @@ -0,0 +1,250 @@ +"""Helpers for filtering benchmark data by allowed element groups.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +import json +from pathlib import Path +from typing import Any + +import numpy as np +from yaml import safe_load + + +def normalize_element_set_key(key: str) -> str: + """ + Normalize an element-set key for file paths and dictionary lookups. + + Parameters + ---------- + key + Raw element-set key from config or UI input. + + Returns + ------- + str + Lowercase normalized key. + """ + return key.strip().lower() + + +def build_allowed_mask( + structure_elements: Sequence[set[str] | Iterable[str]], + allowed_elements: set[str] | None, +) -> np.ndarray: + """ + Build a boolean mask for structures that pass an element filter. + + A structure is selected only when all of its elements are within + ``allowed_elements``. If ``allowed_elements`` is ``None``, every + structure is selected. + + Parameters + ---------- + structure_elements + Element symbols for each structure. + allowed_elements + Allowed elements. ``None`` means no filtering (all True). + + Returns + ------- + np.ndarray + Boolean array where ``True`` means the structure is selected. + """ + if allowed_elements is None: + return np.ones(len(structure_elements), dtype=bool) + + allowed = set(allowed_elements) + return np.array( + [set(elements).issubset(allowed) for elements in structure_elements], + dtype=bool, + ) + + +def filter_sequence(values: Sequence, mask: np.ndarray) -> list: + """ + Filter a sequence with a boolean mask and return a plain list. + + Parameters + ---------- + values + Sequence to filter. + mask + Boolean mask with the same length as ``values``. + + Returns + ------- + list + Filtered values as a list. + """ + return list(np.array(values, dtype=object)[mask]) + + +def build_element_set_masks( + structure_elements: Sequence[set[str] | Iterable[str]], + element_sets: dict[str, dict[str, Any]], +) -> dict[str, np.ndarray]: + """ + Build one structure-selection mask for each configured element set. + + Parameters + ---------- + structure_elements + Per-structure element collections. + element_sets + Element-set mapping returned by ``load_element_sets``. + + Returns + ------- + dict[str, np.ndarray] + Mapping of set key to boolean mask. + """ + return { + key: build_allowed_mask(structure_elements, set_info.get("elements")) + for key, set_info in element_sets.items() + } + + +def filter_results_dict( + results: dict[str, Sequence], + mask: np.ndarray, +) -> dict[str, list]: + """ + Filter a results dictionary with a structure mask. + + Parameters + ---------- + results + Mapping of ``ref`` and model prediction arrays. + mask + Boolean mask selecting rows to keep. + + Returns + ------- + dict[str, list] + Filtered result mapping. + """ + filtered: dict[str, list] = {} + expected = len(mask) + for key, values in results.items(): + if len(values) == 0: + filtered[key] = [] + continue + if len(values) != expected: + raise ValueError( + f"Length mismatch for '{key}': got {len(values)}, expected {expected}." + ) + filtered[key] = filter_sequence(values, mask) + return filtered + + +def filter_hoverdata_dict( + hoverdata: dict[str, Sequence], + mask: np.ndarray, +) -> dict[str, list]: + """ + Filter hover-label columns with a structure mask. + + Parameters + ---------- + hoverdata + Hover column mapping used by plot decorators. + mask + Boolean mask selecting rows to keep. + + Returns + ------- + dict[str, list] + Filtered hover-label mapping. + """ + filtered: dict[str, list] = {} + expected = len(mask) + for key, values in hoverdata.items(): + if len(values) != expected: + raise ValueError( + f"Length mismatch for hover column '{key}': got {len(values)}, " + f"expected {expected}." + ) + filtered[key] = filter_sequence(values, mask) + return filtered + + +def write_element_sets_summary_file( + out_path: str | Path, + element_sets: dict[str, dict[str, Any]], + element_set_masks: dict[str, np.ndarray], +) -> None: + """ + Write element-set summary information to ``element_sets.json``. + + The output contains, for each set: + 1. Display name and description. + 2. Allowed elements. + 3. Number of selected structures. + 4. Original structure positions used by this set. + + Parameters + ---------- + out_path + Benchmark output directory under ``app/data``. + element_sets + Element-set mapping returned by ``load_element_sets``. + element_set_masks + Mapping of set key to structure-selection mask. + """ + element_sets_data: dict[str, dict[str, Any]] = {} + for key, set_info in element_sets.items(): + mask = element_set_masks[key] + indices = np.flatnonzero(mask).astype(int).tolist() + elements = set_info.get("elements") + element_sets_data[key] = { + "name": set_info.get("label", key), + "description": set_info.get("description", ""), + "elements": sorted(elements) if elements is not None else None, + "count": len(indices), + "indices": indices, + } + + summary_data = {"element_sets": element_sets_data} + output_path = Path(out_path) / "element_sets.json" + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as fp: + json.dump(summary_data, fp, indent=2) + fp.write("\n") + + +def load_element_sets(config_path: str | Path) -> dict[str, dict]: + """ + Load element-set definitions from YAML config. + + Parameters + ---------- + config_path + Path to ``element_sets.yml``. + + Returns + ------- + dict[str, dict] + Mapping keyed by normalized set key. Each value contains: + ``label``, ``description``, and ``elements`` (``set[str]`` or ``None``). + """ + with open(config_path) as f: + config = safe_load(f) or {} + + raw_sets = config.get("element_sets") or {} + element_sets: dict[str, dict] = {} + + for raw_key, set_config in raw_sets.items(): + key = normalize_element_set_key(raw_key) + set_config = set_config or {} + raw_elements = set_config.get("elements") + element_sets[key] = { + "label": set_config.get("name", raw_key), + "description": set_config.get("description", ""), + "elements": set(raw_elements) if raw_elements is not None else None, + } + + if not element_sets: + raise ValueError(f"No element sets defined in config: {config_path}") + + return element_sets diff --git a/ml_peg/app/base_app.py b/ml_peg/app/base_app.py index c8f1b24a3..b7b7e4326 100644 --- a/ml_peg/app/base_app.py +++ b/ml_peg/app/base_app.py @@ -63,6 +63,7 @@ def __init__( self.table = rebuild_table( self.table_path, id=self.table_id, description=description ) + self.table.source_table_path = str(self.table_path) self.layout = self.build_layout() def build_layout(self) -> Div: diff --git a/ml_peg/app/build_app.py b/ml_peg/app/build_app.py index 8a8860a30..35918c69a 100644 --- a/ml_peg/app/build_app.py +++ b/ml_peg/app/build_app.py @@ -5,12 +5,13 @@ from importlib import import_module import warnings -from dash import Dash, Input, Output, callback +from dash import Dash, Input, Output, State, callback from dash.dash_table import DataTable -from dash.dcc import Store, Tab, Tabs +from dash.dcc import Dropdown, Store, Tab, Tabs from dash.html import H1, H3, Div from yaml import safe_load +from ml_peg.analysis.utils.element_filters import load_element_sets from ml_peg.analysis.utils.utils import calc_table_scores, get_table_style from ml_peg.app import APP_ROOT from ml_peg.app.utils.build_components import build_footer, build_weight_components @@ -19,7 +20,11 @@ build_tutorial_button, register_onboarding_callbacks, ) -from ml_peg.app.utils.register_callbacks import register_benchmark_to_category_callback +from ml_peg.app.utils.register_callbacks import ( + register_benchmark_to_category_callback, + register_element_set_table_callbacks, + register_summary_scores_from_files_callback, +) from ml_peg.app.utils.utils import ( build_level_of_theory_warnings, calculate_column_widths, @@ -31,6 +36,18 @@ # Get all models MODELS = get_model_names(current_models) +ELEMENT_SETS_CONFIG_PATH = APP_ROOT.parent / "analysis" / "element_sets.yml" + +try: + ELEMENT_SETS = load_element_sets(ELEMENT_SETS_CONFIG_PATH) +except (FileNotFoundError, ValueError): + ELEMENT_SETS = { + "all": { + "label": "All elements", + "description": "Include every structure regardless of composition.", + "elements": None, + } + } def get_all_tests( @@ -336,6 +353,13 @@ def build_tabs( Tab(label=category_name, value=category_name) for category_name in layouts ] + element_set_options = [ + {"label": value.get("label", key), "value": key} + for key, value in ELEMENT_SETS.items() + ] + default_element_set = ( + "all" if "all" in ELEMENT_SETS else element_set_options[0]["value"] + ) tabs_layout = [ build_onboarding_modal(), build_tutorial_button(), @@ -344,6 +368,11 @@ def build_tabs( H1("ML-PEG"), Tabs(id="all-tabs", value="summary-tab", children=all_tabs), Div(id="tabs-content"), + Store( + id="global-element-set-store", + storage_type="session", + data=default_element_set, + ), ], style={"flex": "1", "marginBottom": "40px"}, ), @@ -355,8 +384,36 @@ def build_tabs( style={"display": "flex", "flexDirection": "column", "minHeight": "100vh"}, ) - @callback(Output("tabs-content", "children"), Input("all-tabs", "value")) - def select_tab(tab) -> Div: + @callback( + Output("global-element-set-store", "data"), + Input("global-element-set-dropdown", "value", allow_optional=True), + State("global-element-set-store", "data"), + prevent_initial_call=True, + ) + def store_element_set(selected_set: str, current_set: str) -> str: + """ + Persist selected element-set key in a top-level Store. + + Parameters + ---------- + selected_set + Selected element-set value from summary-tab dropdown. + current_set + Current store value. + + Returns + ------- + str + Updated store value. + """ + return selected_set if selected_set is not None else current_set + + @callback( + Output("tabs-content", "children"), + Input("all-tabs", "value"), + State("global-element-set-store", "data"), + ) + def select_tab(tab, selected_element_set) -> Div: """ Select tab contents to be displayed. @@ -364,6 +421,8 @@ def select_tab(tab) -> Div: ---------- tab Name of tab selected. + selected_element_set + Selected element-set key from session store. Returns ------- @@ -371,9 +430,24 @@ def select_tab(tab) -> Div: Summary or tab contents to be displayed. """ if tab == "summary-tab": + summary_element_set = selected_element_set or default_element_set + summary_controls = Div( + [ + H3("Element Set"), + Dropdown( + id="global-element-set-dropdown", + options=element_set_options, + value=summary_element_set, + clearable=False, + style={"maxWidth": "320px"}, + ), + ], + style={"marginBottom": "12px"}, + ) return Div( [ H1("Benchmarks Summary"), + summary_controls, summary_table, weight_components, Store( @@ -398,6 +472,22 @@ def build_full_app(full_app: Dash, category: str = "*") -> None: """ # Get layouts and tables for each test, grouped by categories all_layouts, all_tables = get_all_tests(category=category) + category_title_map = {} + for category_name in all_tables: + try: + with open(APP_ROOT / category_name / f"{category_name}.yml") as file: + category_info = safe_load(file) + category_title_map[category_name] = category_info.get( + "title", category_name + ) + except FileNotFoundError: + category_title_map[category_name] = category_name + + benchmark_tables = [ + table + for category_tables in all_tables.values() + for table in category_tables.values() + ] if not all_layouts: raise ValueError("No tests were built successfully") @@ -413,4 +503,6 @@ def build_full_app(full_app: Dash, category: str = "*") -> None: ) # Build summary and category tabs build_tabs(full_app, category_layouts, summary_table, weight_components) + register_summary_scores_from_files_callback(all_tables, category_title_map) + register_element_set_table_callbacks(benchmark_tables) register_onboarding_callbacks() diff --git a/ml_peg/app/supramolecular/LNCI16/app_LNCI16.py b/ml_peg/app/supramolecular/LNCI16/app_LNCI16.py index 895526e8c..44f339356 100644 --- a/ml_peg/app/supramolecular/LNCI16/app_LNCI16.py +++ b/ml_peg/app/supramolecular/LNCI16/app_LNCI16.py @@ -8,10 +8,9 @@ from ml_peg.app import APP_ROOT from ml_peg.app.base_app import BaseApp from ml_peg.app.utils.build_callbacks import ( - plot_from_table_column, - struct_from_scatter, + plot_from_table_column_for_element_set, + struct_from_scatter_for_element_set, ) -from ml_peg.app.utils.load import read_plot from ml_peg.models.get_models import get_model_names from ml_peg.models.models import current_models @@ -29,27 +28,22 @@ class LNCI16App(BaseApp): def register_callbacks(self) -> None: """Register callbacks to app.""" - scatter = read_plot( - DATA_PATH / "figure_interaction_energies.json", - id=f"{BENCHMARK_NAME}-figure", - ) - - # Assets dir will be parent directory - individual files for each system - structs = [ - f"assets/supramolecular/LNCI16/{MODELS[0]}/{i}.xyz" - for i in range(16) # LNCI16 has 16 systems - ] - - plot_from_table_column( + plot_from_table_column_for_element_set( table_id=self.table_id, plot_id=f"{BENCHMARK_NAME}-figure-placeholder", - column_to_plot={"MAE": scatter}, + figure_id=f"{BENCHMARK_NAME}-figure", + data_path=DATA_PATH, + plot_filename="figure_interaction_energies.json", + metric_columns=("MAE",), ) - - struct_from_scatter( + struct_from_scatter_for_element_set( scatter_id=f"{BENCHMARK_NAME}-figure", struct_id=f"{BENCHMARK_NAME}-struct-placeholder", - structs=structs, + data_path=DATA_PATH, + structure_path_template=( + f"assets/supramolecular/LNCI16/{MODELS[0]}/{{index}}.xyz" + ), + default_structure_count=16, mode="struct", ) diff --git a/ml_peg/app/utils/build_callbacks.py b/ml_peg/app/utils/build_callbacks.py index f3aab2ee2..ecafac0a1 100644 --- a/ml_peg/app/utils/build_callbacks.py +++ b/ml_peg/app/utils/build_callbacks.py @@ -3,7 +3,7 @@ from __future__ import annotations import base64 -from collections.abc import Callable +from collections.abc import Callable, Sequence import io import json import math @@ -23,6 +23,8 @@ PERIODIC_TABLE_POSITIONS, PERIODIC_TABLE_ROWS, ) +from ml_peg.analysis.utils.element_filters import normalize_element_set_key +from ml_peg.app.utils.load import read_plot from ml_peg.app.utils.weas import generate_weas_html @@ -195,6 +197,234 @@ def show_struct(click_data): ) +def _resolve_element_set_file_path( + data_path: Path, + element_set_key: str, + filename: str, +) -> Path: + """ + Resolve a set-specific file path with fallback to ``all`` then top-level file. + + Parameters + ---------- + data_path + Benchmark data directory under ``app/data``. + element_set_key + Selected element-set key. + filename + Benchmark file name to load. + + Returns + ------- + Path + Existing path to a set-specific or fallback file. + """ + requested = data_path / "element_sets" / element_set_key / filename + if requested.exists(): + return requested + + fallback = data_path / "element_sets" / "all" / filename + if fallback.exists(): + return fallback + + return data_path / filename + + +def _load_element_set_structure_indices( + element_set_index_file: Path, + default_structure_count: int, +) -> dict[str, list[int]]: + """ + Load structure-index mapping for each element set from ``element_sets.json``. + + Parameters + ---------- + element_set_index_file + Path to ``element_sets.json`` for one benchmark. + default_structure_count + Number of structures in the unfiltered benchmark. + + Returns + ------- + dict[str, list[int]] + Mapping of set key to original structure indices. + """ + default_indices = list(range(default_structure_count)) + if not element_set_index_file.exists(): + return {"all": default_indices} + + with open(element_set_index_file) as f: + data = json.load(f) + + element_sets = data.get("element_sets") or {} + element_set_indices: dict[str, list[int]] = {} + for key, value in element_sets.items(): + raw_indices = value.get("indices", []) + element_set_indices[normalize_element_set_key(key)] = [ + int(index) for index in raw_indices + ] + + if "all" not in element_set_indices: + element_set_indices["all"] = default_indices + + return element_set_indices + + +def plot_from_table_column_for_element_set( + table_id: str, + plot_id: str, + figure_id: str, + data_path: str | Path, + plot_filename: str, + metric_columns: Sequence[str], + selector_store_id: str = "global-element-set-store", +) -> None: + """ + Attach callback to show a set-specific plot when a metric cell is clicked. + + Parameters + ---------- + table_id + ID for Dash table being clicked. + plot_id + ID for Dash plot placeholder Div. + figure_id + ID for Dash graph component. + data_path + Benchmark data directory under ``app/data``. + plot_filename + Plot JSON file name, e.g. ``figure_interaction_energies.json``. + metric_columns + Table columns that should trigger this plot. + selector_store_id + Store ID that keeps selected element-set key. + """ + data_dir = Path(data_path) + metric_column_set = set(metric_columns) + + @callback( + Output(plot_id, "children"), + Input(table_id, "active_cell"), + Input(selector_store_id, "data", allow_optional=True), + ) + def show_plot(active_cell, selected_set) -> Div: + """ + Show set-specific plot for selected metric and element set. + + Parameters + ---------- + active_cell + Clicked cell in Dash table. + selected_set + Selected element-set key from shared store. + + Returns + ------- + Div + Message explaining interactivity, or plot on table click. + """ + if not active_cell: + return Div("Click on a metric to view plot.") + + column_id = active_cell.get("column_id") + if column_id not in metric_column_set: + raise PreventUpdate + + set_key = normalize_element_set_key(selected_set or "all") + plot_path = _resolve_element_set_file_path(data_dir, set_key, plot_filename) + scatter = read_plot(plot_path, id=figure_id) + return Div(scatter) + + +def struct_from_scatter_for_element_set( + scatter_id: str, + struct_id: str, + data_path: str | Path, + structure_path_template: str, + default_structure_count: int, + selector_store_id: str = "global-element-set-store", + mode: Literal["struct", "traj"] = "struct", +) -> None: + """ + Attach callback to show a structure for clicked set-filtered scatter points. + + Parameters + ---------- + scatter_id + ID for Dash scatter being clicked. + struct_id + ID for Dash placeholder Div where structures are visualised. + data_path + Benchmark data directory under ``app/data``. + structure_path_template + Structure path template with ``{index}``, e.g. + ``assets/.../{index}.xyz``. + default_structure_count + Number of structures in the unfiltered benchmark. + selector_store_id + Store ID that keeps selected element-set key. + mode + Whether to display a single structure (``"struct"``) or trajectory + (``"traj"``). Default is ``"struct"``. + """ + data_dir = Path(data_path) + element_set_index_file = data_dir / "element_sets.json" + element_set_indices = _load_element_set_structure_indices( + element_set_index_file=element_set_index_file, + default_structure_count=default_structure_count, + ) + + @callback( + Output(struct_id, "children", allow_duplicate=True), + Input(scatter_id, "clickData"), + State(selector_store_id, "data", allow_optional=True), + prevent_initial_call="initial_duplicate", + ) + def show_structure(click_data, selected_set) -> Div: + """ + Show structure for clicked scatter point and selected element set. + + Parameters + ---------- + click_data + Clicked data point in scatter plot. + selected_set + Selected element-set key from shared store. + + Returns + ------- + Div + Message explaining interactivity, or visualised structure. + """ + if not click_data: + return Div("Click on a metric to view the structure.") + + point_index = click_data["points"][0]["pointNumber"] + set_key = normalize_element_set_key(selected_set or "all") + allowed_indices = element_set_indices.get( + set_key, + element_set_indices.get("all", []), + ) + if allowed_indices and point_index >= len(allowed_indices): + raise PreventUpdate + + structure_index = ( + allowed_indices[point_index] if allowed_indices else point_index + ) + structure_path = structure_path_template.format(index=structure_index) + return Div( + Iframe( + srcDoc=generate_weas_html(structure_path, mode, 0), + style={ + "height": "550px", + "width": "100%", + "border": "1px solid #ddd", + "borderRadius": "5px", + }, + ) + ) + + def struct_from_table( table_id: str, struct_id: str, @@ -282,8 +512,8 @@ def register_image_gallery_callbacks( manifest_dir Directory containing per-model ``manifest.json`` files. curve_dir - Directory of per-model curve JSON payloads. Element selections are rendered on - the fly from these payloads instead of relying on pre-generated element images. + Directory of per-model curve JSON files. Element selections are rendered on + the fly from these files instead of relying on pre-generated element images. overview_label Dropdown label representing the overview image. Default is ``"All"``. """ @@ -461,10 +691,10 @@ def _update_options(model_name: str): if model_curve_dir.exists(): for curve_file in model_curve_dir.glob("*.json"): try: - payload = json.loads(curve_file.read_text()) + curve_data = json.loads(curve_file.read_text()) except Exception: continue - pair = payload.get("pair") or curve_file.stem + pair = curve_data.get("pair") or curve_file.stem try: first, second = pair.split("-") except ValueError: @@ -508,11 +738,11 @@ def _update_figure(model_name: str, element_value: str | None): curves: dict[str, dict] = {} for curve_file in model_curve_dir.glob("*.json"): try: - payload = json.loads(curve_file.read_text()) + curve_data = json.loads(curve_file.read_text()) except Exception: continue - pair = payload.get("pair") or curve_file.stem - curves[pair] = payload + pair = curve_data.get("pair") or curve_file.stem + curves[pair] = curve_data if not curves: raise PreventUpdate @@ -521,17 +751,17 @@ def _update_figure(model_name: str, element_value: str | None): # all pairs involving the selected element. selected_element = None if element_value == overview_label else element_value filtered: dict[str, dict] = {} - for pair, payload in curves.items(): + for pair, curve_data in curves.items(): try: first, second = pair.split("-") except ValueError: first = second = pair if selected_element is None: if first == second: - filtered[pair] = payload + filtered[pair] = curve_data else: if selected_element in (first, second): - filtered[pair] = payload + filtered[pair] = curve_data if not filtered: raise PreventUpdate @@ -555,14 +785,14 @@ def _update_figure(model_name: str, element_value: str | None): ax.axis("off") has_data = False - for pair, payload in filtered.items(): + for pair, curve_data in filtered.items(): first, second = pair.split("-") if "-" in pair else (pair, pair) other = second if selected_element == first else first pos = PERIODIC_TABLE_POSITIONS.get(other) if pos is None: continue - x_vals = payload.get("distance") or [] - y_vals = payload.get("energy") or [] + x_vals = curve_data.get("distance") or [] + y_vals = curve_data.get("energy") or [] if not x_vals or not y_vals: continue try: diff --git a/ml_peg/app/utils/register_callbacks.py b/ml_peg/app/utils/register_callbacks.py index 7a1adc675..bca14bebc 100644 --- a/ml_peg/app/utils/register_callbacks.py +++ b/ml_peg/app/utils/register_callbacks.py @@ -3,17 +3,20 @@ from __future__ import annotations from copy import deepcopy +from pathlib import Path from typing import Any from dash import Input, Output, State, callback, ctx from dash.exceptions import PreventUpdate +from ml_peg.analysis.utils.element_filters import normalize_element_set_key from ml_peg.analysis.utils.utils import ( calc_metric_scores, calc_table_scores, get_table_style, update_score_style, ) +from ml_peg.app.utils.load import rebuild_table from ml_peg.app.utils.utils import ( Thresholds, build_level_of_theory_warnings, @@ -50,7 +53,7 @@ def register_summary_table_callbacks( ), # Needed to display model config & level of theory tooltips Input("all-tabs", "value"), Input("summary-table-weight-store", "data"), - State("summary-table-scores-store", "data"), + Input("summary-table-scores-store", "data"), State("summary-table", "data"), prevent_initial_call=False, ) @@ -95,6 +98,218 @@ def update_summary_table( return updated_rows, style_with_warnings, tooltip_rows +def _resolve_element_set_table_path( + table_path: str | Path, element_set_key: str +) -> Path: + """ + Resolve set-specific table path with fallback to ``all`` then baseline table. + + Parameters + ---------- + table_path + Baseline benchmark metrics table path. + element_set_key + Requested element-set key. + + Returns + ------- + Path + Existing path to load table JSON from. + """ + base_path = Path(table_path) + candidate = base_path.parent / "element_sets" / element_set_key / base_path.name + if candidate.exists(): + return candidate + fallback = base_path.parent / "element_sets" / "all" / base_path.name + if fallback.exists(): + return fallback + return base_path + + +def register_element_set_table_callbacks( + benchmark_tables: list[Any], + selector_id: str = "global-element-set-store", +) -> None: + """ + Register callbacks to switch benchmark tables across configured element sets. + + Parameters + ---------- + benchmark_tables + List of benchmark ``DataTable`` instances. + selector_id + ID of global element-set dropdown component. + """ + table_specs: list[dict[str, Any]] = [] + for table in benchmark_tables: + table_path = getattr(table, "source_table_path", None) + table_id = getattr(table, "id", None) + if not table_id or not table_path: + continue + table_specs.append( + { + "id": table_id, + "path": table_path, + "description": getattr(table, "description", None), + } + ) + + if not table_specs: + return + + for spec in table_specs: + table_id = spec["id"] + + @callback( + Output(table_id, "data", allow_duplicate=True), + Output(table_id, "columns", allow_duplicate=True), + Output(table_id, "style_data_conditional", allow_duplicate=True), + Output(table_id, "tooltip_data", allow_duplicate=True), + Output(table_id, "tooltip_header", allow_duplicate=True), + Output(f"{table_id}-raw-data-store", "data", allow_duplicate=True), + Output(f"{table_id}-computed-store", "data", allow_duplicate=True), + Output(f"{table_id}-raw-tooltip-store", "data", allow_duplicate=True), + Output(f"{table_id}-thresholds-store", "data", allow_duplicate=True), + Output(f"{table_id}-weight-store", "data", allow_duplicate=True), + Input(selector_id, "data", allow_optional=True), + Input(table_id, "id", allow_optional=True), + prevent_initial_call="initial_duplicate", + ) + def switch_single_table( + selected_set: str, + _mounted_table_id: str, + *, + table_id: str = table_id, + table_path: str = spec["path"], + description: str | None = spec["description"], + ) -> tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]: + """ + Load one benchmark table for the selected element set. + + Parameters + ---------- + selected_set + Selected element-set key. + _mounted_table_id + Mounted table ID (unused, triggers refresh when table appears). + table_id + Benchmark table component ID. + table_path + Baseline JSON table path. + description + Optional table description. + + Returns + ------- + tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] + Updated table data and associated stores. + """ + set_key = normalize_element_set_key(selected_set or "all") + resolved_path = _resolve_element_set_table_path(table_path, set_key) + rebuilt_table = rebuild_table( + filename=resolved_path, + id=table_id, + description=description, + ) + return ( + rebuilt_table.data, + rebuilt_table.columns, + rebuilt_table.style_data_conditional, + rebuilt_table.tooltip_data, + rebuilt_table.tooltip_header, + rebuilt_table.data, + rebuilt_table.data, + rebuilt_table.tooltip_header, + getattr(rebuilt_table, "thresholds", {}), + getattr(rebuilt_table, "weights", {}), + ) + + +def register_summary_scores_from_files_callback( + all_tables: dict[str, dict[str, Any]], + category_title_map: dict[str, str], + selector_store_id: str = "global-element-set-store", +) -> None: + """ + Recompute summary score-store directly from benchmark table JSON files. + + Parameters + ---------- + all_tables + Nested mapping of category -> benchmark name -> benchmark DataTable. + category_title_map + Mapping of category directory names to displayed category titles. + selector_store_id + Store ID containing the selected element-set key. + """ + + @callback( + Output("summary-table-scores-store", "data", allow_duplicate=True), + Input(selector_store_id, "data"), + prevent_initial_call=True, + ) + def recompute_summary_scores(selected_set: str) -> dict[str, dict[str, float]]: + """ + Recompute category scores for the selected element set. + + Parameters + ---------- + selected_set + Selected element-set key. + + Returns + ------- + dict[str, dict[str, float]] + Summary score-store data keyed by `` Score``. + """ + set_key = normalize_element_set_key(selected_set or "all") + summary_scores: dict[str, dict[str, float]] = {} + + for category_dir, benchmarks in all_tables.items(): + benchmark_columns: list[str] = [] + model_scores: dict[str, dict[str, float | None]] = {} + + for benchmark_name, benchmark_table in benchmarks.items(): + benchmark_column = benchmark_name + " Score" + benchmark_columns.append(benchmark_column) + + source_path = getattr(benchmark_table, "source_table_path", None) + if not source_path: + continue + table_path = _resolve_element_set_table_path(source_path, set_key) + rebuilt_table = rebuild_table( + filename=table_path, + id=getattr(benchmark_table, "id", "metrics"), + description=getattr(benchmark_table, "description", None), + ) + name_map = getattr(rebuilt_table, "model_name_map", {}) or {} + for row in rebuilt_table.data: + display_name = row.get("MLIP") + original_name = name_map.get(display_name, display_name) + if original_name is None: + continue + model_scores.setdefault(original_name, {}) + model_scores[original_name][benchmark_column] = row.get("Score") + + if not benchmark_columns: + continue + + category_rows: list[dict[str, Any]] = [] + for model_name, model_row_scores in model_scores.items(): + category_row: dict[str, Any] = {"MLIP": model_name} + for column in benchmark_columns: + category_row[column] = model_row_scores.get(column, None) + category_rows.append(category_row) + + category_rows = calc_table_scores(category_rows) + category_title = category_title_map.get(category_dir, category_dir) + summary_scores[category_title + " Score"] = { + row["MLIP"]: row["Score"] for row in category_rows + } + + return summary_scores + + def register_category_table_callbacks( table_id: str, use_thresholds: bool = False, @@ -133,7 +348,6 @@ def register_category_table_callbacks( Output(f"{table_id}-raw-data-store", "data"), Input(f"{table_id}-weight-store", "data"), Input(f"{table_id}-thresholds-store", "data"), - Input("all-tabs", "value"), Input(f"{table_id}-normalized-toggle", "value"), State(f"{table_id}-raw-data-store", "data"), State(f"{table_id}-computed-store", "data"), @@ -144,7 +358,6 @@ def register_category_table_callbacks( def update_benchmark_table_scores( stored_weights: dict[str, float] | None, stored_threshold: dict | None, - _tabs_value: str, toggle_value: list[str] | None, stored_raw_data: list[dict] | None, stored_computed_data: list[dict] | None, @@ -168,8 +381,6 @@ def update_benchmark_table_scores( Stored weights dictionary for table metrics. stored_threshold Stored thresholds dictionary for table metric thresholds. - _tabs_value - Current tab identifier (unused, required to trigger on tab change). toggle_value Value of toggle to show normalised values. stored_raw_data @@ -194,12 +405,8 @@ def apply_levels_of_theory( tooltip_data = tooltip_rows if tooltip_rows else [{} for _ in rows] return combined_style, tooltip_data - # Tab switches and toggle flips reuse the cached scored rows rather than - # recalculating scores, we only re-score when weights/thresholds change. - if ( - trigger_id in ("all-tabs", f"{table_id}-normalized-toggle") - and stored_computed_data - ): + # Toggle flips reuse cached scored rows; weight/threshold edits rescore. + if trigger_id == f"{table_id}-normalized-toggle" and stored_computed_data: display_rows = get_scores( stored_raw_data, stored_computed_data, thresholds, toggle_value ) @@ -256,19 +463,13 @@ def apply_levels_of_theory( Output(table_id, "tooltip_data", allow_duplicate=True), Output(f"{table_id}-computed-store", "data", allow_duplicate=True), Input(f"{table_id}-weight-store", "data"), - Input("all-tabs", "value"), State(table_id, "data"), - State(f"{table_id}-computed-store", "data"), prevent_initial_call="initial_duplicate", ) def update_table_scores( stored_weights: dict[str, float] | None, - _tabs_value: str, table_data: list[dict] | None, - computed_store: list[dict] | None, ) -> tuple[list[dict], list[dict], list[dict], list[dict]]: - trigger_id = ctx.triggered_id - def apply_levels( rows: list[dict], base_style: list[dict] ) -> tuple[list[dict], list[dict]]: @@ -279,11 +480,6 @@ def apply_levels( tooltips = tooltip_rows if tooltip_rows else [{} for _ in rows] return combined_style, tooltips - if trigger_id == "all-tabs" and computed_store: - style = get_table_style(computed_store) - style, tooltip_data = apply_levels(computed_store, style) - return computed_store, style, tooltip_data, computed_store - if not table_data: raise PreventUpdate