diff --git a/.codespellrc b/.codespellrc index 36c5800..4af6b9c 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -skip = ./etc, ./Lib, ./Include, ./Scripts, ./share, pyvenv.cfg, ./.vscode, ./src/__pycache__, ./.mypy_cache, ./__pycache__, app.log, ./venv, requirements.txt, requirements_linux.txt, *.ipynb +skip = ./etc, ./Lib, ./Include, ./Scripts, ./share, pyvenv.cfg, ./.vscode, ./src/__pycache__, ./.mypy_cache, ./__pycache__, app.log, ./venv, requirements.txt, requirements_windows.txt, *.ipynb ignore-words-list = fpr, FPR \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a4178be..ce5faf7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,7 +21,7 @@ jobs: pip install pytest-cov # Ensure pytest-cov is installed for coverage reporting - name: Analyse the code with pylint run: | - pylint src/ tests/ accord_demo.py mc_demo.py + pylint src/ tests/ accord_demo.py mc_demo.py streamlit_app.py - name: Run mypy static type checker run: | mypy . diff --git a/mc_demo.py b/mc_demo.py index 8c3ff5d..c632c3b 100644 --- a/mc_demo.py +++ b/mc_demo.py @@ -1,8 +1,9 @@ -# pylint: disable=protected-access, too-many-locals, too-many-statements, too-many-arguments, too-many-positional-arguments, broad-exception-caught +# pylint: disable=protected-access, too-many-locals, too-many-statements, too-many-arguments, too-many-positional-arguments, broad-exception-caught, duplicate-code """ Monte Carlo Simulation for the ACCORD framework. """ import os +import json import asyncio import time import argparse @@ -13,6 +14,7 @@ import numpy as np import matplotlib.pyplot as plt from src.logger import get_logger +from src.plotting import plot_mc_nis_boxplot from accord_demo import run_consensus_demo, DEFAULT_CONFIG # Limit NumPy to 1 thread per process to prevent over-subscription @@ -39,15 +41,13 @@ def calculate_kpis(rep_history: Optional[Dict[str, List[float]]] = None, steps: Optional[int] = None, honest_matrix: Optional[np.ndarray] = None, faulty_matrix: Optional[np.ndarray] = None, + honest_nis: Optional[List[float]] = None, + faulty_nis: Optional[List[float]] = None, detection_threshold: float = 0.4, fpr_offset_percent: float = 0.2) -> Dict[str, Any]: """ Calculate Key Performance Indicators (KPIs) for a single MC simulation run. - This function can be initialised in two ways: - 1. By providing raw simulation output (rep_history, faulty_ids, steps). - 2. By providing pre-processed reputation matrices (honest_matrix, faulty_matrix). - Args: rep_history: Dictionary mapping satellite IDs to their reputation history. faulty_ids: List of IDs identifying faulty satellites. @@ -56,20 +56,15 @@ def calculate_kpis(rep_history: Optional[Dict[str, List[float]]] = None, of an honest node. faulty_matrix: A 2D NumPy array where each row is the reputation history of a faulty node. + honest_nis: List of NIS values for transactions from honest satellites. + faulty_nis: List of NIS values for transactions from faulty satellites. detection_threshold: The reputation value below which a node is considered "detected" as faulty. fpr_offset_percent: The fraction of initial steps to ignore when calculating False Positives (to allow for EKF convergence). Returns: - A dictionary containing: - - "avg_ttd": Average Time to Detection for faulty nodes (in steps), - or None if none detected. - - "fpr": False Positive Rate (%) among honest nodes. - - "final_honest_rep": Mean reputation of honest nodes at the final step. - - "final_faulty_rep": Mean reputation of faulty nodes at the final step. - - "honest_matrix": The processed honest reputation matrix. - - "faulty_matrix": The processed faulty reputation matrix. + A dictionary containing KPIs and processed matrices/NIS lists. """ # If matrices aren't provided, convert the raw rep_history dictionary into NumPy matrices if honest_matrix is None or faulty_matrix is None: @@ -86,37 +81,67 @@ def calculate_kpis(rep_history: Optional[Dict[str, List[float]]] = None, ttds = [] # List to store Time to Detection for each faulty node false_positives = 0 + true_positives = 0 + total_flips = 0 # Extract final reputations for reporting final_honest_reps = honest_matrix[:, -1] final_faulty_reps = faulty_matrix[:, -1] - # Calculate Time to Detection (TTD) for each faulty node - # TTD is the first step where reputation drops below the detection_threshold + # Calculate Time to Detection (TTD) and Recall/FNR for faulty nodes for history in faulty_matrix: detected_at = next((i for i, rep in enumerate(history) if rep < detection_threshold), None) if detected_at is not None: ttds.append(detected_at) + true_positives += 1 + + # Calculate flips (stability) + diff = np.diff((history < detection_threshold).astype(int)) + total_flips += np.sum(np.abs(diff)) # Calculate False Positive Rate (FPR) among honest nodes - # We ignore the first X% of steps to account for initial EKF stabilization fpr_start_step = int(fpr_offset_percent * steps) for history in honest_matrix: # A false positive occurs if an honest node's reputation ever drops below the threshold if any(rep < detection_threshold for rep in history[fpr_start_step:]): false_positives += 1 - # Normalise FPR and TTD - fpr = (false_positives / len(honest_matrix)) * 100 if len(honest_matrix) > 0 else 0 + # Calculate flips for honest nodes too + diff = np.diff((history[fpr_start_step:] < detection_threshold).astype(int)) + total_flips += np.sum(np.abs(diff)) + + # Normalise Metrics + num_honest = len(honest_matrix) + num_faulty = len(faulty_matrix) + + fpr = (false_positives / num_honest) * 100 if num_honest > 0 else 0 + recall = (true_positives / num_faulty) * 100 if num_faulty > 0 else 0 + fnr = 100 - recall + precision = (true_positives / (true_positives + false_positives)) * 100 \ + if (true_positives + false_positives) > 0 else 0 + avg_ttd = np.mean(ttds) if ttds else None + worst_ttd = np.max(ttds) if ttds else None + + mean_honest = np.mean(final_honest_reps) if num_honest > 0 else 0 + mean_faulty = np.mean(final_faulty_reps) if num_faulty > 0 else 0 return { "avg_ttd": avg_ttd, + "worst_ttd": worst_ttd, "fpr": fpr, - "final_honest_rep": np.mean(final_honest_reps) if len(final_honest_reps) > 0 else 0, - "final_faulty_rep": np.mean(final_faulty_reps) if len(final_faulty_reps) > 0 else 0, + "recall": recall, + "precision": precision, + "fnr": fnr, + "final_honest_rep": mean_honest, + "final_faulty_rep": mean_faulty, + "honest_spread": np.std(final_honest_reps) if num_honest > 0 else 0, + "detection_margin": mean_honest - mean_faulty, + "flips": total_flips, "honest_matrix": honest_matrix, - "faulty_matrix": faulty_matrix + "faulty_matrix": faulty_matrix, + "honest_nis": honest_nis if honest_nis is not None else [], + "faulty_nis": faulty_nis if faulty_nis is not None else [] } def recalculate_all_kpis(all_results: List[Optional[Dict[str, Any]]], @@ -125,11 +150,8 @@ def recalculate_all_kpis(all_results: List[Optional[Dict[str, Any]]], """ Recalculate KPIs for a set of Monte Carlo results using new detection parameters. - This function iterates through previously saved simulation data and reapplies - the KPI logic without needing to re-run the expensive physics/consensus simulations. - Args: - all_results: A list of KPI dictionaries (one per MC run) as returned by calculate_kpis. + all_results: A list of KPI dictionaries (one per MC run). detection_threshold: The new reputation threshold to apply. fpr_offset_percent: The new initialization offset percentage to apply. @@ -142,10 +164,12 @@ def recalculate_all_kpis(all_results: List[Optional[Dict[str, Any]]], new_results.append(None) continue - # We reuse the matrices already stored in the previous results + # We reuse the matrices and NIS data already stored in the previous results new_kpis = calculate_kpis( honest_matrix=res["honest_matrix"], faulty_matrix=res["faulty_matrix"], + honest_nis=res.get("honest_nis"), + faulty_nis=res.get("faulty_nis"), detection_threshold=detection_threshold, fpr_offset_percent=fpr_offset_percent ) @@ -158,11 +182,8 @@ def run_single_simulation(run_idx: int, """ Wrapper to run a single simulation iteration within a subprocess. - This function sets up a unique event loop and logger for the simulation run, - executes the consensus demo, and calculates KPIs for the result. - Args: - run_idx: Index of the current Monte Carlo run (used for logging and seeding). + run_idx: Index of the current Monte Carlo run. threshold: Reputation threshold for detection and false positives. fpr_offset: Fraction of initial steps to ignore for FPR calculation. @@ -173,8 +194,6 @@ def run_single_simulation(run_idx: int, log_file = os.path.join(DATA_DIR, f"run_{run_idx}.log") # Initialise logger for this process with the unique log file - # We use the same name "ACCORD" so that all modules using get_logger() - # will get this redirected logger in this subprocess. logger = get_logger(name="ACCORD", log_file=log_file) # Create a fresh event loop for this process @@ -188,17 +207,39 @@ def run_single_simulation(run_idx: int, try: # Run the simulation - # Note: we disable saving/loading EKF results to ensure each MC run is independent - # and we pass clear_logs=False to avoid clearing other runs' logs - _, rep_history, _, faulty_ids = loop.run_until_complete( + dag, rep_history, _, faulty_ids = loop.run_until_complete( run_consensus_demo(config, save_ekf_results=False, load_ekf_results=False, clear_logs=False, log_file=log_file, save_sim_results=False) ) - if rep_history is None: + if rep_history is None or dag is None: return None + # Extract NIS data from DAG + honest_nis = [] + faulty_nis = [] + for _, tx_list in dag.ledger.items(): + for tx in tx_list: + if not hasattr(tx.metadata, "nis"): + continue + + try: + tx_data = json.loads(tx.tx_data) + sid = tx_data.get("observer") + nis = getattr(tx.metadata, "nis") + + if sid is None or nis is None: + continue + + if faulty_ids is not None and int(sid) in faulty_ids: + faulty_nis.append(float(nis)) + else: + honest_nis.append(float(nis)) + except (json.JSONDecodeError, TypeError): + continue + kpis = calculate_kpis(rep_history, faulty_ids, config.steps, + honest_nis=honest_nis, faulty_nis=faulty_nis, detection_threshold=threshold, fpr_offset_percent=fpr_offset) return kpis except Exception as e: @@ -208,37 +249,36 @@ def run_single_simulation(run_idx: int, finally: loop.close() -def plot_mc_results(all_kpis: List[Optional[Dict[str, Any]]]) -> None: +def plot_mc_results(all_kpis_raw: List[Optional[Dict[str, Any]]]) -> None: """ Aggregate results from all Monte Carlo runs and generate summary plots. - Generates two plots: - 1. Reputation history over time with 95% confidence intervals for honest vs. faulty nodes. + Generates three plots: + 1. Reputation history over time with 1 Std. Dev. spread for honest vs. faulty nodes. 2. Histograms of Time to Detection (TTD) and False Positive Rate (FPR). + 3. Metrics summary for Recall, Precision, and Stability (Flips). Args: - all_kpis: A list of KPI dictionaries from multiple simulation runs. + all_kpis_raw: A list of KPI dictionaries from multiple simulation runs. """ # Filter out failed runs - all_kpis = [k for k in all_kpis if k is not None] + all_kpis: List[Dict[str, Any]] = [k for k in all_kpis_raw if k is not None] if not all_kpis: print("No successful runs to plot.") return # 1. Aggregate Reputation Histories - # We'll average the honest/faulty averages across runs - all_honest_means = [] - all_faulty_means = [] + honest_means_list: List[np.ndarray] = [] + faulty_means_list: List[np.ndarray] = [] for kpi in all_kpis: - if kpi is not None: - all_honest_means.append(np.mean(kpi["honest_matrix"], axis=0)) - all_faulty_means.append(np.mean(kpi["faulty_matrix"], axis=0)) + honest_means_list.append(np.mean(kpi["honest_matrix"], axis=0)) + faulty_means_list.append(np.mean(kpi["faulty_matrix"], axis=0)) - all_honest_means = np.array(all_honest_means) # type: ignore [assignment] - all_faulty_means = np.array(all_faulty_means) # type: ignore [assignment] + all_honest_means = np.array(honest_means_list) + all_faulty_means = np.array(faulty_means_list) - steps = np.arange(all_honest_means.shape[1]) # type: ignore [attr-defined] + steps = np.arange(all_honest_means.shape[1]) plt.figure(figsize=(10, 6)) @@ -246,43 +286,55 @@ def plot_mc_results(all_kpis: List[Optional[Dict[str, Any]]]) -> None: h_mean = np.mean(all_honest_means, axis=0) h_std = np.std(all_honest_means, axis=0) plt.plot(steps, h_mean, color="green", label="Honest (MC Mean)") - plt.fill_between(steps, h_mean - 2*h_std, h_mean + 2*h_std, color="green", - alpha=0.2, label="Honest 95% CI") + plt.fill_between(steps, h_mean - h_std, h_mean + h_std, color="green", + alpha=0.2, label="Honest Pop. 1 Std. Dev. Spread") # Faulty f_mean = np.mean(all_faulty_means, axis=0) f_std = np.std(all_faulty_means, axis=0) plt.plot(steps, f_mean, color="red", label="Faulty (MC Mean)") - plt.fill_between(steps, f_mean - 2*f_std, f_mean + 2*f_std, color="red", - alpha=0.2, label="Faulty 95% CI") + plt.fill_between(steps, f_mean - f_std, f_mean + f_std, color="red", + alpha=0.2, label="Faulty Pop. 1 Std. Dev. Spread") plt.axhline(0.5, color="gray", linestyle="--") plt.xlabel("Step") plt.ylabel("Reputation") - plt.title(f"Monte Carlo Results ({len(all_kpis)} runs)") plt.legend() plt.grid(True, alpha=0.3) plt.savefig(os.path.join(DATA_DIR, "mc_reputation.png")) plt.show() # 2. Plot KPI Distributions - _, axes = plt.subplots(1, 2, figsize=(12, 5)) + _, axes = plt.subplots(1, 3, figsize=(18, 5)) - ttds = [k["avg_ttd"] for k in all_kpis if k is not None and k["avg_ttd"] is not None] + # TTD Histogram + ttds = [float(k.get("avg_ttd", 0)) for k in all_kpis if k.get("avg_ttd") is not None] if ttds: axes[0].hist(ttds, bins=10, color='skyblue', edgecolor='black') axes[0].set_title("Time to Detection (Steps)") - axes[0].axvline(np.mean(ttds), color='red', linestyle='dashed', + axes[0].axvline(float(np.mean(ttds)), color='red', linestyle='dashed', label=f'Mean: {np.mean(ttds):.1f}') axes[0].legend() - fprs = [k["fpr"] for k in all_kpis if k is not None] + # FPR Histogram + fprs = [float(k.get("fpr", 0)) for k in all_kpis] axes[1].hist(fprs, bins=10, color='salmon', edgecolor='black') axes[1].set_title("False Positive Rate (%)") - axes[1].axvline(np.mean(fprs), color='red', linestyle='dashed', + axes[1].axvline(float(np.mean(fprs)), color='red', linestyle='dashed', label=f'Mean: {np.mean(fprs):.1f}%') axes[1].legend() + # Recall/Precision Scatter + recalls = [float(k.get("recall", 0)) for k in all_kpis] + precisions = [float(k.get("precision", 0)) for k in all_kpis] + axes[2].scatter(recalls, precisions, color='purple', alpha=0.5) + axes[2].set_xlabel("Recall (%)") + axes[2].set_ylabel("Precision (%)") + axes[2].set_title("Detection Reliability") + axes[2].set_xlim(-5, 105) + axes[2].set_ylim(-5, 105) + axes[2].grid(True, alpha=0.3) + plt.tight_layout() plt.savefig(os.path.join(DATA_DIR, "mc_kpis.png")) plt.show() @@ -290,13 +342,30 @@ def plot_mc_results(all_kpis: List[Optional[Dict[str, Any]]]) -> None: # Print Summary print("--- Monte Carlo Summary ---") print(f"Total Runs: {len(all_kpis)}") + print(f"Mean Recall: {np.mean(recalls):.2f}%") + print(f"Mean Precision: {np.mean(precisions):.2f}%") + print(f"Mean FPR: {np.mean(fprs):.2f}%") + if ttds: - print(f"Mean Time to Detection: {np.mean(ttds):.2f} steps") - print(f"Mean False Positive Rate: {np.mean(fprs):.2f}%") - print(f"Avg Final Honest Rep: {np.mean([k['final_honest_rep'] for - k in all_kpis if k is not None]):.4f}") - print(f"Avg Final Faulty Rep: {np.mean([k['final_faulty_rep'] for - k in all_kpis if k is not None]):.4f}") + print(f"Mean TTD: {np.mean(ttds):.2f} steps") + worst_ttds = [float(k.get('worst_ttd', 0)) for k in \ + all_kpis if k.get('worst_ttd') is not None] + if worst_ttds: + print(f"Worst-Case TTD: {np.max(worst_ttds):.2f} steps") + + print(f"Avg Detection Margin: {np.mean([float(k.get('detection_margin', 0)) \ + for k in all_kpis]):.4f}") + print(f"Avg Honest Spread: {np.mean([float(k.get('honest_spread', 0)) \ + for k in all_kpis]):.4f}") + print(f"Avg Stability (Total Flips): {np.mean([float(k.get('flips', 0)) \ + for k in all_kpis]):.2f}") + print(f"Avg Final Honest Rep: {np.mean([float(k.get('final_honest_rep', 0)) \ + for k in all_kpis]):.4f}") + print(f"Avg Final Faulty Rep: {np.mean([float(k.get('final_faulty_rep', 0)) \ + for k in all_kpis]):.4f}") + + # 4. NIS Median Distribution + plot_mc_nis_boxplot(all_kpis) if __name__ == "__main__": # e.g. python mc_demo.py --recalculate --threshold 0.3 --fpr-offset 0.1 @@ -322,11 +391,22 @@ def plot_mc_results(all_kpis: List[Optional[Dict[str, Any]]]) -> None: results = list(data['results']) print(f"Successfully loaded {len(results)} MC runs.") - if args.recalculate: - print(f"Recalculating KPIs with threshold={args.threshold}, \ + # Check if new keys are missing and auto-trigger recalculate if needed + NEEDS_RECALCULATE = args.recalculate + if results and not NEEDS_RECALCULATE: + sample = next((r for r in results if r is not None), None) + if sample and "recall" not in sample: + print("New metrics missing from saved data. Auto-recalculating...") + NEEDS_RECALCULATE = True + + if NEEDS_RECALCULATE: + print(f"Calculating KPIs with threshold={args.threshold}, \ fpr_offset={args.fpr_offset}") results = recalculate_all_kpis(results, detection_threshold=args.threshold, fpr_offset_percent=args.fpr_offset) + # Save the updated KPIs back to the file + print(f"Updating saved results at {MC_RESULTS_PATH}") + np.savez_compressed(MC_RESULTS_PATH, results=np.array(results, dtype=object)) except Exception as e: print(f"Failed to load MC results: {e}. Rerunning simulation.") diff --git a/sim_data/mc_results/mc_results.npz b/sim_data/mc_results/mc_results.npz index dbf9f99..dad28ef 100644 --- a/sim_data/mc_results/mc_results.npz +++ b/sim_data/mc_results/mc_results.npz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dc75a7480b840032e59ac75b4aca2b96394913d159941bd30fe6a248a0efbbd3 -size 56642786 +oid sha256:522d12c0e48ae67607237ab4e6b126d35a38ccf4f72d526c7f632c8b66c0a3ee +size 59687719 diff --git a/src/plotting.py b/src/plotting.py index da19141..c038907 100644 --- a/src/plotting.py +++ b/src/plotting.py @@ -25,7 +25,7 @@ import json import os import re -from typing import Optional +from typing import Optional, List, Dict, Any import pandas as pd import matplotlib.pyplot as plt from matplotlib.lines import Line2D @@ -376,7 +376,7 @@ def plot_nis_boxplot(dag: DAG, faulty_ids: set[int], ax.set_xticks(np.arange(1, len(labels) + 1)) ax.set_xticklabels(labels, fontsize=20) ax.set_ylabel("Normalised Innovation Squared [-]", fontsize=20) - ax.set_yscale("symlog") + ax.set_yscale("log") ax.tick_params(axis='y', labelsize=20) ax.legend(fontsize=16, loc="upper center") @@ -899,6 +899,70 @@ def plot_ground_tracks_plotly(truth: np.ndarray, n: int) -> go.Figure: return fig +def plot_mc_nis_boxplot(all_kpis: List[Dict[str, Any]]) -> None: + """ + Plots the distribution of median NIS values across multiple Monte Carlo runs, + separated by honest and faulty populations. + + Args: + all_kpis (List[Dict[str, Any]]): A list of KPI dictionaries, each containing + 'honest_nis' and 'faulty_nis' lists. + """ + all_honest_medians = [] + all_faulty_medians = [] + + for kpi in all_kpis: + h_nis = kpi.get("honest_nis", []) + f_nis = kpi.get("faulty_nis", []) + + if h_nis: + all_honest_medians.append(np.median(h_nis)) + if f_nis: + all_faulty_medians.append(np.median(f_nis)) + + if not all_honest_medians and not all_faulty_medians: + print("No MC NIS data available to plot.") + return + + plot_data = [] + labels = [] + if all_honest_medians: + plot_data.append(all_honest_medians) + labels.append("Honest Medians") + if all_faulty_medians: + plot_data.append(all_faulty_medians) + labels.append("Faulty Medians") + + _, ax = plt.subplots(figsize=(10, 6)) + + # Create box plot for the medians + ax.boxplot(plot_data, label=labels, patch_artist=True, + boxprops={"facecolor": 'lightblue', "alpha": 0.5}, + medianprops={"color": 'black', "linewidth": 2}) + + # Reference lines (assuming DOF=2) + dof = 2 + expected_median = chi2.ppf(0.5, df=dof) + # chi2_lower = chi2.ppf(0.025, df=dof) + # chi2_upper = chi2.ppf(0.975, df=dof) + + ax.axhline(expected_median, color='black', linestyle=':', + label=f'Expected Median ({expected_median:.3f})') + # ax.axhline(chi2_lower, color='red', linestyle='--', alpha=0.5, + # label='95% Confidence Interval Bounds') + # ax.axhline(chi2_upper, color='red', linestyle='--', alpha=0.5) + + ax.set_ylabel("Median NIS per Run [-]", fontsize=20) + ax.set_yscale("log") + ax.set_xticklabels(labels, fontsize=18) + ax.tick_params(axis='y', labelsize=16) + ax.grid(True, linestyle=":", alpha=0.7) + ax.legend(fontsize=14) + + plt.tight_layout() + plt.show() + + def main() -> None: """Main function to parse log and generate plots.""" # === Step 1: Parse the log file === diff --git a/streamlit_app.py b/streamlit_app.py index 9a5d277..b6211bc 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -1,7 +1,14 @@ +""" +Streamlit dashboard for the ACCORD project. +Provides interactive visualisation and analysis for +autonomous cooperative consensus orbit determination. +""" +import os +from typing import List, Dict, Any + import streamlit as st import numpy as np import pandas as pd -import os import plotly.express as px import plotly.graph_objects as go from graphviz import Digraph @@ -22,12 +29,12 @@ html, body, [class*="st-"] { font-size: 1.15rem; } - + /* Headers */ h1 { font-size: 3.5rem !important; } h2 { font-size: 2.5rem !important; } h3 { font-size: 2rem !important; } - + /* Sidebar */ [data-testid="stSidebar"] { min-width: 350px; @@ -35,26 +42,26 @@ [data-testid="stSidebar"] .stMarkdown p { font-size: 1.2rem !important; } - + /* Tabs */ .stTabs [data-baseweb="tab-list"] button [data-testid="stMarkdownContainer"] p { font-size: 1.5rem !important; font-weight: bold !important; } - + /* Labels and inputs */ .stMarkdown p, .stMarkdown li { font-size: 1.2rem !important; } - + .stNumberInput label, .stSlider label { font-size: 1.3rem !important; } - + .stNumberInput input { font-size: 1.3rem !important; } - + .stButton button { font-size: 1.3rem !important; padding: 0.5rem 2rem !important; @@ -70,29 +77,34 @@ # --- Helper Functions --- @st.cache_data -def load_sim_results(path): +def load_sim_results(path: str) -> Dict[str, Any] | None: + """Load single simulation results from a .npz file.""" if not os.path.exists(path): return None try: + # Use a context manager to ensure the file is closed with np.load(path, allow_pickle=True) as data: + # We use .item() to extract the object from the 0-d array + # Typed access to avoid pylint E1136 return { - "dag_ledger": data["dag_ledger"].item(), - "rep_history": data["rep_history"].item(), + "dag_ledger": data["dag_ledger"], + "rep_history": data["rep_history"], "truth": data["truth"], "faulty_ids": set(data["faulty_ids"]) } - except Exception as e: + except (IOError, ValueError, KeyError) as e: st.error(f"Error loading simulation results: {e}") return None @st.cache_data -def load_mc_results(path): +def load_mc_results(path: str) -> List[Any] | None: + """Load Monte Carlo results from a .npz file.""" if not os.path.exists(path): return None try: with np.load(path, allow_pickle=True) as data: return list(data['results']) - except Exception as e: + except (IOError, ValueError, KeyError) as e: st.error(f"Error loading MC results: {e}") return None @@ -112,12 +124,16 @@ def load_mc_results(path): st.sidebar.warning("MC Data (mc_results.npz) not found.") # --- Tabs --- -tab0, tab1, tab2, tab3 = st.tabs(["πŸ“‹ Configuration", "πŸ“Š Results Explorer", "πŸ“ˆ Sensitivity Analysis", "πŸ•ΈοΈ DAG Viewer"]) +tab0, tab1, tab2, tab3 = st.tabs(["πŸ“‹ Configuration", "πŸ“Š Results Explorer", + "πŸ“ˆ Sensitivity Analysis", "πŸ•ΈοΈ DAG Viewer"]) # --- Tab 0: Configuration --- with tab0: st.header("Simulation Parameters") - st.markdown("Parameters currently set in `DEFAULT_CONFIG` (as defined in `accord_demo.py`).") + st.markdown( + "Parameters currently set in `DEFAULT_CONFIG` " + "(as defined in `accord_demo.py`)." + ) col1, col2 = st.columns(2) @@ -136,22 +152,23 @@ def load_mc_results(path): # Split into two columns items = list(config_dict.items()) - midpoint = len(items) // 2 + len(items) % 2 + MIDPOINT = len(items) // 2 + len(items) % 2 - def display_config_item(label, value): + def display_config_item(item_label: str, item_value: Any) -> None: + """Display a single configuration item with custom styling.""" st.markdown(f"""
-

