diff --git a/bbttest/bbt/_types.py b/bbttest/bbt/_types.py index deac7ea..d04abf7 100644 --- a/bbttest/bbt/_types.py +++ b/bbttest/bbt/_types.py @@ -24,6 +24,11 @@ "strong_interpretation_raw", ] +InterpretationTypes = Literal[ + "weak", + "strong", +] + ALL_PROPERTIES_COLUMNS: list[ReportedPropertyColumnType] = list( get_args(ReportedPropertyColumnType) ) diff --git a/bbttest/bbt/alg.py b/bbttest/bbt/alg.py index d441cd2..73df17b 100644 --- a/bbttest/bbt/alg.py +++ b/bbttest/bbt/alg.py @@ -7,8 +7,6 @@ import pandas as pd from tqdm.auto import tqdm -from .const import UNNAMED_COLUMNS_WARNING_TEMPLATE - ALG1_COL = 2 ALG2_COL = 3 TIE_COL = 4 @@ -16,6 +14,12 @@ logger = log.getLogger(__name__) +UNNAMED_COLUMNS_WARNING_TEMPLATE = """Some algorithm names are unnamed. This may lead to issues in the win table construction. +Algorithm names extracted: {algorithms_names} +Dataset column: {dataset_col} +""" + + def _gen_pairs(no_algs: int) -> Generator[tuple[int, int, int], None, None]: k = 0 for i in range(no_algs): diff --git a/bbttest/bbt/const.py b/bbttest/bbt/const.py deleted file mode 100644 index 92ad316..0000000 --- a/bbttest/bbt/const.py +++ /dev/null @@ -1,4 +0,0 @@ -UNNAMED_COLUMNS_WARNING_TEMPLATE = """Some algorithm names are unnamed. This may lead to issues in the win table construction. -Algorithm names extracted: {algorithms_names} -Dataset column: {dataset_col} -""" diff --git a/bbttest/bbt/plots/__init__.py b/bbttest/bbt/plots/__init__.py new file mode 100644 index 0000000..e338db3 --- /dev/null +++ b/bbttest/bbt/plots/__init__.py @@ -0,0 +1,3 @@ +from ._critical_difference import plot_cdd_diagram + +__all__ = ["plot_cdd_diagram"] diff --git a/bbttest/bbt/plots/_critical_difference.py b/bbttest/bbt/plots/_critical_difference.py new file mode 100644 index 0000000..4a450b6 --- /dev/null +++ b/bbttest/bbt/plots/_critical_difference.py @@ -0,0 +1,232 @@ +import warnings + +import matplotlib.pyplot as plt +import networkx as nx +import pandas as pd + +NO_EQUIVALENCE_CLIQUES_WARNING_TEMPLATE = """No groups of equivalent algorithms were found in the posterior table. +CDD plot will not contain any equivalence bars.""" + + +def get_bars_for_cdd( + posterior_df: pd.DataFrame, + models_df: pd.DataFrame, + interpretation_col: str, +) -> list[tuple[int, int]]: + """Calculate equivalence bars using the equivalence cliques in the posterior table.""" + # Construct Graph and find the cliques + g = nx.Graph() + + posterior_models = set(posterior_df["left_model"]) | set( + posterior_df["right_model"] + ) + if posterior_models != set(models_df["model"]): + raise ValueError( + "The models in the posterior table do not match the models in the models table." + ) + + for _, row in posterior_df.iterrows(): + left = row["left_model"] + right = row["right_model"] + equiv = row[interpretation_col] == "=" + if equiv: + g.add_edge(left, right) + + cliques = list(nx.find_cliques(g)) + + # Map cliques to bars + res = [] + + for clique in cliques: + clique_pos = models_df.loc[models_df["model"].isin(clique), "pos"] + res.append((clique_pos.min(), clique_pos.max())) + + return res + + +def assign_bar_position( + bars: list[tuple[int, int]], min_distance: int = 1 +) -> list[int]: + """Order the bars vertically to minimize the size of the plot.""" + if len(bars) == 0: + return [] + + indexed_bars = [ + ( + i, + start - min_distance, + end + min_distance, + ) # add min distance to the bar sizes + for i, (start, end) in enumerate(bars) + ] + + rows: list[tuple[int, int]] = [] + rows_assignments = [0] * len(indexed_bars) + + for task_idx, start, end in indexed_bars: + assigned = False + for i, (row_end_value, row_id) in enumerate(rows): + if row_end_value < start: + # This row is available + rows[i] = (end, row_id) + rows_assignments[task_idx] = row_id + assigned = True + break + if not assigned: + # No rows are available, create a new one + new_row_id = len(rows) + rows.append((end, new_row_id)) + rows_assignments[task_idx] = new_row_id + + return rows_assignments + + +def _plot_cdd_diagram( + models_df: pd.DataFrame, + bars: list[tuple[int, int]], + bars_positions: list[int], + bar_y_spacing: float = 0.12, + ax: plt.Axes | None = None, + xlabel_spacing: int = 5, + draw_equivalence_lines_to_axis: bool = True, +) -> plt.Axes: + """Plot a critical difference diagram.""" + if ax is None: + _, ax = plt.subplots() + + n_models = len(models_df) + + # Ruler at the top + ruler_y = 0 + ax.hlines(ruler_y, 0.5, n_models + 0.5, color="black", linewidth=2) + + # Add ticks for each model + for _, row in models_df.iterrows(): + pos = row["pos"] + name = row["model"] + # Invert so rank 1 is on the right + inv_pos = n_models - pos + 1 + + ax.vlines(inv_pos, ruler_y, ruler_y + 0.15, color="black", linewidth=1.2) + ax.text( + inv_pos, + ruler_y + 0.2, + name, + ha="left", + va="bottom", + fontsize=8, + rotation=45, + ) + + if len(bars) == 0: + warnings.warn(NO_EQUIVALENCE_CLIQUES_WARNING_TEMPLATE, UserWarning) + max_bar_pos = 0 + else: + max_bar_pos = max(bars_positions) + # Draw equivalence bars + for i, (min_pos, max_pos) in enumerate(bars): + bar_y = ruler_y - 0.4 - bars_positions[i] * bar_y_spacing + + inv_min = n_models - max_pos + 1 + inv_max = n_models - min_pos + 1 + + ax.hlines(bar_y, inv_min, inv_max, color="black", linewidth=2.5) + + if draw_equivalence_lines_to_axis: + ax.vlines(inv_min, bar_y, -0.25, color="black", linewidth=0.5) + ax.vlines(inv_max, bar_y, -0.25, color="black", linewidth=0.5) + else: + ax.vlines(inv_min, bar_y, bar_y + 0.05, color="black", linewidth=1.5) + ax.vlines(inv_max, bar_y, bar_y + 0.05, color="black", linewidth=1.5) + + # Add rank numbers - first and last manually + ax.text( + 1, + ruler_y - 0.1, + str(n_models), + ha="center", + va="top", + fontsize=8, + fontweight="bold", + ) + ax.text( + n_models, + ruler_y - 0.1, + "1", + ha="center", + va="top", + fontsize=8, + fontweight="bold", + ) + + for i in range(xlabel_spacing + 1, n_models, xlabel_spacing): + inv_pos = n_models - i + 1 + ax.text(inv_pos, ruler_y - 0.1, str(i), ha="center", va="top", fontsize=8) + + # Clip axes + min_bar_y = ruler_y - 0.4 - max_bar_pos * bar_y_spacing + ax.set_xlim(0, n_models + 1) + ax.set_ylim(min_bar_y - 0.3, 2.5) + ax.axis("off") + + # Legend + ax.text( + 0.5, + min_bar_y - 0.1, + "← worse better →", + fontsize=8, + style="italic", + ) + + return ax + + +def plot_cdd_diagram( + models_df: pd.DataFrame, + posterior_df: pd.DataFrame, + interpretation_col: str, + ax: plt.Axes | None = None, + bar_y_spacing: float = 0.12, + xlabel_spacing: int = 5, + draw_equivalence_lines_to_axis: bool = True, +) -> plt.Axes: + """Plot a critical difference diagram. + + Parameters + ---------- + models_df : pd.DataFrame + DataFrame containing model names and their ranks. Must have columns "model" and "pos + posterior_df : pd.DataFrame + DataFrame containing pairwise model comparisons and their interpretations. Must have columns "left_model", + "right_model", and the specified interpretation_col. + interpretation_col : str + Name of the column in posterior_df that contains the interpretation of model comparisons. + ax : plt.Axes, optional + Matplotlib Axes to plot on. If None, a new figure and axes will be created. + bar_y_spacing : float, optional + Vertical spacing between equivalence bars. Default is 0.12. + xlabel_spacing : int, optional + Spacing between x-axis labels. Default is 5. + draw_equivalence_lines_to_axis : bool, optional + Whether to draw equivalence lines to extend equivalence bars up to the axis. + If False, equivalence bars will not have vertical lines connecting them to + the axis. Default is True. + """ + if ax is not None and not isinstance(ax, plt.Axes): + raise ValueError("ax must be a matplotlib Axes object or None.") + + bars = get_bars_for_cdd( + posterior_df=posterior_df, + models_df=models_df, + interpretation_col=interpretation_col, + ) + bars_positions = assign_bar_position(bars) + return _plot_cdd_diagram( + models_df=models_df, + bars=bars, + bars_positions=bars_positions, + ax=ax, + bar_y_spacing=bar_y_spacing, + xlabel_spacing=xlabel_spacing, + draw_equivalence_lines_to_axis=draw_equivalence_lines_to_axis, + ) diff --git a/bbttest/bbt/py_bbt.py b/bbttest/bbt/py_bbt.py index c800c17..ae3b32b 100644 --- a/bbttest/bbt/py_bbt.py +++ b/bbttest/bbt/py_bbt.py @@ -1,18 +1,20 @@ from collections.abc import Iterable, Sequence -from typing import Literal +import matplotlib.pyplot as plt import numpy as np import pandas as pd from ._types import ( ALL_PROPERTIES_COLUMNS, HyperPriorType, + InterpretationTypes, ReportedPropertyColumnType, TieSolverType, ) from ._utils import _validate_params from .alg import _construct_win_table, _get_pwin, _hdi from .model import _mcmcbbt_pymc +from .plots import plot_cdd_diagram class PyBBT: @@ -111,6 +113,16 @@ def _check_if_fitted(self): if not self._fitted: raise RuntimeError("The model must be fitted before accessing this method.") + @staticmethod + def _get_interpretation_columns( + interpretation: InterpretationTypes, + ) -> ReportedPropertyColumnType: + return ( + "weak_interpretation_raw" + if interpretation == "weak" + else "strong_interpretation_raw" + ) + @property def fitted(self): """Whether the model has been fitted.""" @@ -163,6 +175,24 @@ def fit( return self + @property + def beta_ranking(self) -> dict[str, float]: + r""" + Get the $\beta$ values for each model. + + Beta values can be used for ranking the models globally from best to worst (higher beta indicates better performance). + However, they do not have a direct probabilistic interpretation like the pairwise probabilities obtained from the posterior table. + + Returns + ------- + dict[str, float] + Dictionary mapping model names to their posterior mean beta values. + """ + self._check_if_fitted() + beta = self._fit_posterior.posterior["beta"].to_numpy() + mean_beta = np.mean(beta.reshape(-1, beta.shape[-1]), axis=0) + return dict(zip(self._algorithms, mean_beta, strict=True)) + def posterior_table( self, rope_value: tuple[float, float] = (0.45, 0.55), @@ -275,7 +305,7 @@ def rope_comparison_control_table( rope_values: Sequence[tuple[float, float]], control_model: str, selected_models: Sequence[str] | None = None, - interpretation: Literal["weak", "strong"] = "weak", + interpretation: InterpretationTypes = "weak", return_as_array: bool = False, join_char: str = ", ", ) -> pd.DataFrame: @@ -307,6 +337,7 @@ def rope_comparison_control_table( """ self._check_if_fitted() records = [] + interpretation_col = self._get_interpretation_columns(interpretation) for rope in rope_values: posterior_df = self.posterior_table( rope_value=rope, @@ -324,11 +355,6 @@ def rope_comparison_control_table( worse_models: list[str] = [] unknown_models: list[str] = [] for _, row in posterior_df.iterrows(): - interpretation_col = ( - "weak_interpretation_raw" - if interpretation == "weak" - else "strong_interpretation_raw" - ) non_control_model = ( row["right_model"] if row["left_model"] == control_model @@ -374,3 +400,65 @@ def rope_comparison_control_table( ) result_df = pd.DataFrame.from_records(records) return result_df + + @_validate_params + def plot_cdd_diagram( + self, + rope_value: tuple[float, float] = (0.45, 0.55), + interpretation: InterpretationTypes = "weak", + ax: plt.Axes | None = None, + **kwargs, + ) -> plt.Axes: + """ + Plot the Critical Difference Diagram (CDD) based on the fitted BBT model. + + Critical Difference Diagram visualizes the global ranking of the models along + with the equivalence bars connecting models that are considered equivalent based on the specified BBT interpretation. + The global ranking is determined based on the posterior mean beta values for each model. + + Parameters + ---------- + rope_value : tuple[float, float], optional + Region of Practical Equivalence (ROPE) used to determine ties in the posterior table. Defaults to (0.45, 0.55). + interpretation : {"weak", "strong"}, optional + Type of interpretation to use for determining equivalence bars. Defaults to "weak". + ax : plt.Axes | None, optional + Matplotlib Axes to plot on. If None, a new figure and axes are created. Defaults to None. + **kwargs + Additional keyword arguments passed to the underlying plotting function. + See :func:`bbttest.bbt.plots.plot_cdd_diagram` for available parameters. + + Returns + ------- + plt.Axes + Matplotlib Axes containing the CDD plot. + """ + self._check_if_fitted() + interpretation_col = self._get_interpretation_columns(interpretation) + + model_ranking = self.beta_ranking + models_df = pd.DataFrame( + { + "model": list(model_ranking.keys()), + "beta": list(model_ranking.values()), + } + ) + models_df["pos"] = ( + models_df["beta"].rank(ascending=False, method="first").astype(int) + ) + models_df = models_df.sort_values("pos").reset_index(drop=True) + posterior_df = self.posterior_table( + rope_value=rope_value, + columns=( + "left_model", + "right_model", + interpretation_col, + ), + ) + return plot_cdd_diagram( + models_df=models_df, + posterior_df=posterior_df, + interpretation_col=interpretation_col, + ax=ax, + **kwargs, + ) diff --git a/pyproject.toml b/pyproject.toml index bbd70c2..696c55a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "pandas>=2.0.0", "pymc", "tqdm", + "networkx", + "matplotlib", ] [dependency-groups] diff --git a/tests/bbt/plots/__init__.py b/tests/bbt/plots/__init__.py new file mode 100644 index 0000000..3e732c1 --- /dev/null +++ b/tests/bbt/plots/__init__.py @@ -0,0 +1 @@ +"""Tests for bbt.plots module.""" diff --git a/tests/bbt/plots/test__critical_difference.py b/tests/bbt/plots/test__critical_difference.py new file mode 100644 index 0000000..475e801 --- /dev/null +++ b/tests/bbt/plots/test__critical_difference.py @@ -0,0 +1,112 @@ +import matplotlib.pyplot as plt +import pandas as pd +import pytest + +from bbttest.bbt.plots._critical_difference import ( + _plot_cdd_diagram, + assign_bar_position, + get_bars_for_cdd, +) + + +@pytest.fixture +def models_df() -> pd.DataFrame: + """Create a simple models DataFrame for testing.""" + return pd.DataFrame( + { + "model": ["A", "B", "C", "D"], + "pos": [1, 2, 3, 4], + } + ) + + +@pytest.fixture +def posterior_df() -> pd.DataFrame: + """Create a posterior DataFrame with a known equivalence structure. + + A, B, C form a clique (all equivalent); D is isolated. + """ + return pd.DataFrame( + [ + {"left_model": "A", "right_model": "B", "interp": "="}, + {"left_model": "B", "right_model": "C", "interp": "="}, + {"left_model": "A", "right_model": "C", "interp": "="}, + {"left_model": "C", "right_model": "D", "interp": "<"}, + ] + ) + + +class TestGetBarsForCDD: + """Test equivalence bar extraction from the posterior table.""" + + def test_single_clique( + self, models_df: pd.DataFrame, posterior_df: pd.DataFrame + ) -> None: + """Test that a single equivalence clique produces one bar spanning the correct positions.""" + bars = get_bars_for_cdd( + posterior_df=posterior_df, + models_df=models_df, + interpretation_col="interp", + ) + # Only one equivalence group: models A (pos 1), B (pos 2), C (pos 3) + # so we expect a single bar spanning from 1 to 3. + assert len(bars) == 1 + assert bars[0] == (1, 3) + + +class TestAssignBarPosition: + """Test vertical bar positioning for the CDD plot.""" + + def test_non_overlapping(self) -> None: + """Check that non-overlapping bars are placed on the same row.""" + bars = [(0, 1), (2, 3), (4, 5)] + positions = assign_bar_position(bars, min_distance=0) + # All bars are disjoint; the greedy algorithm should be able to place + # them all on the same row. + assert len(positions) == len(bars) + assert set(positions) == {0} + + def test_overlapping(self) -> None: + """Check that overlapping bars are not placed on the same row.""" + # Bar 0 overlaps with bar 1, bar 1 overlaps with bar 2 + bars = [(0, 3), (2, 5), (4, 7)] + positions = assign_bar_position(bars, min_distance=0) + assert len(positions) == len(bars) + # At least two rows are required for these overlapping intervals. + assert max(positions) >= 1 + # Overlapping bars should not share the same row id. + for i in range(len(bars)): + for j in range(i + 1, len(bars)): + s1, e1 = bars[i] + s2, e2 = bars[j] + if not (e1 <= s2 or e2 <= s1): + # Bars i and j overlap; they must be on different rows. + assert positions[i] != positions[j] + + +class TestPlotCDDDiagram: + """Test the CDD diagram plotting function.""" + + def test_smoke(self) -> None: + """Ensure _plot_cdd_diagram runs without error and returns an Axes.""" + models_df = pd.DataFrame( + { + "model": ["A", "B", "C"], + "pos": [1, 2, 3], + "mean": [0.1, 0.2, 0.3], + } + ) + # A single bar spanning all three models on row 0 + bars = [(1, 3)] + bars_positions = [0] + fig, ax = plt.subplots() + try: + result_ax = _plot_cdd_diagram( + models_df=models_df, + bars=bars, + bars_positions=bars_positions, + ax=ax, + ) + finally: + plt.close(fig) + assert isinstance(result_ax, plt.Axes) diff --git a/uv.lock b/uv.lock index cb01286..d802741 100644 --- a/uv.lock +++ b/uv.lock @@ -148,6 +148,8 @@ name = "bbt-test" version = "0.0.0" source = { editable = "." } dependencies = [ + { name = "matplotlib" }, + { name = "networkx" }, { name = "pandas" }, { name = "pymc" }, { name = "tqdm" }, @@ -171,6 +173,8 @@ test = [ [package.metadata] requires-dist = [ + { name = "matplotlib" }, + { name = "networkx" }, { name = "pandas", specifier = ">=2.0.0" }, { name = "pymc" }, { name = "tqdm" }, @@ -1680,6 +1684,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, ] +[[package]] +name = "networkx" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1"