From 9c2f97db2393e05b5260d21cf8d6adc81855cc9e Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Wed, 18 Feb 2026 16:52:01 +0000 Subject: [PATCH] model selection --- ml_peg/app/build_app.py | 149 ++++++++++++++++++++++++- ml_peg/app/utils/register_callbacks.py | 111 +++++++++++++----- ml_peg/app/utils/utils.py | 34 ++++++ 3 files changed, 261 insertions(+), 33 deletions(-) diff --git a/ml_peg/app/build_app.py b/ml_peg/app/build_app.py index ff7b61ef3..ae3aa3a7c 100644 --- a/ml_peg/app/build_app.py +++ b/ml_peg/app/build_app.py @@ -5,10 +5,11 @@ from importlib import import_module import warnings -from dash import Dash, Input, Output, callback +from dash import Dash, Input, Output, callback, ctx, no_update from dash.dash_table import DataTable -from dash.dcc import Store, Tab, Tabs -from dash.html import H1, H3, Div +from dash.dcc import Dropdown, Store, Tab, Tabs +from dash.exceptions import PreventUpdate +from dash.html import H1, H3, Button, Details, Div, Summary from yaml import safe_load from ml_peg.analysis.utils.utils import calc_table_scores, get_table_style @@ -375,6 +376,67 @@ def build_tabs( Tab(label=category_name, value=category_name) for category_name in layouts ] + model_options = [{"label": m, "value": m} for m in MODELS] + + model_filter = Details( + [ + Summary( + "Visible models", + style={"cursor": "pointer", "fontWeight": "bold", "padding": "5px"}, + ), + Div( + [ + Dropdown( + id="model-filter-checklist", + options=model_options, + value=MODELS, + multi=True, + placeholder="Select visible models", + closeOnSelect=False, + style={"fontSize": "13px"}, + ), + Div( + [ + Button( + "Select all", + id="model-filter-select-all", + n_clicks=0, + style={ + "fontSize": "11px", + "padding": "4px 8px", + "backgroundColor": "#6c757d", + "color": "white", + "border": "none", + "borderRadius": "3px", + "cursor": "pointer", + }, + ), + Button( + "Clear", + id="model-filter-clear-all", + n_clicks=0, + style={ + "fontSize": "11px", + "padding": "4px 8px", + "backgroundColor": "#6c757d", + "color": "white", + "border": "none", + "borderRadius": "3px", + "cursor": "pointer", + }, + ), + ], + style={"display": "flex", "gap": "8px", "marginTop": "8px"}, + ), + ], + style={"padding": "8px 12px"}, + ), + ], + id="model-filter-details", + open=True, + style={"marginBottom": "8px", "fontSize": "13px"}, + ) + tabs_layout = [ build_onboarding_modal(), build_tutorial_button(), @@ -382,6 +444,17 @@ def build_tabs( [ H1("ML-PEG"), Tabs(id="all-tabs", value="summary-tab", children=all_tabs), + model_filter, + Store( + id="selected-models-store", + storage_type="session", + data=MODELS, + ), + Store( + id="summary-table-computed-store", + storage_type="session", + data=summary_table.data, + ), Div(id="tabs-content"), ], style={"flex": "1", "marginBottom": "40px"}, @@ -394,6 +467,76 @@ def build_tabs( style={"display": "flex", "flexDirection": "column", "minHeight": "100vh"}, ) + @callback( + Output("model-filter-checklist", "value"), + Output("selected-models-store", "data"), + Input("model-filter-checklist", "value"), + Input("model-filter-select-all", "n_clicks"), + Input("model-filter-clear-all", "n_clicks"), + Input("selected-models-store", "data"), + prevent_initial_call=False, + ) + def sync_model_filter( + checklist_value: list[str] | None, + _select_all: int, + _clear_all: int, + stored_selection: list[str] | None, + ) -> tuple[list[str], list[str] | object]: + """ + Keep the model selector checklist and backing store synchronised. + + Parameters + ---------- + checklist_value + Current selection from the model filter control. + _select_all + Click count for the "Select all" button. + _clear_all + Click count for the "Clear" button. + stored_selection + Previously persisted selection from ``selected-models-store``. + + Returns + ------- + tuple[list[str], list[str] | object] + Updated checklist value and store payload. The second element may be + ``dash.no_update`` when only syncing from store to UI. + """ + trigger_id = ctx.triggered_id + stored = stored_selection if stored_selection is not None else MODELS + + if trigger_id in (None, "selected-models-store"): + return stored, no_update + if trigger_id == "model-filter-select-all": + return MODELS, MODELS + if trigger_id == "model-filter-clear-all": + return [], [] + if trigger_id == "model-filter-checklist": + selected = checklist_value or [] + return selected, selected + raise PreventUpdate + + @callback( + Output("model-filter-details", "open"), + Input("all-tabs", "value"), + prevent_initial_call=False, + ) + def toggle_filter_panel(tab: str) -> bool: + """ + Expand the visible-models panel on the summary tab only. + + Parameters + ---------- + tab + Currently selected tab identifier. + + Returns + ------- + bool + ``True`` when the summary tab is active, otherwise ``False``. + """ + return tab == "summary-tab" + @callback(Output("tabs-content", "children"), Input("all-tabs", "value")) def select_tab(tab) -> Div: """ diff --git a/ml_peg/app/utils/register_callbacks.py b/ml_peg/app/utils/register_callbacks.py index 7a1adc675..6ad77cace 100644 --- a/ml_peg/app/utils/register_callbacks.py +++ b/ml_peg/app/utils/register_callbacks.py @@ -18,6 +18,7 @@ Thresholds, build_level_of_theory_warnings, clean_thresholds, + filter_rows_by_models, format_metric_columns, format_tooltip_headers, get_scores, @@ -48,18 +49,23 @@ def register_summary_table_callbacks( Output( "summary-table", "tooltip_data" ), # Needed to display model config & level of theory tooltips + Output("summary-table-computed-store", "data", allow_duplicate=True), Input("all-tabs", "value"), Input("summary-table-weight-store", "data"), + Input("selected-models-store", "data"), State("summary-table-scores-store", "data"), State("summary-table", "data"), - prevent_initial_call=False, + State("summary-table-computed-store", "data"), + prevent_initial_call="initial_duplicate", ) def update_summary_table( tabs_value: str, stored_weights: dict[str, float], + selected_models: list[str] | None, stored_scores: dict[str, dict[str, float]], summary_data: list[dict], - ) -> tuple[list[dict], list[dict], list[dict]]: + computed_store: list[dict] | None, + ) -> tuple[list[dict], list[dict], list[dict], list[dict]]: """ Update summary table when scores/weights change, and sync on tab change. @@ -69,30 +75,42 @@ def update_summary_table( Value of selected tab. Parameter unused, but required to register Input. stored_weights Stored summary weights dictionary. + selected_models + List of model names currently selected in the model filter. stored_scores Stored scores for table scores. summary_data - Data from summary table to be updated. + Data from summary table to be updated (may be filtered). + computed_store + Full unfiltered summary rows used as the source of truth. Returns ------- - tuple[list[dict], list[dict], list[dict]] - Updated rows, conditional styling rules, and tooltip rows. + tuple[list[dict], list[dict], list[dict], list[dict]] + Updated rows, conditional styling rules, tooltip rows, and full rows + written back to the computed store. """ + # Use the full unfiltered store as source so re-selecting models works. + # Fall back to summary_data only before the store is first populated. + source_data = computed_store or summary_data + # Update table from stored scores if stored_scores: - for row in summary_data: + for row in source_data: for tab, values in stored_scores.items(): - row[tab] = values[row["MLIP"]] + if row["MLIP"] in values: + row[tab] = values[row["MLIP"]] - # Update table contents - updated_rows, base_style = update_score_style(summary_data, stored_weights) + # Score all rows, write full rows back to store, then filter for display + updated_rows, _ = update_score_style(source_data, stored_weights) + filtered_rows = filter_rows_by_models(updated_rows, selected_models) warning_styles, tooltip_rows = build_level_of_theory_warnings( - updated_rows, model_levels, metric_levels, model_configs + filtered_rows, model_levels, metric_levels, model_configs ) + base_style = get_table_style(filtered_rows) if filtered_rows else [] style_with_warnings = base_style + warning_styles - return updated_rows, style_with_warnings, tooltip_rows + return filtered_rows, style_with_warnings, tooltip_rows, updated_rows def register_category_table_callbacks( @@ -135,6 +153,7 @@ def register_category_table_callbacks( Input(f"{table_id}-thresholds-store", "data"), Input("all-tabs", "value"), Input(f"{table_id}-normalized-toggle", "value"), + Input("selected-models-store", "data"), State(f"{table_id}-raw-data-store", "data"), State(f"{table_id}-computed-store", "data"), State(f"{table_id}-raw-tooltip-store", "data"), @@ -146,6 +165,7 @@ def update_benchmark_table_scores( stored_threshold: dict | None, _tabs_value: str, toggle_value: list[str] | None, + selected_models: list[str] | None, stored_raw_data: list[dict] | None, stored_computed_data: list[dict] | None, raw_tooltips: dict[str, str] | None, @@ -172,6 +192,8 @@ def update_benchmark_table_scores( Current tab identifier (unused, required to trigger on tab change). toggle_value Value of toggle to show normalised values. + selected_models + List of model names currently selected in the model filter. stored_raw_data Table data. stored_computed_data @@ -204,8 +226,14 @@ def apply_levels_of_theory( stored_raw_data, stored_computed_data, thresholds, toggle_value ) scored_rows = calc_metric_scores(stored_raw_data, thresholds=thresholds) - style = get_table_style(display_rows, scored_data=scored_rows) - style, tooltip_data = apply_levels_of_theory(display_rows, style) + filtered_rows = filter_rows_by_models(display_rows, selected_models) + filtered_scores = filter_rows_by_models(scored_rows, selected_models) + style = ( + get_table_style(filtered_rows, scored_data=filtered_scores) + if filtered_rows + else [] + ) + style, tooltip_data = apply_levels_of_theory(filtered_rows, style) columns = format_metric_columns( current_columns, thresholds, show_normalized ) @@ -213,7 +241,7 @@ def apply_levels_of_theory( raw_tooltips, thresholds, show_normalized ) return ( - display_rows, + filtered_rows, style, tooltip_data, columns, @@ -232,14 +260,20 @@ def apply_levels_of_theory( display_rows = get_scores( metrics_data, scored_rows, thresholds, toggle_value ) - style = get_table_style(display_rows, scored_data=scored_rows) - style, tooltip_data = apply_levels_of_theory(display_rows, style) + filtered_rows = filter_rows_by_models(display_rows, selected_models) + filtered_scores = filter_rows_by_models(scored_rows, selected_models) + style = ( + get_table_style(filtered_rows, scored_data=filtered_scores) + if filtered_rows + else [] + ) + style, tooltip_data = apply_levels_of_theory(filtered_rows, style) columns = format_metric_columns( current_columns, thresholds, show_normalized ) tooltips = format_tooltip_headers(raw_tooltips, thresholds, show_normalized) return ( - display_rows, + filtered_rows, style, tooltip_data, columns, @@ -257,6 +291,7 @@ def apply_levels_of_theory( Output(f"{table_id}-computed-store", "data", allow_duplicate=True), Input(f"{table_id}-weight-store", "data"), Input("all-tabs", "value"), + Input("selected-models-store", "data"), State(table_id, "data"), State(f"{table_id}-computed-store", "data"), prevent_initial_call="initial_duplicate", @@ -264,11 +299,10 @@ def apply_levels_of_theory( def update_table_scores( stored_weights: dict[str, float] | None, _tabs_value: str, + selected_models: list[str] | None, table_data: list[dict] | None, computed_store: list[dict] | None, ) -> tuple[list[dict], list[dict], list[dict], list[dict]]: - trigger_id = ctx.triggered_id - def apply_levels( rows: list[dict], base_style: list[dict] ) -> tuple[list[dict], list[dict]]: @@ -279,27 +313,30 @@ def apply_levels( tooltips = tooltip_rows if tooltip_rows else [{} for _ in rows] return combined_style, tooltips - if trigger_id == "all-tabs" and computed_store: - style = get_table_style(computed_store) - style, tooltip_data = apply_levels(computed_store, style) - return computed_store, style, tooltip_data, computed_store - - if not table_data: + # Always use computed_store (full unfiltered rows) as the source so + # that re-selecting a model restores it. Fall back to table_data only + # on first load before the store is populated. + source_data = computed_store or table_data + if not source_data: raise PreventUpdate - scored_rows, style = update_score_style(table_data, stored_weights) - style, tooltip_data = apply_levels(scored_rows, style) - return scored_rows, style, tooltip_data, scored_rows + scored_rows, _ = update_score_style(source_data, stored_weights) + filtered_rows = filter_rows_by_models(scored_rows, selected_models) + style = get_table_style(filtered_rows) if filtered_rows else [] + style, tooltip_data = apply_levels(filtered_rows, style) + return filtered_rows, style, tooltip_data, scored_rows @callback( Output("summary-table-scores-store", "data", allow_duplicate=True), Input(table_id, "data"), State("summary-table-scores-store", "data"), + State(f"{table_id}-computed-store", "data"), prevent_initial_call="initial_duplicate", ) def update_scores_store( table_data: list[dict], scores_data: dict[str, dict[str, float]], + computed_rows: list[dict] | None, ) -> dict[str, dict[str, float]]: """ Update stored scores values when weights update. @@ -310,6 +347,8 @@ def update_scores_store( Data from `table_id` to be updated. scores_data Dictionary of scores for each tab. + computed_rows + Cached unfiltered rows for the category summary. Returns ------- @@ -320,12 +359,16 @@ def update_scores_store( if not table_id.endswith("-summary-table"): return scores_data + source_rows = computed_rows or table_data + if not source_rows: + return scores_data + if not scores_data: scores_data = {} # Update scores store. Category table IDs are of form "[category]-summary-table" # Table headings are of the form "[category] Score" scores_data[table_id.removesuffix("-summary-table") + " Score"] = { - row["MLIP"]: row["Score"] for row in table_data + row["MLIP"]: row["Score"] for row in source_rows if row.get("MLIP") } return scores_data @@ -363,6 +406,7 @@ def register_benchmark_to_category_callback( Output(f"{category_table_id}-computed-store", "data", allow_duplicate=True), Input(f"{benchmark_table_id}-computed-store", "data"), Input("all-tabs", "value"), + Input("selected-models-store", "data"), State(category_table_id, "data"), State(f"{category_table_id}-weight-store", "data"), State(f"{category_table_id}-computed-store", "data"), @@ -371,6 +415,7 @@ def register_benchmark_to_category_callback( def update_category_from_benchmark( benchmark_computed_store: list[dict] | None, _tabs_value: str, + selected_models: list[str] | None, category_data: list[dict] | None, category_weights: dict[str, float] | None, category_computed_store: list[dict] | None, @@ -384,6 +429,8 @@ def update_category_from_benchmark( Latest scored benchmark rows emitted by the benchmark table. _tabs_value Current tab identifier (unused, required to trigger on tab change). + selected_models + List of model names currently selected in the model filter. category_data Existing category table rows shown to the user. category_weights @@ -400,6 +447,8 @@ def update_category_from_benchmark( category_rows = category_computed_store or category_data if not category_rows: raise PreventUpdate + if not benchmark_computed_store: + raise PreventUpdate benchmark_scores: dict[str, float] = {} for row in benchmark_computed_store: @@ -416,7 +465,9 @@ def update_category_from_benchmark( row[benchmark_column] = benchmark_scores[mlip] category_rows, style = update_score_style(category_rows, category_weights) - return category_rows, style, category_rows + filtered_rows = filter_rows_by_models(category_rows, selected_models) + filtered_style = get_table_style(filtered_rows) if filtered_rows else [] + return filtered_rows, filtered_style, category_rows def register_weight_callbacks( diff --git a/ml_peg/app/utils/utils.py b/ml_peg/app/utils/utils.py index 1b408c3df..630600005 100644 --- a/ml_peg/app/utils/utils.py +++ b/ml_peg/app/utils/utils.py @@ -189,6 +189,40 @@ def clean_weights(raw_weights: dict[str, float] | None) -> dict[str, float]: return weights +def filter_rows_by_models( + rows: list[dict] | None, + selected_models: Sequence[str] | None, +) -> list[dict]: + """ + Filter table rows to only those whose model identifier is selected. + + Parameters + ---------- + rows + Table rows containing an ``MLIP`` display name and optionally an ``id`` + canonical model key. + selected_models + Model identifiers to keep. ``None`` returns the original rows unchanged. + + Returns + ------- + list[dict] + Filtered rows preserving original order. + """ + if not rows: + return [] + if selected_models is None: + return rows + selected = {m for m in selected_models if m} + if not selected: + return [] + return [ + row + for row in rows + if (row.get("MLIP") in selected) or (row.get("id") in selected) + ] + + def get_scores( raw_rows: list[dict], scored_rows: list[dict],