diff --git a/docs/source/user_guide/benchmarks/bulk_crystal.rst b/docs/source/user_guide/benchmarks/bulk_crystal.rst index 575bbcc11..e8b9dc0c0 100644 --- a/docs/source/user_guide/benchmarks/bulk_crystal.rst +++ b/docs/source/user_guide/benchmarks/bulk_crystal.rst @@ -115,3 +115,170 @@ Reference data: * Same as input data * PBE + + +Diamond phonons +=============== + +Summary +------- + +Performance in predicting the phonon dispersion of bulk diamond (carbon). + +The benchmark evaluates the accuracy of phonon frequencies along a fixed +high-symmetry path in the Brillouin zone for a single reference crystal +structure of diamond. + + +Metrics +------- + +1. Band MAE + +Mean absolute error (MAE) between predicted and reference phonon frequencies. + +For bulk diamond, the phonon band structure is computed for each model along the same +q-point path as the reference calculation. At each q-point, the six phonon frequencies +are compared to the reference frequencies after sorting the modes to avoid branch +labelling ambiguities. The MAE is evaluated over all q-points and all phonon branches. + + +2. Band RMSE + +Root mean squared error (RMSE) between predicted and reference phonon frequencies. + +The RMSE is computed using the same sorted, mode-unlabelled comparison procedure as in +(1), over all q-points and all phonon branches. + + +Computational cost +------------------ + +Medium: tests typically take a few minutes to run on CPU. + + +Data availability +----------------- + +Input structures (https://github.com/7radians/ml-peg-data/tree/main/diamond_data): + +* A primitive bulk diamond unit cell containing two carbon atoms, used to generate a + phonopy displacement dataset on a 4×4×4 supercell. + +Reference data: + +* DFT phonon band structure for bulk diamond along a fixed high-symmetry path, provided + as ``dft_band.npz``. +* RSCAN. + + +Ti64 phonons +============ + +Summary +------- + +Performance in predicting phonon dispersions, vibrational densities of states (DOS/PDOS), +and thermodynamic Helmholtz free energies for a suite of 10 Ti–Al–V alloy phases. + +Each case is evaluated by comparing ML-predicted phonon frequencies to CASTEP reference +phonon frequencies along a fixed high-symmetry q-path. For a subset of cases, +Helmholtz free-energy errors per atom are additionally reported. + + +Metrics +------- + +1. Dispersion RMSE (mean) + + Mean root mean squared error (RMSE) between predicted and reference phonon frequencies, + averaged over 10 Ti64 cases. + + For each case, reference phonon frequencies are parsed from CASTEP ``.castep`` outputs + along a fixed high-symmetry q-path. The structure is then relaxed for each model using + the LBFGS optimiser (maximum 10000 steps, ``fmax=0.001``). Phonon frequencies are computed + using finite displacements in a 2×2×2 supercell with a displacement magnitude of 0.02 Å and + ``plusminus=True``. The reference dispersion is linearly interpolated onto an inferred ML + path-coordinate grid spanning the same path (a uniform grid with the same number of + q-points as the ML dispersion), and the RMSE is evaluated over all q-points and all + phonon branches. + +2. Dispersion RMSE (max) + + Maximum per-case dispersion RMSE (in THz) over the 10 Ti64 cases. + + Computed as in (1), but taking the maximum RMSE value across cases. + +3. ω_avg MAE + + Mean absolute error (MAE) in the average phonon frequency ω_avg over the 10 Ti64 cases. + + For each case, ω_avg is computed as the arithmetic mean of all phonon frequencies after + interpolating the reference dispersion onto the inferred ML grid. The per-case absolute error is + then averaged across cases. Frequencies are averaged as stored; if imaginary modes are present + as negative values, they contribute directly. + +4. ΔF (0 K) mean + + Mean absolute error in Helmholtz free energy at 0 K, reported as eV/atom, over the + subset of cases where thermodynamic outputs are available. + + For applicable cases, CASTEP q-point phonon frequencies and q-point weights are parsed + from CASTEP qpoints ``.castep`` outputs. A reference Helmholtz free energy is computed + in the harmonic approximation by combining a weighted zero-point energy contribution + with a weighted thermal free-energy contribution evaluated on a dense temperature grid + (2000 points spanning 0–2000 K) and interpolated to the ML temperatures. The absolute + difference between ML and reference free energy at 0 K is divided by the number of atoms + and averaged across thermodynamics-enabled cases. Weights are taken directly from CASTEP; + no explicit renormalisation is applied. + +5. ΔF (2000 K) mean + + Mean absolute error in Helmholtz free energy at 2000 K, reported as eV/atom, over the + subset of cases where thermodynamic outputs are available. + + Computed as in (4), but using the final temperature point (2000 K). + + +Computational cost +------------------ + +Medium: dispersion, DOS/PDOS and thermodynamic calculations typically take minutes per model on CPU. +Thermodynamic calculations are enabled for a 7/10 subset of cases. + + +Data availability +----------------- + +Full details on the data and benchmark: + +* Allen, C. S. & Bartók, A. P. Multi-phase dataset for Ti and Ti-6Al-4V. + Preprint at https://arxiv.org/abs/2501.06116 (2025). + +* Radova, M., Stark, W. G., Allen, C. S., Maurer, R.J. & Bartók, A. P. + Fine-tuning foundation models of materials interatomic potentials + with frozen transfer learning. + npj Comput Mater 11, 237 (2025). + https://doi.org/10.1038/s41524-025-01727-x + +Input structures (https://github.com/7radians/ml-peg-data/tree/main/ti64_data): + +* CASTEP ``.castep`` outputs providing reference phonon dispersions along fixed + high-symmetry q-paths for 10 Ti–Al–V alloy cases. +* Corresponding CASTEP qpoints ``.castep`` outputs (subset) providing q-point phonon + frequencies and weights for thermodynamic reference reconstruction. + +Reference data: + +* CASTEP phonon dispersions parsed from ``.castep`` outputs (q-path dispersion). +* CASTEP q-point phonon frequencies and weights parsed from qpoints ``.castep`` outputs + (subset), used to compute reference Helmholtz free energies in the harmonic + approximation. +* PBE + +Computational environment +------------------------- + +Ti64 phonon calculations were run as a single process on CPU on an +x86_64 machine (11th Gen Intel(R) Core(TM) i5-1145G7; 4 cores / 8 threads). No explicit +parallel execution (MPI or multiprocessing) was used in the benchmark driver. diff --git a/ml_peg/analysis/bulk_crystal/diamond_phonons/analyse_diamond_phonons.py b/ml_peg/analysis/bulk_crystal/diamond_phonons/analyse_diamond_phonons.py new file mode 100644 index 000000000..11397b5b1 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/diamond_phonons/analyse_diamond_phonons.py @@ -0,0 +1,471 @@ +"""Analyse diamond phonon dispersion benchmark (bands only).""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +import yaml # type: ignore + +from ml_peg.analysis.utils.decorators import build_table, cell_to_scatter +from ml_peg.analysis.utils.utils import load_metrics_config +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.calcs.utils.utils import download_github_data +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +GITHUB_BASE = "https://raw.githubusercontent.com/7radians/ml-peg-data/main" + +EXTRACTED_ROOT = Path( + download_github_data( + filename="diamond_data/data.zip", + github_uri=GITHUB_BASE, + ) +) + +CALC_DATA = EXTRACTED_ROOT / "data" + + +MODELS = get_model_names(current_models) + +CATEGORY = "bulk_crystal" +BENCH = "diamond_phonons" + +CALC_PATH = CALCS_ROOT / CATEGORY / BENCH / "outputs" + +# CALC_DATA = CALCS_ROOT / CATEGORY / BENCH / "data" + + +OUT_PATH = APP_ROOT / "data" / CATEGORY / BENCH + +# cm^-1 per THz (to convert the DFT reference to THz) +THZ_TO_CM1 = 33.35640951981521 +NBANDS = 6 + +SCATTER_FILENAME = OUT_PATH / "diamond_phonons_bands_interactive.json" + +METRIC_KEY_MAE = "band_mae" +METRIC_KEY_RMSE = "band_rmse" + +METRIC_LABEL_MAE = "Band MAE" +METRIC_LABEL_RMSE = "Band RMSE" + + +METRICS_YML = Path(__file__).with_name("metrics.yml") +THRESHOLDS, METRIC_TOOLTIPS, WEIGHTS = load_metrics_config(METRICS_YML) + +_expected_metric_labels = {METRIC_LABEL_MAE, METRIC_LABEL_RMSE} +_yaml_metric_labels = set(THRESHOLDS.keys()) +missing = _expected_metric_labels - _yaml_metric_labels +if missing: + raise ValueError( + f"{METRICS_YML}: missing metrics for labels {sorted(missing)}. " + f"Found: {sorted(_yaml_metric_labels)}" + ) + + +def _load_reference_npz(path: Path) -> dict[str, Any]: + """ + Load DFT reference bands from an NPZ file. + + Parameters + ---------- + path + Path to ``dft_band.npz``. Must contain a ``frequencies`` array in cm-1 with + shape ``(Nq, NBANDS)`` or ``(1, Nq, NBANDS)``. + + Returns + ------- + dict[str, Any] + Mapping with keys: + + - ``freqs``: array of shape ``(Nq, NBANDS)`` in THz + - ``units``: ``"THz"`` + - ``path``: input path + """ + if not path.exists(): + raise FileNotFoundError(f"Missing DFT reference: {path}") + + d = np.load(path, allow_pickle=False) + if "frequencies" not in d.files: + raise KeyError(f"{path}: missing 'frequencies'. Found keys: {list(d.files)}") + + freqs_cm1 = np.asarray(d["frequencies"], dtype=float) + + if freqs_cm1.ndim == 3 and freqs_cm1.shape[0] == 1: + freqs_cm1 = freqs_cm1[0] + + if freqs_cm1.ndim != 2 or freqs_cm1.shape[1] != NBANDS: + msg = f"{path}: expected (Nq, {NBANDS}) frequencies, got {freqs_cm1.shape}" + raise ValueError(msg) + + if not np.isfinite(freqs_cm1).all(): + raise ValueError(f"{path}: contains non-finite reference frequencies.") + + freqs_thz = freqs_cm1 / THZ_TO_CM1 + + return {"freqs": freqs_thz, "units": "THz", "path": path} + + +def _load_band_yaml(path: Path) -> dict[str, Any]: + """ + Load phonopy ``band.yaml`` and return frequencies in THz. + + Parameters + ---------- + path + Path to a phonopy ``band.yaml``. + + Returns + ------- + dict[str, Any] + Mapping with keys ``freqs`` (``(Nq, NBANDS)``), ``units`` (``"THz"``), + and ``path``. + """ + if not path.exists(): + raise FileNotFoundError(f"Missing predicted band.yaml: {path}") + + with path.open("r", encoding="utf8") as f: + y = yaml.safe_load(f) + + phonon = y.get("phonon") + if not isinstance(phonon, list) or not phonon: + raise ValueError( + f"{path} does not look like a phonopy band.yaml (missing 'phonon' list)." + ) + + freqs = np.array( + [[b.get("frequency", np.nan) for b in p.get("band", [])] for p in phonon], + dtype=float, + ) + + if freqs.ndim != 2 or freqs.shape[1] != NBANDS: + raise ValueError(f"{path}: expected (Nq, {NBANDS}) freqs, got {freqs.shape}") + + if not np.isfinite(freqs).all(): + raise ValueError(f"{path}: contains non-finite predicted frequencies.") + + return {"freqs": freqs, "units": "THz", "path": path} + + +def _sorted_flat(freqs: np.ndarray) -> np.ndarray: + """ + Sort each q-point's bands then flatten. + + Parameters + ---------- + freqs + Array of shape ``(Nq, NBANDS)``. + + Returns + ------- + numpy.ndarray + Flattened array of shape ``(Nq * NBANDS,)``. + """ + if freqs.ndim != 2: + raise ValueError(f"Expected (Nq, nb). Got {freqs.shape}") + if freqs.shape[1] != NBANDS: + raise ValueError(f"Expected {NBANDS} bands, got {freqs.shape[1]}") + return np.sort(freqs, axis=1).reshape(-1) + + +def _mae(a: np.ndarray, b: np.ndarray) -> float: + """ + Compute mean absolute error (THz). + + Parameters + ---------- + a + Predicted values (THz), same shape as ``b``. + b + Reference values (THz), same shape as ``a``. + + Returns + ------- + float + Mean absolute error in THz. + """ + return float(np.mean(np.abs(a - b))) + + +def _rmse(a: np.ndarray, b: np.ndarray) -> float: + """ + Compute root mean squared error (THz). + + Parameters + ---------- + a + Predicted values (THz), same shape as ``b``. + b + Reference values (THz), same shape as ``a``. + + Returns + ------- + float + Root mean squared error in THz. + """ + d = a - b + return float(np.sqrt(np.mean(d * d))) + + +@pytest.fixture +def reference() -> dict[str, Any]: + """ + Load the DFT reference and ensure output directory exists. + + Returns + ------- + dict[str, Any] + Reference mapping as returned by :func:`_load_reference_npz`. + """ + OUT_PATH.mkdir(parents=True, exist_ok=True) + return _load_reference_npz(CALC_DATA / "dft_band.npz") + + +def _model_flat(model_name: str) -> np.ndarray: + """ + Load and flatten one model's predicted bands (THz). + + Parameters + ---------- + model_name + Model identifier used to locate ``{CALC_PATH}/{model_name}/band.yaml``. + + Returns + ------- + numpy.ndarray + Flattened frequencies of shape ``(Nq * NBANDS,)`` in THz. + """ + pred = _load_band_yaml(CALC_PATH / model_name / "band.yaml") + return _sorted_flat(np.asarray(pred["freqs"], dtype=float)) + + +@pytest.fixture +def flat_bands(reference: dict[str, Any]) -> tuple[np.ndarray, dict[str, np.ndarray]]: + """ + Load and cache flattened reference and predicted bands. + + Parameters + ---------- + reference + Reference mapping as returned by :func:`reference`. + + Returns + ------- + tuple[numpy.ndarray, dict[str, numpy.ndarray]] + ``(ref_flat, pred_flats)`` where ``ref_flat`` has shape ``(Nq * NBANDS,)`` + and ``pred_flats`` maps model name to an array of the same shape. + """ + ref_flat = _sorted_flat(np.asarray(reference["freqs"], dtype=float)) + + pred_flats: dict[str, np.ndarray] = {} + for model_name in MODELS: + pred_flat = _model_flat(model_name) + if pred_flat.shape != ref_flat.shape: + raise ValueError( + f"{model_name}: prediction and reference flattened shapes differ " + f"{pred_flat.shape} vs {ref_flat.shape}." + ) + pred_flats[model_name] = pred_flat + + return ref_flat, pred_flats + + +@pytest.fixture +def band_errors( + flat_bands: tuple[np.ndarray, dict[str, np.ndarray]], +) -> dict[str, dict[str, float]]: + """ + Compute MAE and RMSE for each model (THz). + + Parameters + ---------- + flat_bands + Tuple ``(ref_flat, pred_flats)`` as returned by :func:`flat_bands`. + + Returns + ------- + dict[str, dict[str, float]] + Mapping ``model_name -> {"mae": float, "rmse": float}`` in THz. + """ + ref_flat, pred_flats = flat_bands + + out: dict[str, dict[str, float]] = {} + for model_name in MODELS: + pred_flat = pred_flats[model_name] + out[model_name] = { + "mae": _mae(pred_flat, ref_flat), + "rmse": _rmse(pred_flat, ref_flat), + } + return out + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "diamond_phonons_bands_table.json", + thresholds=THRESHOLDS, + metric_tooltips=METRIC_TOOLTIPS, + weights=WEIGHTS, +) +def metrics(band_errors: dict[str, dict[str, float]]) -> dict[str, dict[str, float]]: + """ + Build the metrics table mapping for the Dash table. + + Parameters + ---------- + band_errors + Per-model MAE/RMSE mapping as returned by :func:`band_errors`. + + Returns + ------- + dict[str, dict[str, float]] + Mapping from visible metric label to per-model values. + """ + return { + METRIC_LABEL_MAE: {m: band_errors[m]["mae"] for m in MODELS}, + METRIC_LABEL_RMSE: {m: band_errors[m]["rmse"] for m in MODELS}, + } + + +@pytest.fixture +def band_stats( + flat_bands: tuple[np.ndarray, dict[str, np.ndarray]], + band_errors: dict[str, dict[str, float]], +) -> dict[str, dict[str, Any]]: + """ + Build per-model structures consumed by ``cell_to_scatter``. + + Parameters + ---------- + flat_bands + Tuple ``(ref_flat, pred_flats)`` as returned by :func:`flat_bands`. + band_errors + Per-model MAE/RMSE mapping as returned by :func:`band_errors`. + + Returns + ------- + dict[str, dict[str, Any]] + Per-model structures containing points and metric values used to build the + interactive scatter dataset. + """ + ref_flat, pred_flats = flat_bands + + stats: dict[str, dict[str, Any]] = {} + for model_name in MODELS: + pred_flat = pred_flats[model_name] + + points = [ + { + "id": f"diamond-{i}", + "label": "diamond", + "ref": float(ref_val), + "pred": float(pred_val), + } + for i, (pred_val, ref_val) in enumerate( + zip(pred_flat, ref_flat, strict=True) + ) + ] + + stats[model_name] = { + "model": model_name, + "metrics": { + METRIC_KEY_MAE: { + "points": points, + "mae": float(band_errors[model_name]["mae"]), + }, + METRIC_KEY_RMSE: { + "points": points, + "rmse": float(band_errors[model_name]["rmse"]), + }, + }, + } + + return stats + + +@pytest.fixture +@cell_to_scatter( + filename=SCATTER_FILENAME, + x_label="Predicted frequency (THz)", + y_label="DFT frequency (THz)", +) +def interactive_dataset(band_stats: dict[str, dict[str, Any]]) -> dict[str, Any]: + """ + Build the interactive scatter dataset for the phonon Dash app. + + Parameters + ---------- + band_stats + Per-model point/metric structures as returned by :func:`band_stats`. + + Returns + ------- + dict[str, Any] + Interactive dataset payload written to JSON by the decorator. + """ + dataset: dict[str, Any] = { + "metrics": { + METRIC_KEY_MAE: METRIC_LABEL_MAE, + METRIC_KEY_RMSE: METRIC_LABEL_RMSE, + }, + "models": {}, + } + + for model_name, model_data in band_stats.items(): + dataset["models"][model_name] = {"metrics": {}} + + dataset["models"][model_name]["metrics"][METRIC_KEY_MAE] = { + "points": model_data["metrics"][METRIC_KEY_MAE]["points"], + "mae": model_data["metrics"][METRIC_KEY_MAE]["mae"], + } + + dataset["models"][model_name]["metrics"][METRIC_KEY_RMSE] = { + "points": model_data["metrics"][METRIC_KEY_RMSE]["points"], + "rmse": model_data["metrics"][METRIC_KEY_RMSE]["rmse"], + } + + return dataset + + +def test_diamond_phonons_analysis( + metrics: dict[str, Any], interactive_dataset: dict[str, Any] +) -> None: + """ + Generate JSON artifacts for the diamond phonons benchmark. + + Parameters + ---------- + metrics + Table fixture output (decorator writes JSON). + interactive_dataset + Scatter fixture output (decorator writes JSON). + """ + assert isinstance(metrics, dict) + assert isinstance(interactive_dataset, dict) + + table_path = OUT_PATH / "diamond_phonons_bands_table.json" + assert table_path.exists() + + table_payload = json.loads(table_path.read_text(encoding="utf8")) + rows = table_payload.get("data", []) + ids = {row.get("id") for row in rows if isinstance(row, dict)} + missing_rows = [m for m in MODELS if m not in ids] + assert not missing_rows, f"Table missing model rows: {missing_rows}" + + assert SCATTER_FILENAME.exists() + scatter_payload = json.loads(SCATTER_FILENAME.read_text(encoding="utf8")) + models = scatter_payload.get("models", {}) + missing_models = [m for m in MODELS if m not in models] + assert not missing_models, f"Interactive dataset missing models: {missing_models}" + + # Ensure each model has both metrics and some points. + for model_name in MODELS: + model_metrics = (models.get(model_name) or {}).get("metrics", {}) + for key in (METRIC_KEY_MAE, METRIC_KEY_RMSE): + assert key in model_metrics, f"{model_name}: missing metric '{key}'" + points = model_metrics[key].get("points", []) + assert points, f"{model_name}: empty points for '{key}'" diff --git a/ml_peg/analysis/bulk_crystal/diamond_phonons/metrics.yml b/ml_peg/analysis/bulk_crystal/diamond_phonons/metrics.yml new file mode 100644 index 000000000..c296a52b0 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/diamond_phonons/metrics.yml @@ -0,0 +1,16 @@ +metrics: + Band MAE: + good: 0 + bad: 2.0 + unit: THz + tooltip: "Mean absolute error of phonon frequencies for diamond over all sampled q-points and phonon branches along the chosen band path." + weight: 1.0 + level_of_theory: RSCAN + + Band RMSE: + good: 0.0 + bad: 2.0 + unit: THz + tooltip: "Root mean squared error of phonon frequencies for diamond over all sampled q-points and phonon branches along the chosen band path." + weight: 1.0 + level_of_theory: RSCAN diff --git a/ml_peg/analysis/bulk_crystal/ti64_phonons/analyse_ti64_phonons.py b/ml_peg/analysis/bulk_crystal/ti64_phonons/analyse_ti64_phonons.py new file mode 100644 index 000000000..36d75fb2f --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/ti64_phonons/analyse_ti64_phonons.py @@ -0,0 +1,538 @@ +"""Analyse Ti64 phonons benchmark.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import pytest + +from ml_peg.analysis.utils.decorators import build_table, cell_to_scatter +from ml_peg.analysis.utils.utils import load_metrics_config +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +THIS_DIR = Path(__file__).resolve().parent + +CALC_OUT_PATH = CALCS_ROOT / "bulk_crystal" / "ti64_phonons" / "outputs" +APP_OUT_PATH = APP_ROOT / "data" / "bulk_crystal" / "ti64_phonons" + +METRICS_YAML_PATH = THIS_DIR / "metrics.yml" +SCATTER_FILENAME = APP_OUT_PATH / "ti64_phonons_interactive.json" + +TP_ON: set[str] = { + "hcp_Ti6AlV", + "hex_Ti8AlV", + "hcp_Ti6Al2", + "hcp_Ti6V2", + "hcp_Ti7V", + "hex_Ti10Al2", + "hex_Ti10V2", +} + +CASES: list[str] = [ + "hcp_Ti6AlV", + "bcc_Ti6AlV", + "hex_Ti8AlV", + "hcp_Ti6Al2", + "hcp_Ti6V2", + "hcp_Ti7V", + "bcc_Ti6Al2", + "bcc_Ti6V2", + "hex_Ti10Al2", + "hex_Ti10V2", +] + + +THRESHOLDS, METRIC_TOOLTIPS, WEIGHTS = load_metrics_config(METRICS_YAML_PATH) + +METRIC_ID_TO_LABEL: dict[str, str] = { + "dispersion_rmse_thz_avg": "Dispersion RMSE (mean)", + "dispersion_rmse_thz_max": "Dispersion RMSE (max)", + "deltaF_0K_eV_per_atom_avg": "ΔF (0 K) mean", + "deltaF_2000K_eV_per_atom_avg": "ΔF (2000 K) mean", + "omega_avg_thz_mae": "ω_avg MAE", +} + +TABLE_METRIC_LABELS: list[str] = list(METRIC_ID_TO_LABEL.values()) +METRIC_LABELS: dict[str, str] = dict(METRIC_ID_TO_LABEL) # id -> label + + +MODELS = load_models(current_models) +MODEL_ITEMS = list(MODELS.items()) +MODEL_IDS: list[str] = [name for name, _ in MODEL_ITEMS] + + +def rmse(a: np.ndarray, b: np.ndarray) -> float: + """ + Compute root mean squared error. + + Parameters + ---------- + a + First array of values. + b + Second array of values. Must be broadcast-compatible with ``a``. + + Returns + ------- + float + Root mean squared error. + """ + a = np.asarray(a, dtype=float) + b = np.asarray(b, dtype=float) + return float(np.sqrt(np.mean((a - b) ** 2))) + + +def resample_dft_to_ml_grid( + dft_x: np.ndarray, dft_freqs: np.ndarray, n_ml: int +) -> np.ndarray: + """ + Resample DFT frequencies onto an inferred ML grid spanning the same path. + + Parameters + ---------- + dft_x + DFT path coordinate array of shape ``(n_dft,)``. + dft_freqs + DFT frequencies array of shape ``(n_dft, n_branches)``. + n_ml + Number of ML q-points. + + Returns + ------- + numpy.ndarray + DFT frequencies interpolated onto the ML grid, shape + ``(n_ml, n_branches)``. + """ + dft_x = np.asarray(dft_x, dtype=float) + dft_freqs = np.asarray(dft_freqs, dtype=float) + + ml_x = np.linspace(dft_x[0], dft_x[-1], n_ml, dtype=float) + + out = np.empty((n_ml, dft_freqs.shape[1]), dtype=float) + for j in range(dft_freqs.shape[1]): + out[:, j] = np.interp(ml_x, dft_x, dft_freqs[:, j]) + return out + + +def write_json(path: Path, obj: dict[str, Any]) -> None: + """ + Write a JSON object to disk. + + Parameters + ---------- + path + Output file path. + obj + JSON-serialisable mapping. + """ + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(obj, indent=2), encoding="utf8") + + +k_b = 8.617333262e-5 # eV/K +hbar = 6.582119569e-16 # eV*s + + +def zp_energy(weights_flat: np.ndarray, freqs_flat_thz: np.ndarray) -> float: + """ + Compute zero-point energy from phonon frequencies. + + Parameters + ---------- + weights_flat + Flattened q-point weights (tiled over branches), shape ``(n_modes,)``. + freqs_flat_thz + Flattened frequencies in THz, shape ``(n_modes,)``. + + Returns + ------- + float + Zero-point energy contribution (eV). + """ + w = np.asarray(weights_flat, dtype=float).reshape(-1) + f_thz = np.asarray(freqs_flat_thz, dtype=float).reshape(-1) + + omega = f_thz * 2.0 * np.pi * 1e12 # rad/s + zpe = w * (hbar * omega) + zpe = zpe[np.isfinite(zpe)] + zpe[zpe == float("-inf")] = 0.0 + zpe[zpe == float("+inf")] = 0.0 + return float(0.5 * np.sum(zpe)) + + +def helmholtz_free_energy_og( + weights_flat: np.ndarray, freqs_flat_thz: np.ndarray, t: float +) -> float: + """ + Compute Helmholtz free energy. + + Parameters + ---------- + weights_flat + Flattened q-point weights (tiled over branches), shape ``(n_modes,)``. + freqs_flat_thz + Flattened frequencies in THz, shape ``(n_modes,)``. + t + Temperature (K). + + Returns + ------- + float + Helmholtz free-energy thermal contribution (eV). + """ + w = np.asarray(weights_flat, dtype=float).reshape(-1) + f_thz = np.asarray(freqs_flat_thz, dtype=float).reshape(-1) + + if t <= 1e-12: + t = 1e-12 + + omega = f_thz * 2.0 * np.pi * 1e12 # rad/s + arg = -(hbar * omega) / (k_b * t) + + integrand = w * np.log(1.0 - np.exp(arg)) + integrand = integrand[np.isfinite(integrand)] + integrand[integrand == float("-inf")] = 0.0 + integrand[integrand == float("+inf")] = 0.0 + return float(np.sum(integrand) * k_b * t) + + +def analyse_one_model(model_id: str) -> None: + """ + Analyse a single Ti64 phonon model and write per-model metrics. + + Parameters + ---------- + model_id + Identifier of the model under + ``ml_peg/calcs/bulk_crystal/ti64_phonons/outputs``. + + Notes + ----- + Writes per-model metrics to: + + - ``ml_peg/app/data/bulk_crystal/ti64_phonons//metrics.json`` + + The JSON file contains aggregated metrics and per-case values. + """ + model_calc_dir = CALC_OUT_PATH / model_id + assert model_calc_dir.exists(), ( + f"No calc outputs found for model '{model_id}'. Expected:\n" + f" {model_calc_dir}\n\n" + "Calc stage writes to:\n" + " ml_peg/calcs/bulk_crystal/ti64_phonons/outputs//...\n" + ) + + model_app_dir = APP_OUT_PATH / model_id + model_app_dir.mkdir(parents=True, exist_ok=True) + + rmse_by_case: dict[str, float] = {} + df0_by_case: dict[str, float] = {} + df2000_by_case: dict[str, float] = {} + + omega_avg_ref_thz_by_case: dict[str, float] = {} + omega_avg_pred_thz_by_case: dict[str, float] = {} + + for case in CASES: + npz_path = model_calc_dir / f"{case}.npz" + assert npz_path.exists(), f"Missing {npz_path}" + + data = np.load(npz_path, allow_pickle=True) + + dft_x = np.asarray(data["dft_x"], dtype=float) + dft_freq = np.asarray(data["dft_frequencies"], dtype=float) + ml_freq = np.asarray(data["ml_frequencies"], dtype=float) + + dft_on_ml = resample_dft_to_ml_grid(dft_x, dft_freq, n_ml=ml_freq.shape[0]) + + omega_avg_ref = float(np.mean(dft_on_ml)) + omega_avg_pred = float(np.mean(ml_freq)) + omega_avg_ref_thz_by_case[case] = omega_avg_ref + omega_avg_pred_thz_by_case[case] = omega_avg_pred + + rmse_by_case[case] = rmse(dft_on_ml, ml_freq) + + if case in TP_ON: + assert "tp_temperatures" in data and "tp_free_energy" in data, ( + f"TP expected for {case} but tp_* arrays missing in {npz_path}" + ) + + required = ["q_weights", "q_frequencies_dft", "n_atoms"] + missing = [k for k in required if k not in data.files] + assert not missing, ( + f"{case}: missing required thermo keys in {npz_path}: {missing}" + ) + + q_w = np.asarray(data["q_weights"], dtype=float) + q_f = np.asarray(data["q_frequencies_dft"], dtype=float) + + weights_tile = np.tile(q_w[:, None], (1, q_f.shape[1])).reshape( + -1, order="F" + ) + freqs_flat = q_f.reshape(-1, order="F") # THz + + n_atoms = int(np.asarray(data["n_atoms"]).item()) + + ml_t = np.asarray(data["tp_temperatures"], dtype=float) + ml_f = np.asarray(data["tp_free_energy"], dtype=float) + + # Legacy unit fix (keep existing behavior) + if np.nanmax(np.abs(ml_f)) > 100.0: + ml_f = ml_f / 96.32 + + t_dense = np.linspace(0.0, 2000.0, 2000, dtype=float) + zpe_const = zp_energy(weights_tile, freqs_flat) + dft_f_dense = np.array( + [ + helmholtz_free_energy_og(weights_tile, freqs_flat, tt) + zpe_const + for tt in t_dense + ], + dtype=float, + ) + dft_f_on_mlt = np.interp(ml_t, t_dense, dft_f_dense) + + df0_by_case[case] = float(np.abs(dft_f_on_mlt[0] - ml_f[0]) / n_atoms) + df2000_by_case[case] = float(np.abs(dft_f_on_mlt[-1] - ml_f[-1]) / n_atoms) + + rmse_vals = np.asarray(list(rmse_by_case.values()), dtype=float) + + omega_avg_mae = ( + float( + np.mean( + [ + abs(omega_avg_pred_thz_by_case[c] - omega_avg_ref_thz_by_case[c]) + for c in omega_avg_ref_thz_by_case + ] + ) + ) + if omega_avg_ref_thz_by_case + else None + ) + + write_json( + model_app_dir / "metrics.json", + { + "model": model_id, + "n_cases": len(CASES), + "metrics": { + "dispersion_rmse_thz_avg": float(np.mean(rmse_vals)), + "dispersion_rmse_thz_max": float(np.max(rmse_vals)), + "deltaF_0K_eV_per_atom_avg": float(np.mean(list(df0_by_case.values()))) + if df0_by_case + else None, + "deltaF_2000K_eV_per_atom_avg": float( + np.mean(list(df2000_by_case.values())) + ) + if df2000_by_case + else None, + "omega_avg_thz_mae": omega_avg_mae, + }, + "by_case": { + "rmse_thz": rmse_by_case, + "deltaF_0K_eV_per_atom": df0_by_case, + "deltaF_2000K_eV_per_atom": df2000_by_case, + "omega_avg_ref_thz": omega_avg_ref_thz_by_case, + "omega_avg_pred_thz": omega_avg_pred_thz_by_case, + }, + }, + ) + + assert len(rmse_by_case) == len(CASES) + + +@pytest.fixture(scope="session") +def run_all_models() -> None: + """ + Generate per-model ``metrics.json`` for all configured models. + + Returns + ------- + None + This fixture exists for its side effects (writing per-model metrics). + """ + for model_id in MODEL_IDS: + analyse_one_model(model_id) + + +@pytest.fixture(scope="session") +@build_table( + filename=APP_OUT_PATH / "ti64_phonons_metrics_table.json", + thresholds=THRESHOLDS, + metric_tooltips=METRIC_TOOLTIPS, + weights=WEIGHTS, +) +def metrics_table(run_all_models: None) -> dict[str, dict[str, float | None]]: + """ + Build the Ti64 metrics table for the Dash app. + + Parameters + ---------- + run_all_models + Session-scoped fixture ensuring per-model metrics are generated. + + Returns + ------- + dict[str, dict[str, float | None]] + Mapping of metric label to per-model values. + """ + _ = run_all_models + + table: dict[str, dict[str, float | None]] = { + label: {} for label in TABLE_METRIC_LABELS + } + + for model_id in MODEL_IDS: + mpath = APP_OUT_PATH / model_id / "metrics.json" + if not mpath.exists(): + continue + + m = json.loads(mpath.read_text(encoding="utf8")) + metrics = m.get("metrics", {}) + + for metric_id, label in METRIC_ID_TO_LABEL.items(): + table[label][model_id] = metrics.get(metric_id) + + return table + + +@pytest.fixture(scope="session") +@cell_to_scatter( + filename=SCATTER_FILENAME, + x_label="Predicted", + y_label="Reference", +) +def interactive_dataset(run_all_models: None) -> dict[str, Any]: + """ + Build the interactive scatter dataset for the Ti64 phonons Dash app. + + Parameters + ---------- + run_all_models + Session-scoped fixture ensuring per-model metrics are generated. + + Returns + ------- + dict[str, Any] + Interactive dataset written to JSON by the decorator. + """ + _ = run_all_models + + dataset: dict[str, Any] = { + "metrics": METRIC_LABELS, # id -> label + "models": {}, + } + + metric_id = "omega_avg_thz_mae" + + for model_id in MODEL_IDS: + metrics_path = APP_OUT_PATH / model_id / "metrics.json" + if not metrics_path.exists(): + continue + + m = json.loads(metrics_path.read_text(encoding="utf8")) + by_case = m.get("by_case") or {} + ref_map = by_case.get("omega_avg_ref_thz", {}) or {} + pred_map = by_case.get("omega_avg_pred_thz", {}) or {} + + points: list[dict[str, Any]] = [] + for case in CASES: + if case not in ref_map or case not in pred_map: + continue + + data_paths = { + "npz": str( + (CALC_OUT_PATH / model_id / f"{case}.npz").relative_to( + CALC_OUT_PATH.parent + ) + ), + "meta": str( + (CALC_OUT_PATH / model_id / f"{case}.json").relative_to( + CALC_OUT_PATH.parent + ) + ), + } + + points.append( + { + "id": case, + "label": case, + "ref": ref_map[case], + "pred": pred_map[case], + "data_paths": data_paths, + } + ) + + dataset["models"][model_id] = { + "model": model_id, + "metrics": { + metric_id: { + "points": points, + "mae": (m.get("metrics") or {}).get(metric_id), + } + }, + } + + return dataset + + +def test_all_models_metrics_written(run_all_models: None) -> None: + """ + Check per-model ``metrics.json`` exists for every configured model. + + Parameters + ---------- + run_all_models + Session-scoped fixture ensuring per-model metrics are generated. + """ + _ = run_all_models + + missing: list[str] = [] + for model_id in MODEL_IDS: + if not (APP_OUT_PATH / model_id / "metrics.json").exists(): + missing.append(model_id) + + assert not missing, f"Missing metrics.json for models: {missing}" + + +def test_write_metrics_table(metrics_table: dict[str, Any]) -> None: + """ + Check the table JSON artifact is produced and includes all models. + + Parameters + ---------- + metrics_table + Fixture providing the metrics table mapping (and/or triggering JSON writing). + """ + assert isinstance(metrics_table, dict) + + table_path = APP_OUT_PATH / "ti64_phonons_metrics_table.json" + assert table_path.exists() + + payload = json.loads(table_path.read_text(encoding="utf8")) + rows = payload.get("data", []) + ids = {row.get("id") for row in rows if isinstance(row, dict)} + missing = [m for m in MODEL_IDS if m not in ids] + assert not missing, f"Table missing model rows: {missing}" + + +def test_write_interactive_json(interactive_dataset: dict[str, Any]) -> None: + """ + Check the interactive JSON artifact is produced and includes all models. + + Parameters + ---------- + interactive_dataset + Fixture providing the interactive dataset (and/or triggering JSON writing). + """ + assert isinstance(interactive_dataset, dict) + assert SCATTER_FILENAME.exists() + + payload = json.loads(SCATTER_FILENAME.read_text(encoding="utf8")) + models = payload.get("models", {}) + missing = [m for m in MODEL_IDS if m not in models] + assert not missing, f"Interactive dataset missing models: {missing}" diff --git a/ml_peg/analysis/bulk_crystal/ti64_phonons/metrics.yml b/ml_peg/analysis/bulk_crystal/ti64_phonons/metrics.yml new file mode 100644 index 000000000..a3bbc6681 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/ti64_phonons/metrics.yml @@ -0,0 +1,40 @@ +metrics: + Dispersion RMSE (mean): + good: 0.0 + bad: 2.0 + unit: THz + tooltip: "Average RMSE between ML and DFT phonon dispersions across the 10 Ti64 cases (after resampling onto a common k-grid)." + weight: 1.0 + level_of_theory: PBE + + Dispersion RMSE (max): + good: 0.0 + bad: 4.0 + unit: THz + tooltip: "Worst-case RMSE between ML and DFT phonon dispersions among the 10 Ti64 cases." + weight: 0.5 + level_of_theory: PBE + + ΔF (0 K) mean: + good: 0.0 + bad: 0.007 + unit: eV/atom + tooltip: "Mean |ΔF| (free energy) at 0 K per atom between ML and DFT." + weight: 1.0 + level_of_theory: PBE + + ΔF (2000 K) mean: + good: 0.0 + bad: 0.03 + unit: eV/atom + tooltip: "Mean |ΔF| (free energy) at 2000 K per atom between ML and DFT." + weight: 1.0 + level_of_theory: PBE + + ω_avg MAE: + good: 0.0 + bad: 1.0 + unit: THz + tooltip: "MAE of ω_avg per case, where ω_avg is the mean of all phonon band frequencies for that case, averaged over the 10 Ti64 cases." + weight: 0.5 + level_of_theory: PBE diff --git a/ml_peg/app/bulk_crystal/diamond_phonons/app_diamond_phonons.py b/ml_peg/app/bulk_crystal/diamond_phonons/app_diamond_phonons.py new file mode 100644 index 000000000..0f9f205a4 --- /dev/null +++ b/ml_peg/app/bulk_crystal/diamond_phonons/app_diamond_phonons.py @@ -0,0 +1,202 @@ +"""Run diamond phonon dispersion app (bands-only benchmark).""" + +from __future__ import annotations + +from collections.abc import Mapping +from functools import partial +import json +from pathlib import Path +from typing import Any + +from dash import Dash, dcc, html + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.app.bulk_crystal.diamond_phonons.diamond_interactive_helpers import ( + render_dispersion_component, +) +from ml_peg.app.utils.build_callbacks import ( + model_asset_from_scatter, + scatter_and_assets_from_table, +) +from ml_peg.app.utils.plot_helpers import build_serialized_scatter_content +from ml_peg.calcs import CALCS_ROOT +from ml_peg.calcs.utils.utils import download_github_data + +GITHUB_BASE = "https://raw.githubusercontent.com/7radians/ml-peg-data/main" + +EXTRACTED_ROOT = Path( + download_github_data( + filename="diamond_data/data.zip", + github_uri=GITHUB_BASE, + ) +) + +DATA_PATH = EXTRACTED_ROOT / "data" +DFT_REF_PATH = DATA_PATH / "dft_band.npz" + + +BENCHMARK_NAME = "diamond_phonons" + +DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / BENCHMARK_NAME +TABLE_PATH = DATA_PATH / "diamond_phonons_bands_table.json" +SCATTER_PATH = DATA_PATH / "diamond_phonons_bands_interactive.json" + +DOCS_URL = ( + "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk_crystal.html" + f"#{BENCHMARK_NAME}" +) + +CALC_BASE = CALCS_ROOT / "bulk_crystal" / BENCHMARK_NAME +# DFT_REF_PATH = Path("data") / "dft_band.npz" + +PLOT_CONTAINER_ID = f"{BENCHMARK_NAME}-plot-container" +DISPERSION_CONTAINER_ID = f"{BENCHMARK_NAME}-dispersion-container" +LAST_CELL_STORE_ID = f"{BENCHMARK_NAME}-last-cell" +SCATTER_METADATA_STORE_ID = f"{BENCHMARK_NAME}-scatter-meta" +SCATTER_GRAPH_ID = f"{BENCHMARK_NAME}-scatter" + + +def model_only_lookup( + click_data: Mapping[str, Any] | None, + metadata: Mapping[str, Any], +) -> dict[str, Any]: + """ + Build a selection context for the dispersion preview. + + For this benchmark we ignore which scatter point was clicked and return a + single dispersion preview per model. The returned ``band_yaml`` must be + resolvable under the app's ``calc_root`` passed to the renderer. + + Parameters + ---------- + click_data + Dash click payload from the scatter plot. Unused for this benchmark. + metadata + Metadata payload produced by the scatter callback helpers. Must contain + a ``model`` key. + + Returns + ------- + dict + Selection context consumed by ``render_dispersion_component``. + """ + _ = click_data + model = str(metadata["model"]) + return { + "model": model, + "selection": { + "id": "diamond", + "label": "Carbon diamond", + "band_yaml": f"outputs/{model}/band.yaml", + }, + } + + +class DiamondPhononApp(BaseApp): + """Bands-only phonon benchmark app wiring callbacks and layout.""" + + def register_callbacks(self) -> None: + """Register scatter and dispersion callbacks.""" + with SCATTER_PATH.open(encoding="utf8") as handle: + interactive_data = json.load(handle) + + calc_root = Path(CALC_BASE) + models_data = interactive_data.get("models", {}) + metric_labels = interactive_data.get("metrics", {}) + label_to_key = {label: key for key, label in metric_labels.items()} + + refresh_msg = ( + "Click on a metric to view DFT vs predicted frequency scatter plots." + ) + + metric_handler = partial( + build_serialized_scatter_content, + models_data=models_data, + label_map=label_to_key, + scatter_id=SCATTER_GRAPH_ID, + instructions=refresh_msg, + ) + + # Bands-only benchmark: no BZ violin panel and no stability panel. + scatter_and_assets_from_table( + table_id=self.table_id, + table_data=self.table.data, + plot_container_id=PLOT_CONTAINER_ID, + scatter_metadata_store_id=SCATTER_METADATA_STORE_ID, + last_cell_store_id=LAST_CELL_STORE_ID, + column_handlers={}, # only metric scatter + default_handler=metric_handler, + ) + + dispersion_renderer = partial( + render_dispersion_component, + calc_root=calc_root, + frequency_scale=1, + frequency_unit="THz", + reference_label="RSCAN", + reference_band_npz=DFT_REF_PATH, + ) + + model_asset_from_scatter( + scatter_id=SCATTER_GRAPH_ID, + metadata_store_id=SCATTER_METADATA_STORE_ID, + asset_container_id=DISPERSION_CONTAINER_ID, + data_lookup=model_only_lookup, + asset_renderer=dispersion_renderer, + empty_message="Select a model to preview the phonon dispersion.", + missing_message="No band.yaml found for this model.", + ) + + +def get_app() -> DiamondPhononApp: + """ + Construct the diamond phonon PhononApp instance. + + Returns + ------- + PhononApp + Configured application with table + scatter/dispersion panels. + """ + return DiamondPhononApp( + name=BENCHMARK_NAME, + description=( + "Accuracy of MLIPs in predicting phonon dispersions for Carbon diamond " + "(RSCAN)." + ), + docs_url=DOCS_URL, + table_path=TABLE_PATH, + extra_components=[ + dcc.Store(id=LAST_CELL_STORE_ID), + dcc.Store(id=SCATTER_METADATA_STORE_ID), + html.Div( + [ + html.Div( + "Click on a metric to view DFT vs predicted frequency scatter " + "plots.", + id=PLOT_CONTAINER_ID, + style={"flex": "1", "minWidth": 0}, + ), + html.Div( + "Click on a scatter point to view the dispersion plot.", + id=DISPERSION_CONTAINER_ID, + style={"flex": "1", "minWidth": 0}, + ), + ], + style={ + "display": "flex", + "gap": "24px", + "alignItems": "stretch", + "flexWrap": "wrap", + }, + ), + ], + ) + + +if __name__ == "__main__": + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + diamond_phonon_app = get_app() + full_app.layout = diamond_phonon_app.layout + diamond_phonon_app.register_callbacks() + full_app.run(port=8060, debug=True) diff --git a/ml_peg/app/bulk_crystal/diamond_phonons/diamond_interactive_helpers.py b/ml_peg/app/bulk_crystal/diamond_phonons/diamond_interactive_helpers.py new file mode 100644 index 000000000..729fbf673 --- /dev/null +++ b/ml_peg/app/bulk_crystal/diamond_phonons/diamond_interactive_helpers.py @@ -0,0 +1,287 @@ +""" +Utilities for diamond phonon interactive assets (bands-only). + +Supports rendering a dispersion preview from: +- Predicted Phonopy band.yaml (THz) +- Reference dft_band.npz with keys: distance, qpoints, frequencies (cm^-1) +""" + +from __future__ import annotations + +import base64 +from collections.abc import Mapping +from io import BytesIO +from pathlib import Path +from typing import Any + +import matplotlib + +matplotlib.use("Agg") +from dash import html +import matplotlib.pyplot as plt +import numpy as np + +THZ_TO_CM1 = 33.35640951981521 +HS_LABELS = [r"$\Gamma$", "X", "W", "K", r"$\Gamma$", "L", "U", "W", "L", "K", "X"] + + +def render_dispersion_component( + selection_context: Mapping[str, Any], + calc_root: Path, + frequency_scale: float, + frequency_unit: str, + reference_label: str, + reference_band_npz: Path | None = None, +): + """ + Return a Dash component containing a dispersion PNG, or None. + + Parameters + ---------- + selection_context : Mapping[str, Any] + Selection context; expects selection["band_yaml"] and optional label/id. + calc_root : Path + Root directory used to resolve files. + frequency_scale : float + Multiplicative scale applied to frequencies. + frequency_unit : str + Y-axis unit label. + reference_label : str + Legend label for the reference overlay. + reference_band_npz : Path | None + Optional reference .npz (distance, qpoints, frequencies in cm^-1). + + Returns + ------- + html.Div | None + Component with an embedded PNG (data URI), or None if unavailable. + """ + model_display = selection_context.get("model") + selected = selection_context.get("selection") or {} + + band_yaml = selected.get("band_yaml") + label = selected.get("label") or selected.get("id", "") + + if not band_yaml: + return None + + image_src = render_band_yaml_png( + calc_root=calc_root, + band_yaml=str(band_yaml), + reference_band_npz=reference_band_npz, + frequency_scale=float(frequency_scale), + frequency_unit=str(frequency_unit), + reference_label=str(reference_label), + prediction_label=str(model_display) + if model_display is not None + else "Prediction", + ) + + if not image_src: + return None + + return html.Div( + [ + html.H4(str(label)) if label else None, + html.Img( + src=image_src, + style={ + "width": "100%", + "maxWidth": "820px", + "height": "auto", + "display": "block", + "borderRadius": "8px", + "border": "1px solid #ddd", + }, + ), + ] + ) + + +def render_band_yaml_png( + *, + calc_root: Path, + band_yaml: str, + reference_band_npz: Path | None, + frequency_scale: float, + frequency_unit: str, + reference_label: str, + prediction_label: str, +) -> str | None: + """ + Render a dispersion PNG from band.yaml with optional reference overlay. + + Parameters + ---------- + calc_root : Path + Root directory used to resolve files. + band_yaml : str + Predicted Phonopy band.yaml path (relative to calc_root). + reference_band_npz : Path | None + Optional reference .npz (distance, qpoints, frequencies in cm^-1). + frequency_scale : float + Multiplicative scale applied to frequencies. + frequency_unit : str + Y-axis unit label. + reference_label : str + Legend label for the reference overlay. + prediction_label : str + Legend label for the prediction. + + Returns + ------- + str | None + Base64 PNG data URI, or None on failure. + """ + import yaml # type: ignore + + def _detect_symmetry_boundaries(q_ref: np.ndarray) -> list[int]: + """ + Return indices of k-path corners (including first and last). + + Parameters + ---------- + q_ref : np.ndarray + Array of q-points with shape (N, 3) along the band path. + + Returns + ------- + list[int] + Sorted indices of symmetry boundaries (path corners), including + indices 0 and N-1. + """ + dq = np.diff(q_ref, axis=0) + dq_norm = np.linalg.norm(dq, axis=1) + eps = 1e-12 + dq_unit = dq / (dq_norm[:, None] + eps) + cosang = np.sum(dq_unit[1:] * dq_unit[:-1], axis=1) + cand = np.where(cosang < 0.95)[0] + 1 + + boundaries = [0] + for i in cand: + boundaries.append(int(i)) + + boundaries = sorted(set(boundaries)) + if boundaries[0] != 0: + boundaries = [0] + boundaries + if boundaries[-1] != len(q_ref) - 1: + boundaries.append(len(q_ref) - 1) + return boundaries + + pred_path = Path(calc_root) / band_yaml + if not pred_path.exists(): + return None + + try: + with pred_path.open("r", encoding="utf8") as f: + y = yaml.safe_load(f) + except Exception: + return None + + phonon = y.get("phonon", None) + if not isinstance(phonon, list) or not phonon: + return None + + s_pred = np.asarray([p.get("distance", np.nan) for p in phonon], dtype=float) + + freqs_pred_thz = np.asarray( + [[b.get("frequency", np.nan) for b in p.get("band", [])] for p in phonon], + dtype=float, + ) + if freqs_pred_thz.ndim != 2 or not np.isfinite(freqs_pred_thz).all(): + return None + + freqs_pred_thz = freqs_pred_thz * float(frequency_scale) + + s_ref: np.ndarray | None = None + q_ref: np.ndarray | None = None + freqs_ref_thz: np.ndarray | None = None + sym_pos: np.ndarray | None = None + + if reference_band_npz is not None: + ref_path = Path(calc_root) / reference_band_npz + if ref_path.exists(): + obj = np.load(ref_path, allow_pickle=False) + + s_ref = np.asarray(obj["distance"], dtype=float) + q_ref = np.asarray(obj["qpoints"], dtype=float) + freqs_ref_cm1 = np.asarray(obj["frequencies"], dtype=float) + + if freqs_ref_cm1.ndim == 3 and freqs_ref_cm1.shape[0] == 1: + freqs_ref_cm1 = freqs_ref_cm1[0] + + if freqs_ref_cm1.ndim == 2 and np.isfinite(freqs_ref_cm1).all(): + freqs_ref_thz = (freqs_ref_cm1 / THZ_TO_CM1) * float(frequency_scale) + bounds = _detect_symmetry_boundaries(q_ref) + sym_pos = s_ref[bounds] + else: + s_ref = None + q_ref = None + freqs_ref_thz = None + sym_pos = None + + use_ref_x = s_ref is not None and len(s_ref) == len(s_pred) + x = s_ref if use_ref_x else s_pred + + fig, ax = plt.subplots(figsize=(10, 7)) + + for j in range(freqs_pred_thz.shape[1]): + ax.plot( + x, + freqs_pred_thz[:, j], + color="#1f77b4", + lw=2.0, + label=prediction_label if j == 0 else None, + ) + + if freqs_ref_thz is not None and s_ref is not None: + x_ref = x if use_ref_x else s_ref + for j in range(freqs_ref_thz.shape[1]): + ax.plot( + x_ref, + freqs_ref_thz[:, j], + color="k", + lw=2.0, + ls="--", + label=reference_label if j == 0 else None, + ) + + if sym_pos is not None and len(sym_pos) == len(HS_LABELS): + for xpos in sym_pos: + ax.axvline(float(xpos), color="k", lw=1.0, alpha=0.6) + ax.set_xticks(sym_pos) + ax.set_xticklabels(HS_LABELS) + ax.set_xlim(float(x[0]), float(x[-1])) + + ax.set_xlabel("Wave vector", fontsize=18) + ax.set_ylabel(f"Frequency ({frequency_unit})", fontsize=18) + ax.axhline(0.0, color="k", lw=1.5) + ax.grid(axis="x") + ax.set_ylim( + float( + min( + np.nanmin(freqs_pred_thz), + np.nanmin(freqs_ref_thz) if freqs_ref_thz is not None else np.inf, + ) + ), + float( + max( + np.nanmax(freqs_pred_thz), + np.nanmax(freqs_ref_thz) if freqs_ref_thz is not None else -np.inf, + ) + ), + ) + + handles, labels = ax.get_legend_handles_labels() + by_label = dict(zip(labels, handles, strict=False)) + if by_label: + ax.legend(by_label.values(), by_label.keys(), loc=1, fontsize=14) + + fig.tight_layout() + + buffer = BytesIO() + fig.savefig(buffer, format="png", dpi=200) + plt.close(fig) + + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + return f"data:image/png;base64,{encoded}" diff --git a/ml_peg/app/bulk_crystal/ti64_phonons/app_ti64_phonons.py b/ml_peg/app/bulk_crystal/ti64_phonons/app_ti64_phonons.py new file mode 100644 index 000000000..93ad5bd08 --- /dev/null +++ b/ml_peg/app/bulk_crystal/ti64_phonons/app_ti64_phonons.py @@ -0,0 +1,170 @@ +"""Run Ti64 phonon dispersion + DOS + TP app.""" + +from __future__ import annotations + +from functools import partial +import json + +from dash import Dash, dcc, html + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.app.bulk_crystal.ti64_phonons.ti64_interactive_helpers import ( + lookup_system_entry, + render_dispersion_component, +) +from ml_peg.app.utils.build_callbacks import ( + model_asset_from_scatter, + scatter_and_assets_from_table, +) +from ml_peg.app.utils.plot_helpers import ( + build_serialized_scatter_content, + resolve_scatter_selection, +) + +BENCHMARK_NAME = "ti64_phonons" + +DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / BENCHMARK_NAME +TABLE_PATH = DATA_PATH / "ti64_phonons_metrics_table.json" +SCATTER_PATH = DATA_PATH / "ti64_phonons_interactive.json" + +CALC_ROOT = APP_ROOT.parent / "calcs" / "bulk_crystal" / BENCHMARK_NAME + +DOCS_URL = ( + "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk_crystal.html#phonons" +) + +PLOT_CONTAINER_ID = f"{BENCHMARK_NAME}-plot-container" +DISPERSION_CONTAINER_ID = f"{BENCHMARK_NAME}-dispersion-container" +LAST_CELL_STORE_ID = f"{BENCHMARK_NAME}-last-cell" +SCATTER_METADATA_STORE_ID = f"{BENCHMARK_NAME}-scatter-meta" +SCATTER_GRAPH_ID = f"{BENCHMARK_NAME}-scatter" + + +class Ti64PhononsApp(BaseApp): + """Ti64 phonons benchmark app wiring callbacks and layout.""" + + def register_callbacks(self) -> None: + """Register scatter/dispersion callbacks via shared helpers.""" + with SCATTER_PATH.open(encoding="utf8") as handle: + interactive_data = json.load(handle) + + models_data = interactive_data.get("models", {}) + + metric_labels = interactive_data.get("metrics", {}) # {metric_id: label} + label_to_key = {label: key for key, label in metric_labels.items()} + + omega_label = metric_labels.get("omega_avg_thz_mae", "ω_avg MAE") + + metric_handler = partial( + build_serialized_scatter_content, + models_data=models_data, + label_map=label_to_key, + scatter_id=SCATTER_GRAPH_ID, + instructions="Click any cell to view ω_avg (ref vs pred) scatter.", + ) + + def omega_only_handler(model_display: str, column_id: str): + """ + Render the ω_avg scatter for the selected model. + + Parameters + ---------- + model_display + Display name of the selected model row. + column_id + Column identifier from the table callback. + + Returns + ------- + Any + Dash component(s) produced by the scatter renderer. + """ + _ = column_id + return metric_handler(model_display, omega_label) + + column_handlers = dict.fromkeys(label_to_key.keys(), omega_only_handler) + + scatter_and_assets_from_table( + table_id=self.table_id, + table_data=self.table.data, + plot_container_id=PLOT_CONTAINER_ID, + scatter_metadata_store_id=SCATTER_METADATA_STORE_ID, + last_cell_store_id=LAST_CELL_STORE_ID, + column_handlers=column_handlers, + default_handler=omega_only_handler, + ) + + selection_lookup = partial( + resolve_scatter_selection, + models_data=models_data, + system_lookup=partial( + lookup_system_entry, + data_root=DATA_PATH, # kept for API compatibility; unused by new helper + assets_prefix=f"bulk_crystal/{BENCHMARK_NAME}", # unused by new helper + ), + ) + + dispersion_renderer = partial(render_dispersion_component, calc_root=CALC_ROOT) + + model_asset_from_scatter( + scatter_id=SCATTER_GRAPH_ID, + metadata_store_id=SCATTER_METADATA_STORE_ID, + asset_container_id=DISPERSION_CONTAINER_ID, + data_lookup=selection_lookup, + asset_renderer=dispersion_renderer, + empty_message="Click on a data point to preview the dispersion + DOS.", + missing_message="No dispersion plot available for this point.", + ) + + +def get_app() -> Ti64PhononsApp: + """ + Construct the Ti64PhononsApp instance. + + Returns + ------- + Ti64PhononsApp + Configured application with table + scatter/dispersion panels. + """ + return Ti64PhononsApp( + name=BENCHMARK_NAME, + description=( + "Accuracy of MLIPs in predicting phonon dispersions and vibrational " + "thermodynamics for Ti64 alloy." + ), + docs_url=DOCS_URL, + table_path=TABLE_PATH, + extra_components=[ + dcc.Store(id=LAST_CELL_STORE_ID), + dcc.Store(id=SCATTER_METADATA_STORE_ID), + html.Div( + [ + html.Div( + "Click any cell to view ω_avg (ref vs pred) scatter.", + id=PLOT_CONTAINER_ID, + style={"flex": "1", "minWidth": 0}, + ), + html.Div( + "Click on a data point to preview the dispersion + DOS.", + id=DISPERSION_CONTAINER_ID, + style={"flex": "1", "minWidth": 0}, + ), + ], + style={ + "display": "flex", + "gap": "24px", + "alignItems": "stretch", + "flexWrap": "wrap", + }, + ), + ], + ) + + +if __name__ == "__main__": + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + ti64_app = get_app() + full_app.layout = ti64_app.layout + ti64_app.register_callbacks() + full_app.run(port=8060, debug=True) diff --git a/ml_peg/app/bulk_crystal/ti64_phonons/ti64_interactive_helpers.py b/ml_peg/app/bulk_crystal/ti64_phonons/ti64_interactive_helpers.py new file mode 100644 index 000000000..ef2fecaf2 --- /dev/null +++ b/ml_peg/app/bulk_crystal/ti64_phonons/ti64_interactive_helpers.py @@ -0,0 +1,510 @@ +"""Helpers for Ti64 phonon interactive dispersion/DOS rendering.""" + +from __future__ import annotations + +import base64 +from collections.abc import Mapping +from functools import lru_cache +from io import BytesIO +import json +from pathlib import Path +from typing import Any + +import matplotlib +import numpy as np + +matplotlib.use("Agg") +from dash import html +import matplotlib.pyplot as plt + + +def _load_json(path: Path) -> Any: + """ + Load a JSON file from disk. + + Parameters + ---------- + path + Path to a JSON file. + + Returns + ------- + Any + Parsed JSON payload. + """ + with path.open("r", encoding="utf8") as f: + return json.load(f) + + +def lookup_system_entry( + model_entry: Mapping[str, Any], + point_id: str | int, + *, + data_root: Path, + assets_prefix: str, +) -> dict[str, Any] | None: + """ + Resolve a scatter (model, point) selection into a data-path payload. + + Parameters + ---------- + model_entry + Model entry from the interactive dataset (contains ``model`` and ``metrics``). + point_id + Point identifier selected in the scatter. + data_root + Root directory for app data (kept for API compatibility; unused here). + assets_prefix + Assets prefix for app data (kept for API compatibility; unused here). + + Returns + ------- + dict[str, Any] or None + Selection dictionary containing ``model``, ``system``, and ``data_paths`` if the + point is found; otherwise ``None``. + """ + _ = data_root + _ = assets_prefix + + model_name = ( + model_entry.get("model") or model_entry.get("id") or model_entry.get("name") + ) + if not isinstance(model_name, str) or not model_name.strip(): + return None + + system_name = str(point_id) + + metrics = model_entry.get("metrics", {}) + if not isinstance(metrics, dict): + return None + + # Search all metrics for a matching point id + for _metric_id, metric_payload in metrics.items(): + if not isinstance(metric_payload, dict): + continue + points = metric_payload.get("points", []) + if not isinstance(points, list): + continue + for p in points: + if not isinstance(p, dict): + continue + if str(p.get("id")) != system_name: + continue + + data_paths = p.get("data_paths") + if not isinstance(data_paths, dict): + return None + + # must include at least the calc npz path + npz_rel = data_paths.get("npz") + if not isinstance(npz_rel, str) or not npz_rel: + return None + + meta_rel = data_paths.get("meta") + + return { + "model": model_name, + "system": system_name, + "data_paths": { + "npz": npz_rel, + "meta": meta_rel, + }, + "label": p.get("label"), + "ref": p.get("ref"), + "pred": p.get("pred"), + } + + return None + + +def resample_dft_to_ml_grid( + dft_x: np.ndarray, dft_freqs: np.ndarray, n_ml: int +) -> np.ndarray: + """ + Resample DFT frequencies onto a uniform ML grid spanning the same path. + + Parameters + ---------- + dft_x + DFT path coordinate array of shape ``(n_dft,)``. + dft_freqs + DFT frequencies array of shape ``(n_dft, n_branches)``. + n_ml + Number of ML q-points. + + Returns + ------- + numpy.ndarray + DFT frequencies interpolated onto the ML grid, shape + ``(n_ml, n_branches)``. + """ + dft_x = np.asarray(dft_x, dtype=float).reshape(-1) + dft_freqs = np.asarray(dft_freqs, dtype=float) + + ml_x = np.linspace(dft_x[0], dft_x[-1], int(n_ml), dtype=float) + + out = np.empty((ml_x.size, dft_freqs.shape[1]), dtype=float) + for j in range(dft_freqs.shape[1]): + out[:, j] = np.interp(ml_x, dft_x, dft_freqs[:, j]) + return out + + +def gaussian_broadened_dos( + freqs_flat: np.ndarray, + weights_flat: np.ndarray, + grid: np.ndarray, + sigma: float, +) -> np.ndarray: + """ + Compute a Gaussian-broadened DOS on a target frequency grid. + + Parameters + ---------- + freqs_flat + Flattened frequencies (e.g. THz), shape ``(n_modes,)``. + weights_flat + Flattened weights matching ``freqs_flat``, shape ``(n_modes,)``. + grid + Frequency grid to evaluate the DOS on. + sigma + Gaussian broadening (same units as ``grid``). + + Returns + ------- + numpy.ndarray + DOS values evaluated on ``grid``. + """ + f = np.asarray(freqs_flat, dtype=float).reshape(-1) + w = np.asarray(weights_flat, dtype=float).reshape(-1) + x = np.asarray(grid, dtype=float).reshape(-1) + + diff = f[:, None] - x[None, :] + return np.sum(w[:, None] * np.exp(-0.5 * (diff / sigma) ** 2), axis=0) + + +def gaussian_smooth_on_grid(x: np.ndarray, y: np.ndarray, sigma: float) -> np.ndarray: + """ + Smooth 1D data on a uniform grid using a Gaussian kernel. + + Parameters + ---------- + x + Grid values. + y + Values defined on ``x``. + sigma + Gaussian kernel standard deviation in the same units as ``x``. + + Returns + ------- + numpy.ndarray + Smoothed values on the same grid. + """ + x = np.asarray(x, dtype=float).reshape(-1) + y = np.asarray(y, dtype=float).reshape(-1) + + if x.size < 3: + return y + + dx = float(np.median(np.diff(x))) + if dx <= 0: + return y + + sigma_pts = sigma / dx + half = int(max(3, np.ceil(4.0 * sigma_pts))) + t = np.arange(-half, half + 1, dtype=float) + + k = np.exp(-0.5 * (t / sigma_pts) ** 2) + k /= np.sum(k) + + return np.convolve(y, k, mode="same") + + +@lru_cache(maxsize=512) +def _render_npz_to_data_uri( + npz_path_str: str, meta_path_str: str | None +) -> tuple[str | None, dict[str, Any]]: + """ + Render a dispersion/DOS PNG from a calc-stage NPZ and optional meta JSON. + + Parameters + ---------- + npz_path_str + Path to the calc-stage NPZ file (string for cache key stability). + meta_path_str + Optional path to a metadata JSON file (string for cache key stability). + + Returns + ------- + tuple[str | None, dict[str, Any]] + ``(data_uri, extras)`` where ``data_uri`` is a PNG data URI or ``None`` if the + NPZ file is missing. ``extras`` contains values used to build the caption. + """ + npz_path = Path(npz_path_str) + if not npz_path.exists(): + return None, {} + + data = np.load(npz_path, allow_pickle=True) + + labels = None + if meta_path_str: + meta_path = Path(meta_path_str) + if meta_path.exists(): + try: + meta = _load_json(meta_path) + labels = meta.get("labels", None) if isinstance(meta, dict) else None + if not isinstance(labels, list): + labels = None + except Exception: + labels = None + + # Dispersion + dft_x = np.asarray(data["dft_x"], dtype=float) + dft_freq = np.asarray(data["dft_frequencies"], dtype=float) + ml_freq = np.asarray(data["ml_frequencies"], dtype=float) + + dft_on_ml = resample_dft_to_ml_grid(dft_x, dft_freq, n_ml=ml_freq.shape[0]) + + omega_avg_ref = float(np.mean(dft_on_ml)) + omega_avg_pred = float(np.mean(ml_freq)) + + y_lo = float(min(np.min(dft_on_ml), np.min(ml_freq))) + y_hi = float(max(np.max(dft_on_ml), np.max(ml_freq))) + + # DOS + required = [ + "pdos_frequency_points", + "pdos_projected", + "q_weights", + "q_frequencies_dft", + ] + has_dos = all(k in data.files for k in required) + + dos_grid = None + dft_dos_plot = None + ml_dos_plot = None + + if has_dos: + fgrid = np.asarray(data["pdos_frequency_points"], dtype=float).reshape(-1) + pdos_proj = np.asarray(data["pdos_projected"], dtype=float) + + if pdos_proj.ndim == 2 and pdos_proj.shape[0] == fgrid.size: + ml_dos = np.mean(pdos_proj, axis=1) + else: + ml_dos = np.mean(pdos_proj, axis=0) + ml_dos = np.asarray(ml_dos, dtype=float).reshape(-1) + + area_ml = float(np.trapz(ml_dos, fgrid)) + ml_dos_n = ml_dos / area_ml if area_ml > 0 else ml_dos + + q_w = np.asarray(data["q_weights"], dtype=float) + q_f = np.asarray(data["q_frequencies_dft"], dtype=float) + + weights_tile = np.tile(q_w[:, None], (1, q_f.shape[1])).reshape(-1, order="F") + freqs_flat = q_f.reshape(-1, order="F") # THz + + dft_dos = gaussian_broadened_dos(freqs_flat, weights_tile, fgrid, sigma=0.05) + area_dft = float(np.trapz(dft_dos, fgrid)) + dft_dos_n = dft_dos / area_dft if area_dft > 0 else dft_dos + + order = np.argsort(fgrid) + fgrid = fgrid[order] + ml_dos_n = ml_dos_n[order] + dft_dos_n = dft_dos_n[order] + + ml_dos_n = gaussian_smooth_on_grid(fgrid, ml_dos_n, sigma=0.05) + dft_dos_n = gaussian_smooth_on_grid(fgrid, dft_dos_n, sigma=0.05) + + df = float(np.median(np.diff(fgrid))) if fgrid.size > 2 else 0.01 + dos_grid = np.arange(y_lo, y_hi + df, df) + + ml_dos_plot = np.interp(dos_grid, fgrid, ml_dos_n, left=0.0, right=0.0) + + dft_dos_plot = gaussian_broadened_dos( + freqs_flat, + weights_tile, + dos_grid, + sigma=0.05, + ) + area_dft_plot = float(np.trapz(dft_dos_plot, dos_grid)) + dft_dos_plot = ( + dft_dos_plot / area_dft_plot if area_dft_plot > 0 else dft_dos_plot + ) + + ml_dos_plot = gaussian_smooth_on_grid(dos_grid, ml_dos_plot, sigma=0.05) + dft_dos_plot = gaussian_smooth_on_grid(dos_grid, dft_dos_plot, sigma=0.05) + + fig, (a0, a1) = plt.subplots( + 1, + 2, + gridspec_kw={"width_ratios": [4, 1], "wspace": 0.05}, + figsize=(10, 8), + ) + + n_br = ml_freq.shape[1] if ml_freq.ndim == 2 else 1 + x = np.arange(ml_freq.shape[0], dtype=float) + + ml_color = "C0" + dft_color = "k" + + if n_br == 1 and ml_freq.ndim == 1: + a0.plot(x, ml_freq, lw=2.0, color=ml_color, label="ML") + a0.plot(x, dft_on_ml, lw=2.0, ls="--", color=dft_color, label="DFT") + else: + for j in range(n_br): + a0.plot( + x, + ml_freq[:, j], + lw=2.0, + color=ml_color, + label="ML" if j == 0 else None, + ) + a0.plot( + x, + dft_on_ml[:, j], + lw=2.0, + ls="--", + color=dft_color, + label="DFT" if j == 0 else None, + ) + + if "ml_normal_ticks" in data.files and labels is not None: + ticks = np.asarray(data["ml_normal_ticks"], dtype=float) + if np.max(ticks) <= ml_freq.shape[0] - 1 + 1e-6: + xt = ticks + else: + xt = ( + (ticks - ticks.min()) + / (ticks.max() - ticks.min() + 1e-12) + * (ml_freq.shape[0] - 1) + ) + a0.set_xticks(xt) + a0.set_xticklabels(labels) + a0.set_xlim(float(xt[0]), float(xt[-1])) + + a0.set_ylabel("Frequency (THz)", fontsize=20) + a0.axhline(0.0, color="k", lw=1.5) + a0.grid(axis="x") + a0.set_ylim(y_lo, y_hi) + + handles, lbls = a0.get_legend_handles_labels() + by_label = dict(zip(lbls, handles, strict=False)) + if by_label: + a0.legend(by_label.values(), by_label.keys(), loc=1, fontsize=18) + + if ( + has_dos + and dos_grid is not None + and dft_dos_plot is not None + and ml_dos_plot is not None + ): + a1.fill_betweenx(dos_grid, 0.0, dft_dos_plot, color="k", alpha=0.25) + a1.plot(dft_dos_plot, dos_grid, color="k", lw=2.0) + a1.plot(ml_dos_plot, dos_grid, lw=2.0) + a1.set_xlabel("DOS", fontsize=20) + a1.grid(True, linestyle=":", linewidth=0.6) + plt.setp(a1.get_yticklabels(), visible=False) + a1.set_ylim(y_lo, y_hi) + else: + a1.axis("off") + + bio = BytesIO() + fig.savefig(bio, format="png", dpi=200) + plt.close(fig) + + data_uri = "data:image/png;base64," + base64.b64encode(bio.getvalue()).decode( + "ascii" + ) + return ( + data_uri, + { + "omega_avg_ref_thz": omega_avg_ref, + "omega_avg_pred_thz": omega_avg_pred, + }, + ) + + +def render_dispersion_component( + selection_context: dict[str, Any], + *, + calc_root: Path, + **_: Any, +): + """ + Render the dispersion + DOS panel for a selected scatter point. + + Parameters + ---------- + selection_context + Selection payload produced by the scatter/table callbacks. + calc_root + Root directory containing calculation artifacts referenced by ``data_paths``. + **_ + Additional keyword arguments accepted for callback API compatibility. + + Returns + ------- + Any or None + A Dash component containing the rendered image and caption, or ``None`` if the + required artifacts are unavailable. + """ + selected = ( + selection_context.get("selection") + if isinstance(selection_context, dict) + else None + ) + if not isinstance(selected, dict): + selected = selection_context if isinstance(selection_context, dict) else {} + + data_paths = selected.get("data_paths") + if not isinstance(data_paths, dict): + return None + + npz_rel = data_paths.get("npz") + if not isinstance(npz_rel, str) or not npz_rel: + return None + + meta_rel = data_paths.get("meta") + meta_rel_str = meta_rel if isinstance(meta_rel, str) and meta_rel else None + + npz_path = (calc_root / npz_rel).resolve() + meta_path = (calc_root / meta_rel_str).resolve() if meta_rel_str else None + + src, extras = _render_npz_to_data_uri( + str(npz_path), + str(meta_path) if meta_path else None, + ) + if not src: + return None + + caption_bits = [] + oref = extras.get("omega_avg_ref_thz") + opred = extras.get("omega_avg_pred_thz") + if oref is not None and opred is not None: + caption_bits.append(f"ω_avg: DFT {oref:.3f} | ML {opred:.3f} THz") + + caption = " | ".join(caption_bits) + title = selected.get("system") or selected.get("label") or selected.get("id") or "" + + return html.Div( + [ + html.H4(title) if title else None, + html.Img( + src=src, + style={ + "width": "100%", + "maxWidth": "820px", + "height": "auto", + "display": "block", + "borderRadius": "8px", + "border": "1px solid #ddd", + }, + ), + html.Div( + caption, + style={"marginTop": "8px", "fontSize": "0.95rem", "opacity": 0.85}, + ) + if caption + else None, + ] + ) diff --git a/ml_peg/calcs/bulk_crystal/diamond_phonons/calc_diamond_phonons.py b/ml_peg/calcs/bulk_crystal/diamond_phonons/calc_diamond_phonons.py new file mode 100644 index 000000000..2d291c566 --- /dev/null +++ b/ml_peg/calcs/bulk_crystal/diamond_phonons/calc_diamond_phonons.py @@ -0,0 +1,269 @@ +""" +Calculate diamond phonon bands (phonopy version-friendly). + +This module computes phonon force constants and phonon band dispersions for +diamond using a set of MLIP models. Outputs are written to ``outputs//`` +as: + +- ``FORCE_CONSTANTS`` +- ``band.yaml`` + +The band path is taken from the DFT reference NPZ so that predicted dispersions +are evaluated on the exact same q-path as the reference. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from ase import Atoms +import numpy as np +from phonopy.structure.atoms import PhonopyAtoms +import pytest + +from ml_peg.calcs.utils.utils import download_github_data +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +GITHUB_BASE = "https://raw.githubusercontent.com/7radians/ml-peg-data/main" + +EXTRACTED_ROOT = Path( + download_github_data( + filename="diamond_data/data.zip", + github_uri=GITHUB_BASE, + ) +) + +DATA_PATH = EXTRACTED_ROOT / "data" +DIAMOND_YAML = DATA_PATH / "diamond.yaml" +DFT_BAND_NPZ = DATA_PATH / "dft_band.npz" +OUT_PATH = Path(__file__).parent / "outputs" + +MODELS = load_models(current_models) + + +def phonopy_atoms_to_ase(ph_atoms: PhonopyAtoms) -> Atoms: + """ + Convert ``PhonopyAtoms`` to ASE ``Atoms``. + + Parameters + ---------- + ph_atoms + Phonopy atoms object. + + Returns + ------- + ase.Atoms + ASE representation with periodic boundary conditions enabled. + """ + return Atoms( + symbols=ph_atoms.symbols, + cell=ph_atoms.cell, + scaled_positions=ph_atoms.scaled_positions, + pbc=True, + ) + + +def load_phonon_yaml(yaml_path: Path) -> Any: + """ + Load a phonopy object from a phonopy YAML file (version-tolerant). + + Parameters + ---------- + yaml_path + Path to a phonopy YAML file (e.g. ``diamond.yaml``). + + Returns + ------- + object + A phonopy object as returned by ``phonopy.load`` or the older YAML reader. + + Raises + ------ + FileNotFoundError + If ``yaml_path`` does not exist. + """ + if not yaml_path.exists(): + raise FileNotFoundError(f"Missing phonopy YAML: {yaml_path}") + + try: + import phonopy # type: ignore + + return phonopy.load(str(yaml_path)) + except Exception: + from phonopy.interface.phonopy_yaml import read_phonopy_yaml # type: ignore + + obj = read_phonopy_yaml(str(yaml_path)) + return getattr(obj, "phonopy", obj) + + +def get_displaced_supercells(phonon: Any) -> list[PhonopyAtoms]: + """ + Return displaced supercells from a phonopy object (version-tolerant). + + Parameters + ---------- + phonon + Phonopy object. + + Returns + ------- + list[phonopy.structure.atoms.PhonopyAtoms] + Displaced supercells. Returns an empty list if not available. + """ + if hasattr(phonon, "get_supercells_with_displacements"): + return phonon.get_supercells_with_displacements() + return getattr(phonon, "supercells_with_displacements", []) or [] + + +def ref_band_path() -> list[list[list[float]]]: + """ + Build a q-point path for ``phonon.run_band_structure`` from the DFT NPZ. + + The DFT reference NPZ is used so that predicted dispersions are computed on + the exact same q-path as the reference. + + The NPZ is expected to contain: + + - ``qpoints``: array with shape ``(Nq, 3)`` + - ``distance``: array with shape ``(Nq,)`` (optional; used to detect segment + boundaries). If missing, the q-path is treated as a single segment. + + Returns + ------- + list[list[list[float]]] + List of q-point segments, each a list of ``[qx, qy, qz]`` points. + + Raises + ------ + FileNotFoundError + If the DFT reference NPZ does not exist. + KeyError + If the NPZ does not contain ``qpoints``. + ValueError + If ``qpoints`` or ``distance`` have invalid shapes. + """ + if not DFT_BAND_NPZ.exists(): + raise FileNotFoundError(f"Missing DFT reference NPZ: {DFT_BAND_NPZ}") + + obj = np.load(DFT_BAND_NPZ, allow_pickle=False) + if "qpoints" not in obj: + raise KeyError(f"{DFT_BAND_NPZ} missing required key 'qpoints'.") + + q = np.asarray(obj["qpoints"], float) + if q.ndim != 2 or q.shape[1] != 3: + raise ValueError(f"Bad qpoints shape in {DFT_BAND_NPZ}: {q.shape}") + + x = np.asarray(obj["distance"], float) if "distance" in obj else None + if x is None: + return [q.tolist()] + + if x.ndim != 1 or x.shape[0] != q.shape[0]: + raise ValueError(f"Bad distance shape in {DFT_BAND_NPZ}: {x.shape}") + + cuts = [0] + [i for i in range(1, len(x)) if x[i] <= x[i - 1] + 1e-12] + [len(x)] + return [q[cuts[i] : cuts[i + 1]].tolist() for i in range(len(cuts) - 1)] + + +def write_force_constants(phonon: Any, fc_path: Path) -> None: + """ + Write force constants to disk in a phonopy-version-tolerant way. + + Parameters + ---------- + phonon + Phonopy object with computed force constants. + fc_path + Output path for ``FORCE_CONSTANTS``. + """ + fc_path.parent.mkdir(parents=True, exist_ok=True) + try: + phonon.write_force_constants(filename=str(fc_path)) + except AttributeError: + from phonopy.file_IO import write_FORCE_CONSTANTS # type: ignore + + write_FORCE_CONSTANTS(phonon.force_constants, filename=str(fc_path)) + + +def write_band_yaml(phonon: Any, bands_path: Path) -> None: + """ + Write phonopy band-structure YAML to disk. + + Parameters + ---------- + phonon + Phonopy object that has already run ``run_band_structure``. + bands_path + Output path for ``band.yaml``. + + Raises + ------ + RuntimeError + If a YAML writer is not available for the current phonopy version. + """ + bands_path.parent.mkdir(parents=True, exist_ok=True) + + if hasattr(phonon, "write_yaml_band_structure"): + phonon.write_yaml_band_structure(filename=str(bands_path)) + return + + bs = getattr(phonon, "band_structure", None) + if bs is None or not hasattr(bs, "write_yaml"): + raise RuntimeError( + "Phonopy band-structure YAML writer not found in this phonopy version." + ) + bs.write_yaml(filename=str(bands_path)) + + +@pytest.mark.parametrize("mlip", MODELS.items()) +def test_diamond_phonons_band(mlip) -> None: + """ + Compute diamond phonon bands for one MLIP and write output files. + + Parameters + ---------- + mlip + Tuple ``(model_name, model)`` from ``MODELS.items()``. + + Raises + ------ + RuntimeError + If the phonopy YAML lacks a displacement dataset, has no displaced + supercells, or cannot write band YAML for the current phonopy version. + """ + model_name, model = mlip + + phonon = load_phonon_yaml(DIAMOND_YAML) + if getattr(phonon, "dataset", None) is None: + raise RuntimeError(f"{DIAMOND_YAML} has no displacement dataset.") + + scells = get_displaced_supercells(phonon) + if not scells: + raise RuntimeError(f"No displaced supercells found in {DIAMOND_YAML}.") + + calc = model.get_calculator() + + forces = [] + for scp in scells: + ase_sc = phonopy_atoms_to_ase(scp) + ase_sc.calc = calc + forces.append(ase_sc.get_forces()) + + if hasattr(phonon, "set_forces"): + phonon.set_forces(forces) + else: + phonon.forces = forces + + phonon.produce_force_constants() + phonon.run_band_structure(ref_band_path()) + + write_dir = OUT_PATH / model_name + fc_path = write_dir / "FORCE_CONSTANTS" + bands_path = write_dir / "band.yaml" + + write_force_constants(phonon, fc_path) + write_band_yaml(phonon, bands_path) + + assert fc_path.exists(), f"Missing FORCE_CONSTANTS for {model_name}: {fc_path}" + assert bands_path.exists(), f"Missing band.yaml for {model_name}: {bands_path}" diff --git a/ml_peg/calcs/bulk_crystal/ti64_phonons/calc_ti64_phonons.py b/ml_peg/calcs/bulk_crystal/ti64_phonons/calc_ti64_phonons.py new file mode 100644 index 000000000..e640dab32 --- /dev/null +++ b/ml_peg/calcs/bulk_crystal/ti64_phonons/calc_ti64_phonons.py @@ -0,0 +1,591 @@ +""" +Run Ti64 CASTEP phonon suite (raw outputs only). + +This module writes per-case raw calculation outputs for each MLIP model to +``outputs//`` as: + +- ``.npz``: raw arrays used by analysis +- ``.json``: minimal metadata (no metrics) +""" + +from __future__ import annotations + +import json +from pathlib import Path +import re +from typing import Any + +import ase.io +from ase.optimize import LBFGS +import numpy as np +import pytest + +from ml_peg.calcs import CALCS_ROOT +from ml_peg.calcs.utils.ASE_to_phonons import AtomsToPDOS, AtomsToPhonons +from ml_peg.calcs.utils.CASTEP_reader_phonon_dispersion import PhononFromCastep +from ml_peg.calcs.utils.utils import download_github_data +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +GITHUB_BASE = "https://raw.githubusercontent.com/7radians/ml-peg-data/main" + +EXTRACTED_ROOT = Path( + download_github_data( + filename="ti64_data/data.zip", + github_uri=GITHUB_BASE, + ) +) + +DATA_PATH = EXTRACTED_ROOT / "data" + +OUT_PATH = CALCS_ROOT / "bulk_crystal" / "ti64_phonons" / "outputs" + +FMAX = 0.001 +STEPS = 10000 +MESH_202020 = [20, 20, 20] + + +def _json_default(obj: Any) -> Any: + """ + JSON serializer for numpy/Path objects. + + Parameters + ---------- + obj + Object to convert. + + Returns + ------- + Any + JSON-serialisable representation. + """ + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return str(obj) + + +def write_case_npz(case_name: str, model_name: str, **arrays: Any) -> None: + """ + Write raw per-case arrays to a compressed NPZ file. + + Parameters + ---------- + case_name + Case identifier used in the output filename. + model_name + Model identifier used in the output directory. + **arrays + Keyword arrays to store in the NPZ file. + + Notes + ----- + Output path: + ``outputs//.npz`` + """ + out_dir = OUT_PATH / model_name + out_dir.mkdir(parents=True, exist_ok=True) + np.savez_compressed(out_dir / f"{case_name}.npz", **arrays) + + +def write_case_metadata(case_name: str, model_name: str, meta: dict[str, Any]) -> None: + """ + Write minimal per-case metadata to JSON (no metrics). + + Parameters + ---------- + case_name + Case identifier used in the output filename. + model_name + Model identifier used in the output directory. + meta + Metadata mapping to write. + + Notes + ----- + Output path: + ``outputs//.json`` + """ + out_dir = OUT_PATH / model_name + out_dir.mkdir(parents=True, exist_ok=True) + + meta_out = dict(meta) + meta_out["case"] = case_name + meta_out["model_name"] = model_name + + out_file = out_dir / f"{case_name}.json" + with out_file.open("w", encoding="utf8") as handle: + json.dump(meta_out, handle, indent=2, default=_json_default) + + +def patch_ase_castep_reader_dummy_energy(dummy_energy: float = 0.0) -> None: + """ + Patch ASE CASTEP reader to tolerate missing energy keys. + + ASE's CASTEP reader can raise ``KeyError`` for CASTEP phonon/qpoints outputs. + This patch only adds dummy values when ASE raises ``KeyError``. + + Parameters + ---------- + dummy_energy + Dummy energy/free energy (eV) to use when missing. + """ + import ase.io.castep.castep_reader as cr # noqa: WPS433 (local import is intentional) + + orig = cr._set_energy_and_free_energy + + def _safe_set_energy_and_free_energy(results: dict[str, Any]) -> None: + """ + Set energy/free_energy keys, tolerating missing values. + + Parameters + ---------- + results + ASE results dictionary to patch in-place. + """ + try: + orig(results) + except KeyError: + results.setdefault("energy", float(dummy_energy)) + results.setdefault("free_energy", float(dummy_energy)) + + cr._set_energy_and_free_energy = _safe_set_energy_and_free_energy + + +class PhononFromCastepPDOS(PhononFromCastep): + """ + Parse CASTEP q-point phonon output for PDOS. + + Parameters + ---------- + castep_file + Path to a CASTEP output file containing q-point frequencies and weights. + atoms_in + Structure used to determine the number of branches (3N). + + Attributes + ---------- + frequencies + Phonon frequencies array with shape ``(nq, 3N)``. + weights + Q-point weights array with shape ``(nq,)``. + """ + + def __init__(self, castep_file: str, atoms_in: Any | None = None) -> None: + """ + Initialise from a CASTEP file and extract frequencies/weights. + + Parameters + ---------- + castep_file + Path to the CASTEP q-points output file. + atoms_in + Structure used to determine the number of branches (3N). Required. + """ + if atoms_in is None: + raise ValueError( + "atoms_in must be provided to set number_of_branches (=3N)." + ) + + self.number_of_branches = len(atoms_in) * 3 + self.filename = castep_file + self.read_in_file() + self.get_frequencies() + self.get_weights() + delattr(self, "filelines") + + def __str__(self) -> str: + """ + Return a short human-readable description of this object. + + Returns + ------- + str + Description string. + """ + return "Phonon q-point frequencies+weights from CASTEP file object" + + def get_weights(self) -> None: + """Extract q-point weights from filelines.""" + float_re = re.compile(r"[-+]?(?:\d*\.\d+|\d+)(?:[Ee][-+]?\d+)?") + weights: list[float] = [] + + for line in self.filelines: + if "q-pt" not in line: + continue + + # 1) Preferred: explicit 'weight=' + match = re.search( + r"weight\s*=\s*(" + float_re.pattern + r")", + line, + flags=re.I, + ) + if match: + weights.append(float(match.group(1))) + continue + + # 2) Fallback: take the last float on the q-pt line + nums = float_re.findall(line) + if nums: + weights.append(float(nums[-1])) + + nq_freq = self.frequencies.shape[0] + + if not weights: + # 3) No weights printed -> assume uniform weights + self.weights = np.ones(nq_freq, dtype=float) + return + + # Some files include extra q-pt header lines; align length + w = np.array(weights, dtype=float) + if w.shape[0] > nq_freq: + w = w[-nq_freq:] + + if w.shape[0] != nq_freq: + raise ValueError( + f"Parsed {w.shape[0]} weights but frequencies has {nq_freq} q-points." + ) + + self.weights = w + + +def run_case( + *, + case_name: str, + structure_file: str, + qpoints_file: str | None, + kpath: list, + labels: list, + grid: list, + disp_phonons: float, + disp_pdos: float, + do_tp: bool, + calc: Any, + model_name: str, +) -> None: + """ + Run one Ti64 phonon case and write raw artifacts. + + Parameters + ---------- + case_name + Case identifier. + structure_file + CASTEP structure file path (relative to repository root). + qpoints_file + Optional CASTEP qpoints file path. + kpath + High-symmetry k-path used by CASTEP dispersion output. + labels + Tick labels for the k-path. + grid + Phonon supercell grid (3x3 matrix-like list). + disp_phonons + Displacement magnitude for dispersion calculation. + disp_pdos + Displacement magnitude for DOS/PDOS calculation. + do_tp + Whether to compute thermo/TP outputs. + calc + ASE calculator from the selected model. + model_name + Model identifier used for output paths. + """ + print(f"{case_name} | {model_name}") + + # DFT dispersion along k-path (raw reference) + pfc = PhononFromCastep(castep_file=structure_file, kpath_in=kpath) + + # Relax + atoms = ase.io.read(structure_file) + atoms.calc = calc + dyn = LBFGS(atoms, logfile=None) + dyn.run(fmax=FMAX, steps=STEPS) + + # ML dispersion + atp_ml = AtomsToPhonons( + primitive_cell=atoms, + phonon_grid=grid, + displacement=disp_phonons, + kpath=[kpath], + calculator=calc, + plusminus=True, + ) + + # ML DOS/PDOS + atp_pdos = AtomsToPDOS( + primitive_cell=atoms, + phonon_grid=grid, + displacement=disp_pdos, + calculator=calc, + ) + atp_pdos.get_pdos(MESH_202020) + atp_pdos.get_dos(MESH_202020) + + # ML thermo + if do_tp: + atp_pdos.get_tp(MESH_202020, tmax=2000, tstep=1) + + # DFT qpoints + q_weights: np.ndarray | None = None + q_freq_dft: np.ndarray | None = None + if qpoints_file is not None: + q_atoms = ase.io.read(qpoints_file) + pfc_q = PhononFromCastepPDOS(castep_file=qpoints_file, atoms_in=q_atoms) + q_weights = np.asarray(pfc_q.weights, dtype=float) + q_freq_dft = np.asarray(pfc_q.frequencies, dtype=float) + + # JSON metadata + meta: dict[str, Any] = { + "structure_file": str(structure_file), + "qpoints_file": str(qpoints_file) if qpoints_file is not None else None, + "kpath": [kpath], + "labels": labels, + "phonon_grid": grid, + "displacement_phonons": float(disp_phonons), + "displacement_pdos": float(disp_pdos), + "relax_settings": {"fmax": float(FMAX), "steps": int(STEPS)}, + "did_tp": bool(do_tp), + "tp_settings": {"mesh": MESH_202020, "tmax": 2000, "tstep": 1} + if do_tp + else None, + } + write_case_metadata(case_name, model_name, meta) + + arrays: dict[str, np.ndarray] = { + "n_atoms": np.int64(len(atoms)), + "labels": np.array(labels, dtype=object), + "kpath": np.array([kpath], dtype=object), + "phonon_grid": np.asarray(grid, dtype=int), + # DFT dispersion + "dft_x": np.asarray(pfc.xscale, dtype=float), + "dft_frequencies": np.asarray(pfc.frequencies, dtype=float), + # ML dispersion + "ml_frequencies": np.asarray(atp_ml.frequencies, dtype=float), + } + + if hasattr(atp_ml, "normal_ticks") and atp_ml.normal_ticks is not None: + arrays["ml_normal_ticks"] = np.asarray(atp_ml.normal_ticks, dtype=float) + + # PDOS + if "frequency_points" in atp_pdos.pdos: + arrays["pdos_frequency_points"] = np.asarray( + atp_pdos.pdos["frequency_points"], + dtype=float, + ) + if "projected_dos" in atp_pdos.pdos: + arrays["pdos_projected"] = np.asarray( + atp_pdos.pdos["projected_dos"], dtype=float + ) + if "dos" in atp_pdos.pdos: + arrays["dos_total"] = np.asarray(atp_pdos.pdos["dos"], dtype=float) + + # TP + if do_tp and hasattr(atp_pdos, "tp_dict") and atp_pdos.tp_dict: + if "temperatures" in atp_pdos.tp_dict: + arrays["tp_temperatures"] = np.asarray( + atp_pdos.tp_dict["temperatures"], dtype=float + ) + if "free_energy" in atp_pdos.tp_dict: + arrays["tp_free_energy"] = np.asarray( + atp_pdos.tp_dict["free_energy"], dtype=float + ) + + # qpoints + if qpoints_file is not None: + # These are defined whenever qpoints_file is not None (see above). + arrays["q_weights"] = np.asarray(q_weights, dtype=float) # type: ignore[arg-type] + arrays["q_frequencies_dft"] = np.asarray(q_freq_dft, dtype=float) # type: ignore[arg-type] + + write_case_npz(case_name, model_name, **arrays) + + +def _hex_path() -> tuple[list[list[float]], list[str]]: + """ + Return the high-symmetry path and tick labels for the hexagonal cell. + + Returns + ------- + tuple[list[list[float]], list[str]] + ``(kpath, labels)`` where ``kpath`` is a list of fractional k-points and + ``labels`` are the corresponding tick labels. + """ + gam = [0, 0, 0] + a_pt = [0, 0, 1 / 2] + k_pt = [1 / 3, 1 / 3, 0] + m_pt = [0.5, 0, 0] + return [gam, k_pt, m_pt, gam, a_pt], ["$\\Gamma$", "K", "M", "$\\Gamma$", "A"] + + +def _bcc_path() -> tuple[list[list[float]], list[str]]: + """ + Return the high-symmetry path and tick labels for the BCC cell. + + Returns + ------- + tuple[list[list[float]], list[str]] + ``(kpath, labels)`` where ``kpath`` is a list of fractional k-points and + ``labels`` are the corresponding tick labels. + """ + gam = [0, 0, 0] + h_pt = [0.5, -0.5, 0.5] + p_pt = [0.25, 0.25, 0.25] + n_pt = [0, 0, 0.5] + return ( + [gam, h_pt, n_pt, gam, p_pt, h_pt, p_pt, n_pt], + ["$\\Gamma$", "H", "N", "$\\Gamma$", "P", "H", "P", "N"], + ) + + +GRID_222 = [[2, 0, 0], [0, 2, 0], [0, 0, 2]] +HEX_KPATH, HEX_LABELS = _hex_path() +BCC_KPATH, BCC_LABELS = _bcc_path() + +TP_ON = { + "hcp_Ti6AlV", # 1/10 + "hex_Ti8AlV", # 3/10 + "hcp_Ti6Al2", # 4/10 + "hcp_Ti6V2", # 5/10 + "hcp_Ti7V", # 6/10 + "hex_Ti10Al2", # 9/10 + "hex_Ti10V2", # 10/10 +} + +CASES: list[dict[str, Any]] = [ + { + "case_name": "hcp_Ti6AlV", + "structure_file": DATA_PATH / "ti64_hcp_phonon.castep", + "qpoints_file": DATA_PATH / "ti64_hcp_phonon_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "bcc_Ti6AlV", + "structure_file": DATA_PATH / "ti64_bcc_phonon.castep", + "qpoints_file": DATA_PATH / "ti64_bcc_phonon_qpoints.castep", + "kpath": BCC_KPATH, + "labels": BCC_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "hex_Ti8AlV", + "structure_file": DATA_PATH / "ti64_hex_phonon.castep", + "qpoints_file": DATA_PATH / "ti64_hex_phonon_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.01, + }, + { + "case_name": "hcp_Ti6Al2", + "structure_file": DATA_PATH / "ti64_hcp_phonon_AlAl_qpath.castep", + "qpoints_file": DATA_PATH / "ti64_hcp_phonon_AlAl_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "hcp_Ti6V2", + "structure_file": DATA_PATH / "ti64_hcp_phonon_VV_qpath.castep", + "qpoints_file": DATA_PATH / "ti64_hcp_phonon_VV_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "hcp_Ti7V", + "structure_file": DATA_PATH / "Ti7V_hcp_phonon_qpath.castep", + "qpoints_file": DATA_PATH / "Ti7V_hcp_phonon_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "bcc_Ti6Al2", + "structure_file": DATA_PATH / "ti64_bcc_phonon_AlAl_qpath.castep", + "qpoints_file": DATA_PATH / "ti64_bcc_phonon_AlAl_qpoints.castep", + "kpath": BCC_KPATH, + "labels": BCC_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "bcc_Ti6V2", + "structure_file": DATA_PATH / "ti64_bcc_phonon_VV_qpath.castep", + "qpoints_file": DATA_PATH / "ti64_bcc_phonon_VV_qpoints.castep", + "kpath": BCC_KPATH, + "labels": BCC_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "hex_Ti10Al2", + "structure_file": DATA_PATH / "ti64_hex_phonon_AlAl_qpath.castep", + "qpoints_file": DATA_PATH / "ti64_hex_phonon_AlAl_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, + { + "case_name": "hex_Ti10V2", + "structure_file": DATA_PATH / "ti64_hex_phonon_VV_qpath.castep", + "qpoints_file": DATA_PATH / "ti64_hex_phonon_VV_qpoints.castep", + "kpath": HEX_KPATH, + "labels": HEX_LABELS, + "grid": GRID_222, + "disp_phonons": 0.02, + "disp_pdos": 0.02, + }, +] + + +MODELS = load_models(current_models) +MODEL_ITEMS = list(MODELS.items()) +MODEL_IDS = [name for name, _ in MODEL_ITEMS] + + +@pytest.mark.parametrize("mlip", MODEL_ITEMS, ids=MODEL_IDS) +def test_phonon_suite(mlip: tuple[str, Any]) -> None: + """ + Run the full Ti64 phonon suite for one model and write artifacts. + + Parameters + ---------- + mlip + Tuple ``(model_name, model)`` from ``MODEL_ITEMS``. + """ + patch_ase_castep_reader_dummy_energy(dummy_energy=0.0) + + model_name, model = mlip + calc = model.get_calculator() + + for spec in CASES: + case_name = spec["case_name"] + do_tp = case_name in TP_ON + + run_case(**spec, do_tp=do_tp, calc=calc, model_name=model_name) + + out_dir = OUT_PATH / model_name + assert (out_dir / f"{case_name}.npz").exists() + assert (out_dir / f"{case_name}.json").exists() diff --git a/ml_peg/calcs/utils/ASE_to_phonons.py b/ml_peg/calcs/utils/ASE_to_phonons.py new file mode 100644 index 000000000..faaf1f6b8 --- /dev/null +++ b/ml_peg/calcs/utils/ASE_to_phonons.py @@ -0,0 +1,492 @@ +"""ASE → phonopy helpers for band structures and DOS/PDOS.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import ase +import ase.io +import numpy as np +from phonopy import Phonopy +from phonopy.phonon.band_structure import get_band_qpoints_and_path_connections +from phonopy.structure.atoms import PhonopyAtoms + + +class AtomsToPhonons: + """ + Compute phonon band structures from an ASE/phonopy workflow. + + Parameters + ---------- + primitive_cell + Primitive cell as an ASE ``Atoms`` or a ``PhonopyAtoms``-like object. + phonon_grid + Phonopy supercell matrix. + displacement + Displacement magnitude for finite-difference force constants. + kpath + K-path provided to phonopy's band structure routine. + calculator + Either an ASE calculator (default workflow) or an iterable of CASTEP files + (when ``castep=True``). + kpoints + Number of points per k-path segment. + castep + If ``True``, forces are read from CASTEP output structures. + plusminus + Whether to generate plus/minus displacements. + diagonal + Whether to include diagonal displacements. + + Attributes + ---------- + phonon + Phonopy object holding force constants and band structure. + supercells + Displaced supercells used for force evaluation. + forces + List of forces arrays (one per displaced supercell). + frequencies + Concatenated frequencies along the k-path. + xticks + Tick positions corresponding to k-path segment boundaries. + normal_ticks + Evenly spaced tick positions used by some plotting utilities. + """ + + def __init__( + self, + primitive_cell: Any, + phonon_grid: Any, + displacement: float, + kpath: Any, + calculator: Any, + kpoints: int = 100, + castep: bool = False, + plusminus: bool = False, + diagonal: bool = True, + ) -> None: + """ + Initialise the workflow and compute the phonon band structure. + + Parameters + ---------- + primitive_cell + Primitive cell structure. + phonon_grid + Phonopy supercell matrix. + displacement + Displacement magnitude for finite differences. + kpath + K-path definition compatible with phonopy band structure routines. + calculator + ASE calculator or iterable of CASTEP output paths (when ``castep=True``). + kpoints + Number of points per k-path segment. + castep + If ``True``, read forces from CASTEP outputs instead of computing with ASE. + plusminus + Whether to generate ± displacements. + diagonal + Whether to generate diagonal displacements. + """ + self.calculator_string = calculator + self.get_supercell( + primitive_cell, phonon_grid, displacement, plusminus, diagonal + ) + + self.forces: list[np.ndarray] | None = None + if castep is False: + self.get_forces_model() + if castep is True: + self.get_forces_castep() + + if self.forces is None or len(self.forces) == 0: + raise RuntimeError("Forces were not generated.") + + self.get_band_struct(kpath, kpoints) + + def _scaled_positions(self, obj: Any) -> Any: + """ + Return scaled positions from either PhonopyAtoms or ASE Atoms. + + Parameters + ---------- + obj + Object providing either a ``scaled_positions`` attribute or a + ``get_scaled_positions()`` method. + + Returns + ------- + Any + Scaled positions array. + """ + if hasattr(obj, "scaled_positions"): + return obj.scaled_positions + return obj.get_scaled_positions() + + def get_supercell( + self, + primitive_cell: Any, + phonon_grid: Any, + displacement: float, + plusminus: bool, + diagonal: bool, + ) -> None: + """ + Construct phonopy object and displaced supercells. + + Parameters + ---------- + primitive_cell + Primitive cell structure. + phonon_grid + Phonopy supercell matrix. + displacement + Displacement magnitude for finite differences. + plusminus + Whether to use ± displacements. + diagonal + Whether to generate diagonal displacements. + """ + self.unitcell = PhonopyAtoms( + symbols=primitive_cell.symbols + if hasattr(primitive_cell, "symbols") + else primitive_cell.get_chemical_symbols(), + cell=primitive_cell.cell, + scaled_positions=self._scaled_positions(primitive_cell), + ) + + self.phonon = Phonopy(self.unitcell, phonon_grid) + self.phonon.generate_displacements( + distance=displacement, + is_plusminus=plusminus, + is_diagonal=diagonal, + ) + self.supercells = self.phonon.supercells_with_displacements + + def get_forces_model(self) -> None: + """Compute forces using an ASE calculator.""" + potential = self.calculator_string + self.forces = [] + self.atoms: list[ase.Atoms] = [] + + for s in self.supercells: + atoms = ase.Atoms( + symbols=list(s.symbols), + cell=s.cell, + scaled_positions=self._scaled_positions(s), + pbc=True, + ) + self.atoms.append(atoms) + atoms.calc = potential + self.forces.append(atoms.get_forces()) + + def get_forces_castep(self) -> None: + """Read forces from CASTEP output structures.""" + forces: list[np.ndarray] = [] + self.atoms = [] + + for _i, path in enumerate(self.calculator_string): + castep_atoms = ase.io.read(path) + self.atoms.append(castep_atoms) + forces.append(castep_atoms.get_forces()) + + self.forces = forces + + def get_band_struct(self, kpath: Any, kpoints: int) -> None: + """ + Compute band structure along the provided k-path. + + Parameters + ---------- + kpath + K-path definition compatible with + ``get_band_qpoints_and_path_connections``. + kpoints + Number of points per k-path segment. + """ + self.phonon.forces = np.asarray(self.forces, dtype=float) + self.phonon.produce_force_constants() + + qpoints, connections = get_band_qpoints_and_path_connections( + kpath, npoints=kpoints + ) + self.phonon.run_band_structure( + qpoints, + with_eigenvectors=True, + path_connections=connections, + ) + + bs = self.phonon.get_band_structure_dict() + self.frequencies_array = bs["frequencies"] + self.eigenvectors = bs["eigenvectors"] + + self.frequencies = np.array(self.frequencies_array[0], copy=True) + + xticks: list[int] = [] + x = 0 + for val in self.frequencies_array[1:]: + self.frequencies = np.append(self.frequencies, val, axis=0) + x += len(val) + xticks.append(x) + + self.xticks = xticks + n_kpoints = kpoints * len(kpath[0]) + n_trace = int(n_kpoints / (len(kpath[0]))) + self.normal_ticks = [i * n_trace for i in range(len(kpath[0]))] + + +class AtomsToPDOS: + """ + Compute force constants and DOS/PDOS/thermal properties via phonopy. + + Parameters + ---------- + primitive_cell + Primitive cell as an ASE ``Atoms`` or a ``PhonopyAtoms``-like object. + phonon_grid + Phonopy supercell matrix. + displacement + Displacement magnitude for finite-difference force constants. + calculator + Either an ASE calculator (default workflow) or an iterable of CASTEP files + (when ``castep=True``). + kpoints + Unused (kept for API compatibility with :class:`AtomsToPhonons`). + castep + If ``True``, forces are read from CASTEP output structures. + plusminus + Whether to generate plus/minus displacements. + diagonal + Whether to include diagonal displacements. + + Attributes + ---------- + phonon + Phonopy object holding force constants and mesh results. + supercells + Displaced supercells used for force evaluation. + forces + List of forces arrays (one per displaced supercell). + pdos + Projected DOS dictionary (set by :meth:`get_pdos`). + dos + Total DOS dictionary (set by :meth:`get_dos`). + tp_dict + Thermal properties dictionary (set by :meth:`get_tp`). + """ + + def __init__( + self, + primitive_cell: Any, + phonon_grid: Any, + displacement: float, + calculator: Any, + kpoints: int = 100, + castep: bool = False, + plusminus: bool = False, + diagonal: bool = True, + ) -> None: + """ + Initialise the workflow and build force constants. + + Parameters + ---------- + primitive_cell + Primitive cell structure. + phonon_grid + Phonopy supercell matrix. + displacement + Displacement magnitude for finite differences. + calculator + ASE calculator or iterable of CASTEP output paths (when ``castep=True``). + kpoints + Unused (kept for API compatibility). + castep + If ``True``, read forces from CASTEP outputs instead of computing with ASE. + plusminus + Whether to generate ± displacements. + diagonal + Whether to generate diagonal displacements. + """ + _ = kpoints + self.calculator_string = calculator + self.get_supercell( + primitive_cell, phonon_grid, displacement, plusminus, diagonal + ) + + self.forces: list[np.ndarray] | None = None + if castep is False: + self.get_forces_model() + if castep is True: + self.get_forces_castep() + + if self.forces is None: + raise RuntimeError( + "Forces were not generated. Check calculator/castep inputs." + ) + + self.get_dynamical_matrix() + + def _scaled_positions(self, obj: Any) -> Any: + """ + Return scaled positions from either PhonopyAtoms or ASE Atoms. + + Parameters + ---------- + obj + Object providing either a ``scaled_positions`` attribute or a + ``get_scaled_positions()`` method. + + Returns + ------- + Any + Scaled positions array. + """ + return ( + obj.scaled_positions + if hasattr(obj, "scaled_positions") + else obj.get_scaled_positions() + ) + + def get_supercell( + self, + primitive_cell: Any, + phonon_grid: Any, + displacement: float, + plusminus: bool, + diagonal: bool, + ) -> None: + """ + Construct phonopy object and displaced supercells. + + Parameters + ---------- + primitive_cell + Primitive cell structure. + phonon_grid + Phonopy supercell matrix. + displacement + Displacement magnitude for finite differences. + plusminus + Whether to use ± displacements. + diagonal + Whether to generate diagonal displacements. + """ + self.unitcell = PhonopyAtoms( + symbols=primitive_cell.symbols + if hasattr(primitive_cell, "symbols") + else primitive_cell.get_chemical_symbols(), + cell=primitive_cell.cell, + scaled_positions=self._scaled_positions(primitive_cell), + ) + self.phonon = Phonopy(self.unitcell, phonon_grid) + self.phonon.generate_displacements( + distance=displacement, + is_plusminus=plusminus, + is_diagonal=diagonal, + ) + self.supercells = self.phonon.supercells_with_displacements + + def get_forces_model(self) -> None: + """Compute forces using an ASE calculator.""" + potential = self.calculator_string + self.forces = [] + self.atoms: list[ase.Atoms] = [] + + for s in self.supercells: + atoms = ase.Atoms( + symbols=list(s.symbols), + cell=s.cell, + scaled_positions=self._scaled_positions(s), + pbc=True, + ) + self.atoms.append(atoms) + atoms.calc = potential + self.forces.append(atoms.get_forces()) + + def get_forces_castep(self) -> None: + """Read forces from CASTEP output structures.""" + forces: list[np.ndarray] = [] + self.atoms = [] + + for _i, path in enumerate(self.calculator_string): + castep_atoms = ase.io.read(path) + self.atoms.append(castep_atoms) + forces.append(castep_atoms.get_forces()) + + self.forces = forces + + def get_dynamical_matrix(self) -> None: + """ + Build force constants from the stored forces. + + Raises + ------ + ValueError + If the number of force sets does not match the number of supercells. + """ + forces = np.asarray(self.forces, dtype=float) + + if len(forces) != len(self.supercells): + raise ValueError( + f"Number of force sets ({len(forces)}) != number of supercells " + f"({len(self.supercells)})." + ) + + self.phonon.forces = forces + self.phonon.produce_force_constants() + + def get_pdos(self, qmesh: Sequence[int]) -> None: + """ + Compute projected DOS and total DOS on a mesh. + + Parameters + ---------- + qmesh + Q-mesh used for DOS calculations. + """ + self.phonon.run_mesh(qmesh, with_eigenvectors=True, is_mesh_symmetry=False) + self.phonon.run_projected_dos() + self.phonon.run_total_dos() + self.pdos = self.phonon.get_projected_dos_dict() + + def get_tp( + self, + qmesh: Sequence[int], + tmin: float = 0, + tmax: float = 2000, + tstep: float = 100, + ) -> None: + """ + Compute thermal properties. + + Parameters + ---------- + qmesh + Q-mesh for thermal properties. + tmin + Minimum temperature (K). + tmax + Maximum temperature (K). + tstep + Temperature step (K). + """ + self.phonon.run_mesh(qmesh) + self.phonon.run_thermal_properties(t_step=tstep, t_max=tmax, t_min=tmin) + self.tp_dict = self.phonon.get_thermal_properties_dict() + + def get_dos(self, qmesh: Sequence[int]) -> None: + """ + Compute total DOS on a mesh. + + Parameters + ---------- + qmesh + Q-mesh used for total DOS calculation. + """ + self.phonon.run_mesh(qmesh) + self.phonon.run_total_dos() + self.dos = self.phonon.get_total_dos_dict() diff --git a/ml_peg/calcs/utils/CASTEP_reader_phonon_dispersion.py b/ml_peg/calcs/utils/CASTEP_reader_phonon_dispersion.py new file mode 100644 index 000000000..a0e557207 --- /dev/null +++ b/ml_peg/calcs/utils/CASTEP_reader_phonon_dispersion.py @@ -0,0 +1,213 @@ +""" +Ti64 phonons CASTEP phonon dispersion reader. + +This module provides a lightweight parser for CASTEP phonon output files to +extract phonon frequencies (in THz) and q-point coordinates. Optionally, the +dispersion x-axis can be rescaled to span [0, 1] along a provided k-path. + +Notes +----- +- This parser assumes the CASTEP output uses THz units for phonon frequencies. +- The rescaling assumes ``kpath_in`` matches the k-path used in the CASTEP run. +""" + +from __future__ import annotations + +from pathlib import Path +import re +from typing import Any +import warnings + +import ase.io +import numpy as np + +warnings.simplefilter("ignore") + + +class PhononFromCastep: + """ + Extract phonon frequencies and k-points from a CASTEP phonon calculation. + + Parameters + ---------- + castep_file + Path to a CASTEP output file readable by ASE. + kpath_in + Optional k-path (high-symmetry points) used to rescale the dispersion + axis onto [0, 1]. Must match the CASTEP k-path convention. + verbose + If ``True``, print basic debug information. + + Attributes + ---------- + number_of_branches + Number of phonon branches (``3 * n_atoms``). + filename + Input file path as provided. + kpoints + Number of q-points parsed from the file. + frequencies + Frequencies array of shape ``(kpoints, number_of_branches)`` in THz. + kpath + q-point coordinates array of shape ``(kpoints, 3)``. + xscale + Optional rescaled x-axis (only present if ``kpath_in`` was provided). + """ + + RESCALE_TOL = 1e-4 + + def __init__( + self, + castep_file: str, + kpath_in: Any | None = None, + verbose: bool = False, + ) -> None: + """ + Initialise the reader and parse frequencies and q-point coordinates. + + Parameters + ---------- + castep_file + Path to a CASTEP output file readable by ASE. + kpath_in + Optional k-path (high-symmetry points) used to rescale the dispersion + axis onto [0, 1]. Must match the CASTEP k-path convention. + verbose + If ``True``, print basic debug information. + """ + self.filename = castep_file + + if not castep_file: + raise ValueError("castep_file must be provided.") + + try: + atoms = ase.io.read(castep_file) + except AttributeError as exc: + raise TypeError("Invalid input type for castep_file.") from exc + + self.number_of_branches = len(atoms) * 3 + + self.read_in_file() + self.get_frequencies() + self.get_kpath() + + if verbose: + print(f"Atoms object info:\n{atoms}\n") + print(self.__dict__.keys(), "\n") + + if kpath_in is not None: + self.rescale_xaxis(kpath_in) + if verbose: + print("k-path rescaled") + elif verbose: + print("no k-path re-scaling done") + + delattr(self, "filelines") + + def __str__(self) -> str: + """ + Return a short description. + + Returns + ------- + str + Description string. + """ + return "Phonon Dispersion (THz) from CASTEP file object" + + def read_in_file(self) -> None: + """Read the input file into memory as lines.""" + with Path(self.filename).open(encoding="utf8") as handle: + self.filelines = handle.readlines() + + def get_frequencies(self) -> None: + """Parse phonon frequencies (THz) into ``self.frequencies``.""" + headlines = 2 # number of lines before frequency numbers appear + + thz_blocks = [ + self.filelines[i + headlines : i + headlines + self.number_of_branches] + for i, val in enumerate(self.filelines) + if re.search(r" \(THz\) ", val) is not None + ] + + thz_lines = [line for block in thz_blocks for line in block] + thz_vals = [line.split()[2] for line in thz_lines] + + frequencies = np.array(thz_vals, dtype=float) + self.kpoints = int(len(frequencies) / self.number_of_branches) + self.frequencies = np.reshape( + frequencies, + (self.kpoints, self.number_of_branches), + ) + + def get_kpath(self) -> None: + """Parse q-point coordinates into ``self.kpath``.""" + qpt_lines = [ + self.filelines[i] + for i, val in enumerate(self.filelines) + if re.search(r"q-pt=", val) is not None + ] + + qpts: list[list[str]] = [] + for line in qpt_lines: + temp = line.split()[4:7] + temp[2] = temp[2].replace(")", "") + qpts.append(temp) + + self.kpath = np.array(qpts, dtype=float) + + def find_index(self, in_path: np.ndarray) -> None: + """ + Find indices of high-symmetry points along the parsed k-path. + + Parameters + ---------- + in_path + Array of target high-symmetry points with shape ``(n_points, 3)``. + """ + j = 0 + sympoint_idx: list[int] = [] + self.kpath_idx: list[list[int]] = [] + + for i, val in enumerate(self.kpath): + if abs(np.linalg.norm(in_path[j] - val)) < self.RESCALE_TOL: + sympoint_idx.append(i) + j += 1 + + for i in range(len(sympoint_idx) - 1): + idxs = list(range(sympoint_idx[i], sympoint_idx[i + 1] + 1)) + self.kpath_idx.append(idxs) + + def rescale_xaxis(self, rescale_xaxis: Any) -> None: + """ + Rescale the dispersion axis to span [0, 1] over the provided k-path. + + Parameters + ---------- + rescale_xaxis + Iterable of high-symmetry points (each a length-3 coordinate) used + to determine segment boundaries for rescaling. + """ + in_path = np.array(rescale_xaxis, dtype=float) + self.find_index(in_path) + + xsplit = 1.0 / len(self.kpath_idx) + xscale = [0.0] + pos = 0.0 + + kpath_cut: list[list[int]] = [self.kpath_idx[0]] + for i in range(len(self.kpath_idx) - 1): + kpath_cut.append(self.kpath_idx[i + 1][1:]) + + x_inc = xsplit / (len(kpath_cut[0]) - 1) + for _ in range(len(kpath_cut[0]) - 1): + pos += x_inc + xscale.append(pos) + + for i in range(len(kpath_cut) - 1): + x_inc = xsplit / len(kpath_cut[i + 1]) + for _ in range(len(kpath_cut[i + 1])): + pos += x_inc + xscale.append(pos) + + self.xscale = np.array(xscale, dtype=float)