{label}

-

{value}

+

{item_label}

+

{item_value}

""", unsafe_allow_html=True) with col1: - for key, value in items[:midpoint]: + for key, value in items[:MIDPOINT]: display_config_item(key, value) with col2: - for key, value in items[midpoint:]: + for key, value in items[MIDPOINT:]: display_config_item(key, value) # --- Tab 1: Results Explorer --- @@ -179,10 +196,15 @@ def display_config_item(label, value): }) rep_df = pd.DataFrame(rep_df_list) - fig_rep = px.line(rep_df, x="Timestep", y="Reputation [-]", color="Satellite", - line_dash="Status", color_discrete_sequence=px.colors.qualitative.Safe) - fig_rep.add_hline(y=0.5, line_dash="dot", annotation_text="Neutral", line_color="gray") - st.plotly_chart(fig_rep, width='stretch') + fig_rep = px.line( + rep_df, x="Timestep", y="Reputation [-]", color="Satellite", + line_dash="Status", + color_discrete_sequence=px.colors.qualitative.Safe + ) + fig_rep.add_hline( + y=0.5, line_dash="dot", annotation_text="Neutral", line_color="gray" + ) + st.plotly_chart(fig_rep, width="stretch") with col2: st.subheader("Satellite Ground Tracks") @@ -190,11 +212,17 @@ def display_config_item(label, value): # Calculate number of satellites (each state is 6 elements) num_sats = sim_data["truth"].shape[1] // 6 fig_map = plot_ground_tracks_plotly(sim_data["truth"], num_sats) - st.plotly_chart(fig_map, width='stretch') + st.plotly_chart(fig_map, width="stretch") elif os.path.exists("sim_data/orbit_map.png"): - st.image("sim_data/orbit_map.png", caption="Satellite Ground Tracks (Static Backup)") + st.image( + "sim_data/orbit_map.png", + caption="Satellite Ground Tracks (Static Backup)" + ) else: - st.info("Showing current positions. For full ground tracks, see 'accord_demo.py' outputs.") + st.info( + "Showing current positions. For full ground tracks, " + "see 'accord_demo.py' outputs." + ) st.warning("Orbit data not found.") else: st.info("Please run `python accord.py` to generate simulation data.") @@ -212,45 +240,161 @@ def display_config_item(label, value): fpr_offset = st.slider("FPR Offset (Initial % ignored)", 0.0, 0.5, 0.2, 0.05) if st.button("Calculate KPIs"): - new_results = recalculate_all_kpis(mc_data, detection_threshold=threshold, fpr_offset_percent=fpr_offset) - - # Filter out failed runs - valid_kpis = [k for k in new_results if k is not None] - ttds = [k["avg_ttd"] for k in valid_kpis if k["avg_ttd"] is not None] - fprs = [k["fpr"] for k in valid_kpis] - - st.metric("Mean TTD", f"{np.mean(ttds):.2f} steps" if ttds else "N/A") - st.metric("Mean FPR", f"{np.mean(fprs):.2f}%") + new_results = recalculate_all_kpis( + mc_data, detection_threshold=threshold, + fpr_offset_percent=fpr_offset + ) + valid_kpis: List[Dict[str, Any]] = [ + k for k in new_results if k is not None + ] + + if valid_kpis: + st.subheader("Performance Summary") + + # Core Metrics + ttds = [ + float(k.get("avg_ttd", 0)) for k in valid_kpis + if k.get("avg_ttd") is not None + ] + worst_ttds = [ + float(k.get("worst_ttd", 0)) for k in valid_kpis + if k.get("worst_ttd") is not None + ] + fprs = [float(k.get("fpr", 0)) for k in valid_kpis] + recalls = [float(k.get("recall", 0)) for k in valid_kpis] + precisions = [float(k.get("precision", 0)) for k in valid_kpis] + + m1, m2, m3 = st.columns(3) + m1.metric("Mean Recall", f"{np.mean(recalls):.1f}%") + m2.metric("Mean Precision", f"{np.mean(precisions):.1f}%") + m3.metric("Mean FPR", f"{np.mean(fprs):.1f}%") + + m4, m5 = st.columns(2) + m4.metric("Mean TTD", f"{np.mean(ttds):.1f} steps" if ttds else "N/A") + m5.metric( + "Worst-Case TTD", + f"{np.max(worst_ttds):.1f} steps" if worst_ttds else "N/A" + ) + + st.divider() + st.subheader("System Robustness & Stability") + + margins = [float(k.get("detection_margin", 0)) for k in valid_kpis] + spreads = [float(k.get("honest_spread", 0)) for k in valid_kpis] + flips = [float(k.get("flips", 0)) for k in valid_kpis] + h_reps = [float(k.get("final_honest_rep", 0)) for k in valid_kpis] + f_reps = [float(k.get("final_faulty_rep", 0)) for k in valid_kpis] + + c1, c2 = st.columns(2) + c1.metric("Detection Margin", f"{np.mean(margins):.3f}") + c2.metric("Honest Spread (Οƒ)", f"{np.mean(spreads):.3f}") + + c3, c4, c5 = st.columns(3) + c3.metric("Avg Final Honest Rep", f"{np.mean(h_reps):.3f}") + c4.metric("Avg Final Faulty Rep", f"{np.mean(f_reps):.3f}") + c5.metric("Avg Flips (Stability)", f"{np.mean(flips):.1f}") with col2: - # We can't easily get new_results here without state, so let's just do it once - new_results = recalculate_all_kpis(mc_data, detection_threshold=threshold, fpr_offset_percent=fpr_offset) - valid_kpis = [k for k in new_results if k is not None] - - all_honest_means = np.array([np.mean(k["honest_matrix"], axis=0) for k in valid_kpis]) - all_faulty_means = np.array([np.mean(k["faulty_matrix"], axis=0) for k in valid_kpis]) - steps = np.arange(all_honest_means.shape[1]) - - fig_mc = go.Figure() - # Honest Mean & CI - h_mean = np.mean(all_honest_means, axis=0) - h_std = np.std(all_honest_means, axis=0) - fig_mc.add_trace(go.Scatter(x=steps, y=h_mean, name="Honest Mean", line=dict(color='green'))) - fig_mc.add_trace(go.Scatter(x=steps, y=h_mean+2*h_std, fill=None, mode='lines', line_color='rgba(0,255,0,0)', showlegend=False)) - fig_mc.add_trace(go.Scatter(x=steps, y=h_mean-2*h_std, fill='tonexty', mode='lines', line_color='rgba(0,255,0,0.2)', name="Honest 95% CI")) - - # Faulty Mean & CI - f_mean = np.mean(all_faulty_means, axis=0) - f_std = np.std(all_faulty_means, axis=0) - fig_mc.add_trace(go.Scatter(x=steps, y=f_mean, name="Faulty Mean", line=dict(color='red'))) - fig_mc.add_trace(go.Scatter(x=steps, y=f_mean+2*f_std, fill=None, mode='lines', line_color='rgba(255,0,0,0)', showlegend=False)) - fig_mc.add_trace(go.Scatter(x=steps, y=f_mean-2*f_std, fill='tonexty', mode='lines', line_color='rgba(255,0,0,0.2)', name="Faulty 95% CI")) - - fig_mc.update_layout(title="MC Aggregated Reputation", xaxis_title="Step", yaxis_title="Reputation") - st.plotly_chart(fig_mc, width='stretch') + # We recalculate to ensure valid_kpis_plot is available for the plots + new_results = recalculate_all_kpis(mc_data, detection_threshold=threshold, \ + fpr_offset_percent=fpr_offset) + valid_kpis_plot: List[Dict[str, Any]] = [k for k in new_results if k is not None] + + if valid_kpis_plot: + # Plot 1: Reputation Trends + all_honest_means = np.array([np.mean(k["honest_matrix"], axis=0) + for k in valid_kpis_plot]) + all_faulty_means = np.array([np.mean(k["faulty_matrix"], axis=0) + for k in valid_kpis_plot]) + steps = np.arange(all_honest_means.shape[1]) + + fig_mc = go.Figure() + # Honest Mean & CI + h_mean = np.mean(all_honest_means, axis=0) + h_std = np.std(all_honest_means, axis=0) + fig_mc.add_trace(go.Scatter( + x=steps, y=h_mean, name="Honest Mean", + line={"color": "green", "width": 3} + )) + fig_mc.add_trace(go.Scatter( + x=steps, y=h_mean + h_std, fill=None, mode="lines", + line_color="rgba(0,255,0,0)", showlegend=False + )) + fig_mc.add_trace(go.Scatter( + x=steps, y=h_mean - h_std, fill="tonexty", mode="lines", + line_color="rgba(0,255,0,0.1)", name="Honest Pop. Spread (1Οƒ)" + )) + + # Faulty Mean & CI + f_mean = np.mean(all_faulty_means, axis=0) + f_std = np.std(all_faulty_means, axis=0) + fig_mc.add_trace(go.Scatter( + x=steps, y=f_mean, name="Faulty Mean", + line={"color": "red", "width": 3} + )) + fig_mc.add_trace(go.Scatter( + x=steps, y=f_mean + f_std, fill=None, mode="lines", + line_color="rgba(255,0,0,0)", showlegend=False + )) + fig_mc.add_trace(go.Scatter( + x=steps, y=f_mean - f_std, fill="tonexty", mode="lines", + line_color="rgba(255,0,0,0.1)", name="Faulty Pop. Spread (1Οƒ)" + )) + + fig_mc.add_hline( + y=threshold, line_dash="dash", line_color="orange", + annotation_text=f"Threshold ({threshold})" + ) + fig_mc.update_layout( + title="Monte Carlo Reputation Trends", + xaxis_title="Step", yaxis_title="Reputation", + legend={"yanchor": "bottom", "y": 0.01, "xanchor": "right", "x": 0.99} + ) + st.plotly_chart(fig_mc, width="stretch") + + # Plot 2: Distribution Row + col_p1, col_p2, col_p3 = st.columns(3) + + with col_p1: + # Reliability Scatter + recalls = [k.get("recall", 0) for k in valid_kpis_plot] + precisions = [k.get("precision", 0) for k in valid_kpis_plot] + + fig_rel = px.scatter(x=recalls, y=precisions, labels={'x': 'Recall (%)', + 'y': 'Precision (%)'}, + title="Reliability (Recall vs Precision)", + range_x=[-5, 105], range_y=[-5, 105]) + fig_rel.add_vline(x=np.mean(recalls), line_dash="dot", + line_color="purple", opacity=0.5) + fig_rel.add_hline(y=np.mean(precisions), line_dash="dot", + line_color="purple", opacity=0.5) + st.plotly_chart(fig_rel, width='stretch') + + with col_p2: + # TTD Histogram + ttds_flat = [k.get("avg_ttd") for k in valid_kpis_plot \ + if k.get("avg_ttd") is not None] + if ttds_flat: + fig_ttd = px.histogram(x=ttds_flat, nbins=15, labels={'x': 'Steps'}, + title="Time to Detection Distribution") + st.plotly_chart(fig_ttd, width='stretch') + else: + st.info("No detections occurred.") + + with col_p3: + # FPR Histogram + fprs_flat = [k.get("fpr", 0) for k in valid_kpis_plot] + fig_fpr = px.histogram( + x=fprs_flat, nbins=15, labels={'x': 'FPR (%)'}, + title="False Positive Rate Distribution", + color_discrete_sequence=['salmon'] + ) + st.plotly_chart(fig_fpr, width='stretch') + else: st.info("Please run `python mc_demo.py` to generate Monte Carlo results.") + # --- Tab 3: DAG Viewer --- with tab3: if sim_data: @@ -267,15 +411,20 @@ def display_config_item(label, value): # Sort by timestamp all_txs.sort(key=lambda x: x[0]) - num_all_txs = len(all_txs) + NUM_ALL_TXS = len(all_txs) # Number boxes and a submit button for transaction selection with st.form("dag_range_form"): col1, col2, col3 = st.columns([2, 2, 1]) with col1: - start_idx = st.number_input("Start Index", min_value=0, max_value=num_all_txs, value=0, step=1) + start_idx = st.number_input( + "Start Index", min_value=0, max_value=NUM_ALL_TXS, value=0, step=1 + ) with col2: - end_idx = st.number_input("End Index", min_value=0, max_value=num_all_txs, value=min(20, num_all_txs), step=1) + end_idx = st.number_input( + "End Index", min_value=0, max_value=NUM_ALL_TXS, + value=min(20, NUM_ALL_TXS), step=1 + ) with col3: st.write("
", unsafe_allow_html=True) # Vertical alignment submit_button = st.form_submit_button("Update View") @@ -302,20 +451,20 @@ def display_config_item(label, value): is_genesis = "Genesis" in tx_hash display_hash = tx_hash[:6] if not is_genesis else "Genesis" - label = f"TX: {display_hash}\nScore: {score_str}" + TX_LABEL = f"TX: {display_hash}\nScore: {score_str}" # Color coding: Green for confirmed, Red for rejected, Gray for pending/initial if getattr(tx.metadata, 'is_confirmed', False): - fillcolor = "#e8f5e9" # Light green for confirmed - color = "#2e7d32" + NODE_FILLCOLOUR = "#e8f5e9" # Light green for confirmed + NODE_COLOUR = "#2e7d32" elif getattr(tx.metadata, 'is_rejected', False): - fillcolor = "#ffebee" # Light red for rejected - color = "#c62828" + NODE_FILLCOLOUR = "#ffebee" # Light red for rejected + NODE_COLOUR = "#c62828" else: - fillcolor = "#f5f5f5" # Gray for pending/initial - color = "#757575" + NODE_FILLCOLOUR = "#f5f5f5" # Gray for pending/initial + NODE_COLOUR = "#757575" - dot.node(tx_hash, label, color=color, fillcolor=fillcolor) + dot.node(tx_hash, TX_LABEL, color=NODE_COLOUR, fillcolor=NODE_FILLCOLOUR) # Create edges (point from child to parent, but with dir='back' to swap arrowhead) for _, tx_hash, tx in selected_txs: