diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7c1ea62..3a32b93 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,8 +36,7 @@ jobs: key: ${{ matrix.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/uv.lock') }} - name: Install the project dependencies - run: uv sync --group test --python "$(python -c 'import sys; print(sys.executable)')" - shell: bash + run: uv sync --group test - name: Check pre-commit run: uv run pre-commit run --all-files diff --git a/.gitignore b/.gitignore index f86617a..b282f62 100644 --- a/.gitignore +++ b/.gitignore @@ -8,9 +8,12 @@ wheels/ # Dev cache .ruff_cache/ +.pytest_cache/ # Virtual environments .venv # IDE files .idea/ +.vscode/ + diff --git a/Makefile b/Makefile index 06e235b..ecf31b7 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: setup help +.PHONY: setup test test-coverage help .DEFAULT_GOAL := help setup: ## Install development dependencies @@ -6,11 +6,19 @@ setup: ## Install development dependencies @uv --version >/dev/null 2>&1 || (echo "uv is not installed, please install it" && exit 1) @# install dependencies - uv sync --group dev + uv sync --group dev --group test uv run pre-commit install -test: +test: ## Run tests without regression uv run ruff check + uv run pytest tests -m "not slow" + +full-test: + uv run ruff check + uv run pytest tests + +test-coverage: ## Run tests and calculate test coverage + uv run pytest --cov=bbttest tests help: @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md index c825155..98ff6db 100644 --- a/README.md +++ b/README.md @@ -87,28 +87,33 @@ Once you obtained a fitted PyBBT model, you can generate statistic dataframe con ```python -stats_df = model.get_stats_dataframe( +stats_df = model.posterior_table( rope_value=(0.45, 0.55), # Defines ROPE of hypothesis for interpretations control_model="alg1", # If provided, only hypotheses comparing to control_model will be included - selected_models=["alg2", "alg3"], # If provided, only hypotheses comparing selected_models will be included + selected_models=["alg2"], # If provided, only hypotheses comparing selected_models will be included ) print(stats_df) + + pair mean delta above_50 in_rope weak_interpretation +0 alg1 > alg2 0.63 0.53 0.75 0.19 Unknown ``` Additionally, you can generate multiple hypothesis interpretations regarding control model for different ROPE values: ```python -from bbttest import multiple_ropes_control_table - -stats_df = multiple_ropes_control_table( - model, - ropes=[(0.4, 0.6), (0.45, 0.55), (0.48, 0.52)], +stats_df = model.rope_comparison_control_table( + rope_values=[(0.4, 0.6), (0.45, 0.55), (0.48, 0.52)], control_model="alg1", - interpretation_type="weak", + interpretation="weak", ) print(stats_df) + +rope_value better_models equivalent_models worse_models unknown_models +0 (0.4, 0.6) alg3, alg1 +1 (0.45, 0.55) alg3, alg1 +2 (0.48, 0.52) alg3, alg1 ``` ## License diff --git a/bbttest/__init__.py b/bbttest/__init__.py index 28989cd..2776693 100644 --- a/bbttest/__init__.py +++ b/bbttest/__init__.py @@ -1,13 +1,10 @@ """bbt-test: Bayesian Bradley-Terry model for algorithm comparison.""" -from .const import HyperPrior, ReportedProperty, TieSolver -from .py_bbt import PyBBT -from .utils import multiple_ropes_control_table +from .bbt import HyperPrior, PyBBT, ReportedProperty, TieSolver __all__ = [ "HyperPrior", "PyBBT", "ReportedProperty", "TieSolver", - "multiple_ropes_control_table", ] diff --git a/bbttest/bbt/__init__.py b/bbttest/bbt/__init__.py new file mode 100644 index 0000000..16f29c2 --- /dev/null +++ b/bbttest/bbt/__init__.py @@ -0,0 +1,11 @@ +"""bbt module: Bayesian Bradley-Terry model implementation.""" + +from .const import HyperPrior, ReportedProperty, TieSolver +from .py_bbt import PyBBT + +__all__ = [ + "HyperPrior", + "PyBBT", + "ReportedProperty", + "TieSolver", +] diff --git a/bbttest/alg.py b/bbttest/bbt/alg.py similarity index 96% rename from bbttest/alg.py rename to bbttest/bbt/alg.py index 9855561..a55653b 100644 --- a/bbttest/alg.py +++ b/bbttest/bbt/alg.py @@ -1,5 +1,5 @@ import logging as log -from collections.abc import Generator +from collections.abc import Generator, Iterable import arviz as az import numpy as np @@ -158,9 +158,9 @@ def _construct_win_table( def _get_pwin( bbt_result: az.InferenceData, - alg_names: list[str] | None = None, + alg_names: Iterable[str] | None = None, control: str | None = None, - selected: list[str] | None = None, + selected: Iterable[str] | None = None, ): def _pairwise_prob(strength_i, strength_j): return strength_i / (strength_i + strength_j) @@ -183,6 +183,9 @@ def _pairwise_prob(strength_i, strength_j): # Filter by selected algorithms if specified if selected is not None: selected_set = set(selected) + if control not in selected_set and control is not None: + selected_set.add(control) + indices = [i for i, name in enumerate(ordered_names) if name in selected_set] ordered_names = ordered_names[indices] strengths = strengths[:, indices] diff --git a/bbttest/const.py b/bbttest/bbt/const.py similarity index 67% rename from bbttest/const.py rename to bbttest/bbt/const.py index 0680ff8..f5d8044 100644 --- a/bbttest/const.py +++ b/bbttest/bbt/const.py @@ -14,16 +14,17 @@ class HyperPrior(str, Enum): NORMAL = "normal" def _get_pymc_dist(self, scale, name="sigma"): - if self == HyperPrior.LOG_NORMAL: - return LogNormal(name, mu=0, sigma=1) - elif self == HyperPrior.LOG_NORMAL_SCALED: - return LogNormal(name, mu=0, sigma=scale) - elif self == HyperPrior.CAUCHY: - return Cauchy(name, alpha=0, beta=scale) - elif self == HyperPrior.NORMAL: - return Normal(name, mu=0, sigma=scale) - else: - raise ValueError(f"Unsupported hyperprior: {self}") + match self: + case HyperPrior.LOG_NORMAL: + return LogNormal(name, mu=0, sigma=1) + case HyperPrior.LOG_NORMAL_SCALED: + return LogNormal(name, mu=0, sigma=scale) + case HyperPrior.CAUCHY: + return Cauchy(name, alpha=0, beta=scale) + case HyperPrior.NORMAL: + return Normal(name, mu=0, sigma=scale) + case _: + raise ValueError(f"Unsupported hyperprior: {self}") class ReportedProperty(str, Enum): @@ -31,6 +32,8 @@ class ReportedProperty(str, Enum): Enum containing properties that can be reported from BBT results. """ + LEFT_MODEL = "left_model" + RIGHT_MODEL = "right_model" MEDIAN = "median" MEAN = "mean" HDI_LOW = "hdi_low" @@ -65,3 +68,5 @@ class TieSolver(str, Enum): ReportedProperty.IN_ROPE, ReportedProperty.WEAK_INTERPRETATION, ) + +ALL_PROPERTIES = tuple(ReportedProperty) diff --git a/bbttest/model.py b/bbttest/bbt/model.py similarity index 99% rename from bbttest/model.py rename to bbttest/bbt/model.py index 3d8137c..105217f 100644 --- a/bbttest/model.py +++ b/bbttest/bbt/model.py @@ -128,7 +128,7 @@ def _mcmcbbt_pymc( sample_kwargs = { k: v for k, v in kwargs.items() - if k in ["draws", "tune", "chains", "cores", "target_accept"] + if k in ["draws", "tune", "chains", "cores", "target_accept", "random_seed"] } fit = pm.sample(**sample_kwargs) diff --git a/bbttest/py_bbt.py b/bbttest/bbt/py_bbt.py similarity index 64% rename from bbttest/py_bbt.py rename to bbttest/bbt/py_bbt.py index fd3140f..44f677f 100644 --- a/bbttest/py_bbt.py +++ b/bbttest/bbt/py_bbt.py @@ -1,4 +1,5 @@ -from collections.abc import Sequence +from collections.abc import Iterable, Sequence +from typing import Literal import numpy as np import pandas as pd @@ -75,14 +76,18 @@ class PyBBT: def __init__( self, local_rope_value: float | None = None, - tie_solver: TieSolver = TieSolver.SPREAD, - hyper_prior: HyperPrior = HyperPrior.LOG_NORMAL, + tie_solver: TieSolver | str = TieSolver.SPREAD, + hyper_prior: HyperPrior | str = HyperPrior.LOG_NORMAL, scale: float = 1.0, ): self._local_rope_value = local_rope_value - self._tie_solver = tie_solver - self._use_davidson = tie_solver == TieSolver.DAVIDSON - self._hyper_prior = hyper_prior + self._tie_solver = ( + TieSolver(tie_solver) if isinstance(tie_solver, str) else tie_solver + ) + self._use_davidson = self._tie_solver == TieSolver.DAVIDSON + self._hyper_prior = ( + HyperPrior(hyper_prior) if isinstance(hyper_prior, str) else hyper_prior + ) self._scale = scale self._fitted = False @@ -140,8 +145,8 @@ def posterior_table( self, rope_value: tuple[float, float] = (0.45, 0.55), control_model: str | None = None, - selected_models: list[str] | None = None, - columns: Sequence[ReportedProperty | str] = DEFAULT_PROPERTIES, + selected_models: Iterable[str] | None = None, + columns: Iterable[ReportedProperty | str] = DEFAULT_PROPERTIES, hdi_proba: float = 0.89, round_ndigits: int | None = 2, ) -> pd.DataFrame: @@ -165,7 +170,7 @@ def posterior_table( bbt_result=self._fit_posterior, alg_names=self._algorithms, control=control_model, - selected=selected_models, + selected=list(selected_models) if selected_models is not None else None, ) out_table = pd.DataFrame({"pair": names}) out_table["left_model"] = out_table["pair"].str.split(">").str[0].str.strip() @@ -220,10 +225,100 @@ def posterior_table( out_table["delta"] = out_table["hdi_high"] - out_table["hdi_low"] if round_ndigits is not None: - return out_table.round(round_ndigits) + return out_table.round(round_ndigits)[["pair", *columns]] for col in columns: if col not in out_table.columns: raise ValueError( f"Column {col} is not available in the posterior table." ) return out_table[["pair", *columns]] + + def rope_comparison_control_table( + self, + rope_values: Sequence[tuple[float, float]], + control_model: str, + selected_models: Sequence[str] | None = None, + interpretation: Literal["weak", "strong"] = "weak", + return_as_array: bool = False, + join_char: str = ", ", + ) -> pd.DataFrame: + """ + Construct a table comparing models against predefined control models across multiple ROPEs. + The output table contains N rows (one per ROPE) and 5 columns + (rope value, better models, equivalent models, worse models, unknown models). + + Args: + model: Fitted PyBBT model. + ropes: List of ROPE tuples to evaluate. + interpretation: Type of interpretation to use ("weak" or "strong"), see [1]_. + return_as_array: Whether the individual cells should contain model names as list or joined into single string. + join_char: Character(s) used to join multiple model names in a single cell. + + Returns + ------- + pd.DataFrame: Table comparing models against control models across multiple ROPEs. + + References + ---------- + .. [1] `Jacques Wainer + "A Bayesian Bradley-Terry model to compare multiple ML algorithms on multiple data sets" + Journal of Machine Learning Research 24 (2023): 1-34 + `_ + """ + self._check_if_fitted() + records = [] + for rope in rope_values: + posterior_df = self.posterior_table( + rope_value=rope, + control_model=control_model, + selected_models=selected_models, + columns=[ + ReportedProperty.LEFT_MODEL, + ReportedProperty.WEAK_INTERPRETATION, + ReportedProperty.STRONG_INTERPRETATION, + ], + ) + better_models: list[str] = [] + equivalent_models: list[str] = [] + worse_models: list[str] = [] + unknown_models: list[str] = [] + for _, row in posterior_df.iterrows(): + interpretation_col = ( + "weak_interpretation" + if interpretation == "weak" + else "strong_interpretation" + ) + if row[interpretation_col] == f"{row['left_model']} better": + better_models.append(row["left_model"]) + elif row[interpretation_col] == "Equivalent": + equivalent_models.append(row["left_model"]) + elif row[interpretation_col] == "Unknown": + unknown_models.append(row["left_model"]) + else: + worse_models.append(row["left_model"]) + if not return_as_array: + better_models_str = join_char.join(better_models) + equivalent_models_str = join_char.join(equivalent_models) + worse_models_str = join_char.join(worse_models) + unknown_models_str = join_char.join(unknown_models) + records.append( + { + "rope_value": rope, + "better_models": better_models_str, + "equivalent_models": equivalent_models_str, + "worse_models": worse_models_str, + "unknown_models": unknown_models_str, + } + ) + else: + records.append( + { + "rope_value": rope, + "better_models": better_models, + "equivalent_models": equivalent_models, + "worse_models": worse_models, + "unknown_models": unknown_models, + } + ) + result_df = pd.DataFrame.from_records(records) + return result_df diff --git a/bbttest/utils.py b/bbttest/utils.py deleted file mode 100644 index eba1776..0000000 --- a/bbttest/utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Literal - -import pandas as pd - -from .py_bbt import PyBBT - - -def multiple_ropes_control_table( - model: PyBBT, - ropes: list[tuple[float, float]], - control_model: str, - selected_models: list[str] | None = None, - interpretation: Literal["weak", "strong"] = "weak", - return_as_array: bool = False, - join_char: str = ", ", -) -> pd.DataFrame: - """ - Construct a table comparing models against predefined control models across multiple ROPEs. - The output table contains N rows (one per ROPE) and 5 columns - (rope value, better models, equivalent models, worse models, unknown models). - - Args: - model: Fitted PyBBT model. - ropes: List of ROPE tuples to evaluate. - interpretation: Type of interpretation to use ("weak" or "strong"), see [1]_. - return_as_array: Whether the individual cells should contain model names as list or joined into single string. - join_char: Character(s) used to join multiple model names in a single cell. - - Returns - ------- - pd.DataFrame: Table comparing models against control models across multiple ROPEs. - - References - ---------- - .. [1] `Jacques Wainer - "A Bayesian Bradley-Terry model to compare multiple ML algorithms on multiple data sets" - Journal of Machine Learning Research 24 (2023): 1-34 - `_ - """ - rows = [] - interpretation_col = f"{interpretation}_interpretation_raw" - for rope in ropes: - post_table = model.posterior_table( - rope_value=rope, - columns=["left_model", "right_model", interpretation_col], - control_model=control_model, - selected_models=selected_models, - ) - better_models = list() - equivalent_models = list() - worse_models = list() - unknown_models = list() - - for _, row in post_table.iterrows(): - decision = row[interpretation_col] - if decision == ">": - if row["left_model"] == control_model: - worse_models.append(row["right_model"]) - else: - better_models.append(row["left_model"]) - elif decision == "=": - if row["left_model"] == control_model: - equivalent_models.append(row["right_model"]) - else: - equivalent_models.append(row["left_model"]) - elif row["left_model"] == control_model: - unknown_models.append(row["right_model"]) - else: - unknown_models.append(row["left_model"]) - - rows.append( - { - "rope": rope, - "better_models": better_models - if return_as_array - else join_char.join(better_models), - "equivalent_models": equivalent_models - if return_as_array - else join_char.join(equivalent_models), - "worse_models": worse_models - if return_as_array - else join_char.join(worse_models), - "unknown_models": unknown_models - if return_as_array - else join_char.join(unknown_models), - } - ) - - return pd.DataFrame(rows) diff --git a/tests/bbt/__init__.py b/tests/bbt/__init__.py new file mode 100644 index 0000000..7d9169f --- /dev/null +++ b/tests/bbt/__init__.py @@ -0,0 +1 @@ +"""Tests for bbt package.""" diff --git a/tests/bbt/test_py_bbt.py b/tests/bbt/test_py_bbt.py new file mode 100644 index 0000000..b83fc69 --- /dev/null +++ b/tests/bbt/test_py_bbt.py @@ -0,0 +1,612 @@ +""" +Unit tests for PyBBT class. + +This module contains unit tests for the PyBBT class, testing various +functionality including model fitting, posterior table generation, +ROPE comparison tables, and parameter validation. +""" + +import numpy as np +import pandas as pd +import pytest + +from bbttest import HyperPrior, PyBBT, ReportedProperty, TieSolver +from bbttest.bbt.const import ALL_PROPERTIES + + +@pytest.fixture(scope="module") +def mock_data(): + """ + Create simple mock data for testing. + + Returns + ------- + pd.DataFrame + Mock dataset with 3 datasets and 3 models. + """ + return pd.DataFrame( + { + "dataset": ["ds1", "ds2", "ds3"], + "model_a": [0.8, 0.75, 0.9], + "model_b": [0.7, 0.8, 0.85], + "model_c": [0.6, 0.65, 0.7], + } + ) + + +@pytest.fixture(scope="module") +def fitted_model(mock_data): + """ + Create a fitted PyBBT model for testing. + + Parameters + ---------- + mock_data : pd.DataFrame + Mock data fixture. + + Returns + ------- + PyBBT + Fitted PyBBT model instance. + """ + model = PyBBT(local_rope_value=0.01, tie_solver=TieSolver.SPREAD) + model.fit( + mock_data, + dataset_col="dataset", + draws=100, + tune=100, + chains=2, + random_seed=42, + ) + return model + + +class TestPyBBTInitialization: + """Test PyBBT initialization and parameter validation.""" + + def test_init_with_enum_parameters(self): + """Test that PyBBT can be initialized with enum parameters.""" + model = PyBBT( + local_rope_value=0.01, + tie_solver=TieSolver.SPREAD, + hyper_prior=HyperPrior.LOG_NORMAL, + scale=1.0, + ) + assert model._local_rope_value == 0.01 + assert model._tie_solver == TieSolver.SPREAD + assert model._hyper_prior == HyperPrior.LOG_NORMAL + assert model._scale == 1.0 + assert not model.fitted + + def test_init_with_string_parameters(self): + """Test that PyBBT can be initialized with string parameters that are cast to enums.""" + model = PyBBT( + local_rope_value=0.01, + tie_solver="spread", + hyper_prior="logNormal", + scale=1.0, + ) + # Verify string values are accepted and work correctly + assert model._local_rope_value == 0.01 + assert model._tie_solver == TieSolver.SPREAD + assert model._hyper_prior == HyperPrior.LOG_NORMAL + assert model._scale == 1.0 + assert not model.fitted + + @pytest.mark.parametrize( + "arg, expected", + [ + ("add", TieSolver.ADD), + ("spread", TieSolver.SPREAD), + ("forget", TieSolver.FORGET), + ("davidson", TieSolver.DAVIDSON), + (TieSolver.ADD, TieSolver.ADD), + (TieSolver.SPREAD, TieSolver.SPREAD), + (TieSolver.FORGET, TieSolver.FORGET), + (TieSolver.DAVIDSON, TieSolver.DAVIDSON), + ], + ) + def test_init_with_different_tie_solvers(self, arg, expected): + """Test initialization with different TieSolver values.""" + model = PyBBT(tie_solver=arg) + assert model._tie_solver == expected + + @pytest.mark.parametrize( + "arg, expected", + [ + ("logNormal", HyperPrior.LOG_NORMAL), + ("logNormalScaled", HyperPrior.LOG_NORMAL_SCALED), + ("cauchy", HyperPrior.CAUCHY), + ("normal", HyperPrior.NORMAL), + (HyperPrior.LOG_NORMAL, HyperPrior.LOG_NORMAL), + (HyperPrior.LOG_NORMAL_SCALED, HyperPrior.LOG_NORMAL_SCALED), + (HyperPrior.CAUCHY, HyperPrior.CAUCHY), + (HyperPrior.NORMAL, HyperPrior.NORMAL), + ], + ) + def test_init_with_different_hyper_priors(self, arg, expected): + """Test initialization with different HyperPrior values.""" + model = PyBBT(hyper_prior=arg) + assert model._hyper_prior == expected + + def test_init_defaults(self): + """Test that default initialization values are set correctly.""" + model = PyBBT() + assert model._local_rope_value is None + assert model._tie_solver == TieSolver.SPREAD + assert model._hyper_prior == HyperPrior.LOG_NORMAL + assert model._scale == 1.0 + assert not model.fitted + + +class TestPyBBTFitting: + """Test PyBBT model fitting functionality.""" + + def test_fit_updates_fitted_property(self, mock_data): + """Test that fit() updates the fitted property.""" + model = PyBBT() + assert not model.fitted + model.fit(mock_data, dataset_col="dataset", draws=50, tune=50, chains=2) + assert model.fitted + + def test_fit_returns_self(self, mock_data): + """Test that fit() returns self for method chaining.""" + model = PyBBT() + result = model.fit( + mock_data, dataset_col="dataset", draws=50, tune=50, chains=2 + ) + assert result is model + + +class TestPyBBTUnfittedErrors: + """Test that methods raise errors when called on unfitted models.""" + + def test_posterior_table_without_fitting_raises_error(self): + """Test that posterior_table() raises error on unfitted model.""" + model = PyBBT() + with pytest.raises( + RuntimeError, match="The model must be fitted before accessing this method" + ): + model.posterior_table() + + def test_rope_comparison_control_table_without_fitting_raises_error(self): + """Test that rope_comparison_control_table() raises error on unfitted model.""" + model = PyBBT() + with pytest.raises( + RuntimeError, match="The model must be fitted before accessing this method" + ): + model.rope_comparison_control_table( + rope_values=[(0.45, 0.55)], control_model="model_a" + ) + + +class TestPosteriorTable: + """Test posterior_table method functionality.""" + + def test_posterior_table_has_required_columns(self, fitted_model): + """Test that posterior_table contains required columns.""" + result = fitted_model.posterior_table() + required_cols = ["pair", "mean", "delta", "above_50", "in_rope"] + for col in required_cols: + assert col in result.columns + + def test_posterior_table_weak_interpretation_values(self, fitted_model): + """Test that weak interpretation contains valid values.""" + result = fitted_model.posterior_table(rope_value=(0.45, 0.55)) + valid_values = {"Equivalent", "Unknown"} + # Weak interpretation should end with "better", be "Equivalent", or be "Unknown" + for interp in result["weak_interpretation"]: + assert interp in valid_values or interp.endswith(" better"), ( + f"Invalid weak interpretation: {interp}" + ) + + def test_posterior_table_strong_interpretation_values(self, fitted_model): + """Test that strong interpretation contains valid values.""" + result = fitted_model.posterior_table() + # Add strong_interpretation to columns + result = fitted_model.posterior_table( + columns=[ + ReportedProperty.MEAN, + ReportedProperty.STRONG_INTERPRETATION, + ] + ) + valid_values = {"Equivalent", "Unknown"} + for interp in result["strong_interpretation"]: + assert interp in valid_values or interp.endswith(" better"), ( + f"Invalid strong interpretation: {interp}" + ) + + def test_posterior_table_with_control_model(self, fitted_model): + """Test posterior_table with control_model parameter.""" + result = fitted_model.posterior_table( + control_model="model_a", columns=ALL_PROPERTIES + ) + assert len(result) > 0 + # All comparisons should involve model_a + for _, row in result.iterrows(): + assert row["left_model"] == "model_a" or row["right_model"] == "model_a", ( + f"Comparison {row['pair']} does not involve control model" + ) + + def test_posterior_table_returns_only_requested_columns(self, fitted_model): + """Test that posterior_table returns only requested columns.""" + requested_columns = [ReportedProperty.MEAN, ReportedProperty.DELTA] + result = fitted_model.posterior_table( + columns=requested_columns, round_ndigits=None + ) + + # Should have 'pair' column plus requested columns + expected_cols = ["pair", "mean", "delta"] + assert set(result.columns) == set(expected_cols) + + result = fitted_model.posterior_table( + columns=requested_columns, round_ndigits=3 + ) + assert set(result.columns) == set(expected_cols) + + def test_posterior_table_requested_columns_with_strings(self, fitted_model): + """Test that posterior_table accepts string column names.""" + requested_columns = ["mean", "delta", "above_50"] + # Must set round_ndigits=None to get column filtering + result = fitted_model.posterior_table( + columns=requested_columns, round_ndigits=None + ) + + expected_cols = ["pair", "mean", "delta", "above_50"] + assert set(result.columns) == set(expected_cols) + + def test_posterior_table_invalid_column_raises_error(self, fitted_model): + """Test that requesting invalid column raises ValueError.""" + with pytest.raises(ValueError, match="is not available in the posterior table"): + # Must set round_ndigits=None to trigger column validation + fitted_model.posterior_table(columns=["invalid_column"], round_ndigits=None) + + def test_posterior_table_rope_value_affects_in_rope(self, fitted_model): + """Test that changing ROPE value affects in_rope column.""" + result1 = fitted_model.posterior_table(rope_value=(0.4, 0.6)) + result2 = fitted_model.posterior_table(rope_value=(0.45, 0.55)) + + # Wider ROPE should generally have higher in_rope values + mean_in_rope_1 = result1["in_rope"].mean() + mean_in_rope_2 = result2["in_rope"].mean() + assert mean_in_rope_1 >= mean_in_rope_2 + + def test_posterior_table_rounding(self, fitted_model): + """Test that rounding parameter works correctly.""" + result_rounded = fitted_model.posterior_table(round_ndigits=2) + + # Check that rounded version has at most 2 decimal places + for col in ["mean", "delta"]: + if col in result_rounded.columns: + for val in result_rounded[col]: + if not pd.isna(val): + str_val = str(val) + if "." in str_val: + decimals = len(str_val.split(".")[1]) + assert decimals <= 2 + + +class TestRopeComparisonControlTable: + """Test rope_comparison_control_table method functionality.""" + + def test_rope_comparison_returns_dataframe(self, fitted_model): + """Test that rope_comparison_control_table returns a DataFrame.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55), (0.4, 0.6)], control_model="model_a" + ) + assert isinstance(result, pd.DataFrame) + assert len(result) == 2 # One row per ROPE + + def test_rope_comparison_has_required_columns(self, fitted_model): + """Test that rope_comparison_control_table has required columns.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55)], control_model="model_a" + ) + required_cols = [ + "rope_value", + "better_models", + "equivalent_models", + "worse_models", + "unknown_models", + ] + for col in required_cols: + assert col in result.columns + + def test_rope_comparison_weak_interpretation(self, fitted_model): + """Test rope_comparison_control_table with weak interpretation.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55), (0.4, 0.6)], + control_model="model_a", + interpretation="weak", + ) + assert len(result) == 2 + # Check that each row has the correct ROPE value + assert result.iloc[0]["rope_value"] == (0.45, 0.55) + assert result.iloc[1]["rope_value"] == (0.4, 0.6) + + def test_rope_comparison_strong_interpretation(self, fitted_model): + """Test rope_comparison_control_table with strong interpretation.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55)], + control_model="model_a", + interpretation="strong", + ) + assert len(result) == 1 + # Verify it runs without error and returns expected structure + assert "better_models" in result.columns + + def test_rope_comparison_return_as_array(self, fitted_model): + """Test rope_comparison_control_table with return_as_array=True.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55)], + control_model="model_a", + return_as_array=True, + ) + # When return_as_array=True, columns should contain lists + assert isinstance(result.iloc[0]["better_models"], list) + assert isinstance(result.iloc[0]["equivalent_models"], list) + assert isinstance(result.iloc[0]["worse_models"], list) + assert isinstance(result.iloc[0]["unknown_models"], list) + + def test_rope_comparison_return_as_string(self, fitted_model): + """Test rope_comparison_control_table with return_as_array=False.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55)], + control_model="model_a", + return_as_array=False, + ) + # When return_as_array=False, columns should contain strings + assert isinstance(result.iloc[0]["better_models"], str) + assert isinstance(result.iloc[0]["equivalent_models"], str) + assert isinstance(result.iloc[0]["worse_models"], str) + assert isinstance(result.iloc[0]["unknown_models"], str) + + def test_rope_comparison_custom_join_char(self, fitted_model): + """Test rope_comparison_control_table with custom join character.""" + result = fitted_model.rope_comparison_control_table( + rope_values=[(0.45, 0.55)], + control_model="model_a", + return_as_array=False, + join_char=" | ", + ) + # Check if custom join character is used (if there are multiple models) + for col in [ + "better_models", + "equivalent_models", + "worse_models", + "unknown_models", + ]: + value = result.iloc[0][col] + if " | " in value: + # Found the custom separator, test passes + assert True + return + + def test_rope_comparison_multiple_ropes(self, fitted_model): + """Test rope_comparison_control_table with multiple ROPE values.""" + rope_values = [(0.3, 0.7), (0.4, 0.6), (0.45, 0.55)] + result = fitted_model.rope_comparison_control_table( + rope_values=rope_values, control_model="model_a" + ) + assert len(result) == len(rope_values) + # Verify ROPE values are correctly stored + for i, rope in enumerate(rope_values): + assert result.iloc[i]["rope_value"] == rope + + +class TestPosteriorTableInterpretations: + """Test interpretation logic in posterior_table with mocked samples.""" + + def test_weak_interpretation_logic(self, fitted_model, monkeypatch): + """Test that weak interpretation follows expected logic with controlled samples.""" + # Create synthetic samples with known properties for testing interpretations + # Case 1: Model A > Model B - should be "A better" (96% above 0.5, only 3% in ROPE) + samples_a_better = np.concatenate( + [np.full(960, 0.8), np.full(40, 0.4)] + ) # 96% > 0.5, mean ~0.784 + + # Case 2: Model B > Model C - should be "Equivalent" (40% above 0.5, 96% in ROPE) + samples_equivalent = np.concatenate( + [np.full(400, 0.52), np.full(600, 0.48)] + ) # 40% > 0.5, mean = 0.496 + + # Case 3: Model C > Model D - should be "Unknown" (70% above 0.5, 80% in ROPE) + samples_unknown = np.concatenate( + [np.full(700, 0.6), np.full(300, 0.4)] + ) # 70% > 0.5, mean = 0.54 + + samples = np.column_stack( + [samples_a_better, samples_equivalent, samples_unknown] + ) + names = ["A > B", "B > C", "C > D"] + + # Mock _get_pwin to return our controlled samples + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table(rope_value=(0.45, 0.55)) + + # Test Case 1: A > B should be "A better" + row_a_b = result[result["pair"] == "A > B"].iloc[0] + assert row_a_b["weak_interpretation"] == "A better" + assert row_a_b["above_50"] >= 0.95 + assert row_a_b["in_rope"] < 0.95 + + # Test Case 2: B > C should be "Equivalent" + row_b_c = result[result["pair"] == "B > C"].iloc[0] + assert row_b_c["weak_interpretation"] == "Equivalent" + assert row_b_c["in_rope"] >= 0.95 + + # Test Case 3: C > D should be "Unknown" + row_c_d = result[result["pair"] == "C > D"].iloc[0] + assert row_c_d["weak_interpretation"] == "Unknown" + assert row_c_d["above_50"] < 0.95 + assert row_c_d["in_rope"] < 0.95 + + def test_strong_interpretation_logic(self, fitted_model, monkeypatch): + """Test that strong interpretation follows expected logic with controlled samples.""" + # Create synthetic samples with known properties for testing strong interpretations + # Case 1: Model A > Model B - mean > 0.70, should be "A better" + samples_a_better = np.full(1000, 0.75) # mean = 0.75 + + # Case 2: Model B > Model C - mean <= 0.55, should be "Equivalent" + samples_equivalent = np.full(1000, 0.50) # mean = 0.50 + + # Case 3: Model C > Model D - 0.55 < mean <= 0.70, should be "Unknown" + samples_unknown = np.full(1000, 0.62) # mean = 0.62 + + samples = np.column_stack( + [samples_a_better, samples_equivalent, samples_unknown] + ) + names = ["A > B", "B > C", "C > D"] + + # Mock _get_pwin to return our controlled samples + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table( + columns=[ + ReportedProperty.MEAN, + ReportedProperty.STRONG_INTERPRETATION, + ], + round_ndigits=None, + ) + + # Test Case 1: A > B should be "A better" (mean > 0.70) + row_a_b = result[result["pair"] == "A > B"].iloc[0] + assert row_a_b["strong_interpretation"] == "A better" + assert row_a_b["mean"] > 0.70 + + # Test Case 2: B > C should be "Equivalent" (mean <= 0.55) + row_b_c = result[result["pair"] == "B > C"].iloc[0] + assert row_b_c["strong_interpretation"] == "Equivalent" + assert row_b_c["mean"] <= 0.55 + + # Test Case 3: C > D should be "Unknown" (0.55 < mean <= 0.70) + row_c_d = result[result["pair"] == "C > D"].iloc[0] + assert row_c_d["strong_interpretation"] == "Unknown" + assert 0.55 < row_c_d["mean"] <= 0.70 + + +class TestPosteriorTableStructure: + """Test the structure and content of posterior_table output with mocked samples.""" + + def test_pair_column_format(self, fitted_model, monkeypatch): + """Test that pair column has correct format.""" + # Create simple samples for structure testing + samples = np.column_stack([np.full(100, 0.6), np.full(100, 0.5)]) + names = ["model_a > model_b", "model_b > model_c"] + + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table() + + for pair in result["pair"]: + assert " > " in pair, f"Pair {pair} does not contain ' > '" + parts = pair.split(" > ") + assert len(parts) == 2, f"Pair {pair} does not have exactly 2 parts" + + def test_left_right_model_consistency(self, fitted_model, monkeypatch): + """Test that left_model and right_model match the pair column.""" + samples = np.column_stack([np.full(100, 0.6), np.full(100, 0.5)]) + names = ["model_a > model_b", "model_b > model_c"] + + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table(columns=ALL_PROPERTIES) + + for _, row in result.iterrows(): + expected_pair = f"{row['left_model']} > {row['right_model']}" + assert row["pair"] == expected_pair + + def test_probability_values_in_range(self, fitted_model, monkeypatch): + """Test that probability values are in valid range [0, 1].""" + # Create samples with values that should produce probabilities in [0, 1] + samples = np.column_stack( + [ + np.random.uniform(0.3, 0.9, 100), + np.random.uniform(0.2, 0.8, 100), + np.random.uniform(0.1, 0.7, 100), + ] + ) + names = ["A > B", "B > C", "C > D"] + + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table() + + for col in ["mean", "median", "above_50", "in_rope"]: + if col in result.columns: + assert (result[col] >= 0).all() + assert (result[col] <= 1).all() + + def test_hdi_values_consistent(self, fitted_model, monkeypatch): + """Test that HDI values are consistent (low <= high).""" + # Create samples with varying distributions + samples = np.column_stack( + [ + np.random.beta(2, 5, 100), # Skewed distribution + np.random.beta(5, 5, 100), # Symmetric distribution + np.random.beta(8, 2, 100), # Right-skewed distribution + ] + ) + names = ["A > B", "B > C", "C > D"] + + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table( + columns=[ReportedProperty.HDI_LOW, ReportedProperty.HDI_HIGH], + round_ndigits=None, + ) + + assert (result["hdi_low"] <= result["hdi_high"]).all() + + def test_delta_equals_hdi_difference(self, fitted_model, monkeypatch): + """Test that delta equals hdi_high - hdi_low.""" + # Create samples with known distributions + np.random.seed(42) + samples = np.column_stack( + [ + np.random.normal(0.6, 0.1, 1000), + np.random.normal(0.5, 0.15, 1000), + np.random.normal(0.7, 0.08, 1000), + ] + ) + # Clip to [0, 1] range + samples = np.clip(samples, 0, 1) + names = ["A > B", "B > C", "C > D"] + + def mock_get_pwin(*args, **kwargs): + return samples, names + + monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) + + result = fitted_model.posterior_table( + columns=[ + ReportedProperty.HDI_LOW, + ReportedProperty.HDI_HIGH, + ReportedProperty.DELTA, + ], + round_ndigits=None, # Don't round to avoid rounding differences + ) + + calculated_delta = result["hdi_high"] - result["hdi_low"] + np.testing.assert_array_almost_equal( + result["delta"], calculated_delta, decimal=10 + ) diff --git a/tests/data/README.md b/tests/data/README.md new file mode 100644 index 0000000..1c39fc1 --- /dev/null +++ b/tests/data/README.md @@ -0,0 +1,7 @@ +# Datasets for Testing BBT + +This directory contains various datasets used in the testing of the BBT package. + +## Contents + +- `benchmarking_mol.csv` - Adaptation of the benchmarking results from [Praski, Mateusz, Jakub Adamczyk, and Wojciech Czech. "Benchmarking pretrained molecular embedding models for molecular representation learning." arXiv preprint arXiv:2508.06199 (2025).](https://arxiv.org/pdf/2508.06199) - [data source](https://github.com/scikit-fingerprints/benchmarking_molecular_models/blob/31779f16c004b3fb8aa555ecceb6b95ca71a1d7d/results/arxiv_preprint_2025_08.csv) \ No newline at end of file diff --git a/tests/data/benchmarking_mol.csv b/tests/data/benchmarking_mol.csv new file mode 100644 index 0000000..9f8d95e --- /dev/null +++ b/tests/data/benchmarking_mol.csv @@ -0,0 +1,26 @@ +dataset,AtomPair_count,CDDD,CLAMP,ChemBERTa-10M-MTR,ChemFM-3B,ChemGPT-4.7M,ECFP_count,GEM,GNN-GraphCL-sum,GraphFP-CP,GraphMVP_CP-max,MoLFormer-XL-both-10pct,SELFormer-Lite,SimSon,TT,chemformer_mask,coati,grover_large,mat_masking_2M,mol2vec,mol_r_tag_1024,molbert,rmat_4M,unimolv1,unimolv2 +AMES,0.8145254459652627,0.8355244864790773,0.8459975719125106,0.8426452446689773,0.8194070767001508,0.7335193561651882,0.8481701227750689,0.7126309502829505,0.7789569014470619,0.5807867786719928,0.8080058352425151,0.8166421899782647,0.7205085276782393,0.7463177269968083,0.8370175644715973,0.8017074937829212,0.8062239323268519,0.8149405706005599,0.8176819597015802,0.8063541483091503,0.8127807476159705,0.8539838258043041,0.8554191388121953,0.7725518416260353,0.8058469913254617 +Bioavailability_Ma,0.7120053209178583,0.6687728633189225,0.641835716661124,0.744595942800133,0.6704356501496508,0.695710009976721,0.6910542068506818,0.7005320917858331,0.6769205187894911,0.6614566012637179,0.648819421350183,0.7223145992683738,0.6696042567342868,0.6311938809444629,0.6915530428999002,0.7031925507149982,0.6217159960093117,0.8839374792151646,0.6867309610907882,0.7263052876621217,0.5874625872963086,0.6835716661124045,0.7283006318589956,0.6601263717991352,0.7653807781842367 +CYP1A2_Veith,0.9288641730543384,0.923714467918524,0.9507841998257334,0.9261630403212568,0.9038250262031344,0.8679914508328177,0.9278665597494602,0.8753829446008915,0.8971167081286542,0.7630870449178547,0.8894167750571418,0.9142961774994002,0.8789724582959755,0.8582107363396432,0.9151504628168052,0.9044431676116632,0.8970137897940371,0.8569611941052419,0.9269052519920696,0.9133440250539848,0.9013234161310284,0.9293585599009964,0.9308565583603784,0.901898622283398,0.9156012830064784 +CYP2C19_Veith,0.8581617924425365,0.8721135770601274,0.9223048361905296,0.8712846316196391,0.8536529402778946,0.7899815755263134,0.8744442262188677,0.770266453429103,0.8349491378530968,0.6038650867134916,0.8503511873030125,0.8550894880739722,0.7905159164383093,0.7917155336082976,0.8512889337810075,0.8396796448534616,0.8488420015649891,0.7794356686587545,0.8746050896439495,0.8515386461908725,0.8513609482213056,0.8851874713578931,0.8807565568992016,0.8265592996829495,0.8631962564961079 +CYP2C9_Substrate_CarbonMangels,0.5963103635377102,0.6551817688551276,0.6977753662506783,0.692078133478025,0.667932718393923,0.6524688008681498,0.6272381985892567,0.5973955507325014,0.6587086272381986,0.5879001627780792,0.6740368963646228,0.6250678241996744,0.6367335865436788,0.5755561584373305,0.5826098752034725,0.6405317417254477,0.6724091155724362,0.5721649484536082,0.6245252306022789,0.644872490504612,0.6049918610960392,0.6466359196961478,0.6663049376017363,0.5949538795442213,0.6793271839392295 +CYP2C9_Veith,0.8717376922652849,0.8812034842053036,0.9155552755310185,0.8796941081962307,0.8565025635371301,0.8259968654753372,0.8819434052908887,0.8048740701046527,0.8287907665409181,0.6694722500059067,0.8384621139927386,0.8659746560292346,0.8167854583257859,0.8095870775676718,0.8695888102194955,0.863039386326227,0.8663519015853764,0.8590219968024698,0.876420183818607,0.8636198246871382,0.8579658667590747,0.8845333259826892,0.8793838060060013,0.8477920502784055,0.8649319146590219 +CYP2D6_Substrate_CarbonMangels,0.7936046511627907,0.8100353892821032,0.8431496461071789,0.8145854398382204,0.7834934277047523,0.7296511627906976,0.804726996966633,0.7669362992922144,0.8211577350859454,0.4970930232558139,0.7930990899898888,0.8303842264914054,0.7306622851365014,0.7831142568250758,0.7832406471183013,0.7906976744186046,0.8312689585439839,0.7520222446916077,0.8303842264914054,0.8120576339737109,0.782355915065723,0.8224216380182002,0.8162285136501517,0.7554347826086957,0.7960060667340749 +CYP2D6_Veith,0.8584982043680441,0.8681169595964715,0.8934522769316356,0.8601299390219312,0.8567858136286035,0.7904575578371347,0.872068073736207,0.7947070531932723,0.8134802170269133,0.5143666067260971,0.8281027517167855,0.8552481777411953,0.8198453264451658,0.758300079932249,0.847481927349548,0.8301848161816788,0.8464835497777305,0.8185724079667531,0.8598931410271655,0.842437302819137,0.8385404764292931,0.8682570736938108,0.8752777723202793,0.8240435791034559,0.8543082241287849 +CYP3A4_Substrate_CarbonMangels,0.6839963833634719,0.6589059674502713,0.64376130198915,0.6583408679927667,0.6338155515370705,0.6499773960216997,0.6637658227848102,0.6177667269439421,0.6317811934900542,0.523508137432188,0.6778933092224231,0.6306509945750453,0.66873869801085,0.6353978300180831,0.589624773960217,0.6421790235081375,0.6017179023508138,0.5752712477396021,0.6808318264014466,0.6348327305605785,0.6192359855334539,0.6772151898734177,0.6462477396021701,0.6625226039783002,0.60623869801085 +CYP3A4_Veith,0.8669962030469048,0.8633300413060452,0.903643473444684,0.8622570183575008,0.8446845993180165,0.8035443789746224,0.87512161594053,0.7692566913791147,0.7938341218628425,0.6597229358655252,0.831473004264065,0.8499916587146413,0.7942605283703799,0.7804099908579513,0.8613992005712112,0.8433620051115396,0.8393291604663111,0.8504447573353264,0.8771281955464209,0.8439282115616888,0.8600512488572438,0.8658360970792156,0.8824839680495404,0.8243925876001788,0.8605303722882482 +DILI,0.9286956521739133,0.9234782608695652,0.888695652173913,0.8965217391304348,0.8860869565217392,0.8130434782608695,0.9152173913043478,0.8910869565217392,0.895,0.6552173913043479,0.9021739130434784,0.9204347826086956,0.7130434782608696,0.8843478260869565,0.9169565217391304,0.8817391304347826,0.883695652173913,0.8956521739130435,0.9352173913043478,0.9221739130434784,0.9230434782608696,0.9215217391304348,0.912608695652174,0.9221739130434782,0.9208695652173912 +HIA_Hou,0.9868312757201646,0.94320987654321,0.9728395061728394,0.9641975308641976,0.9716049382716048,0.8934156378600823,0.9497942386831276,0.920576131687243,0.8,0.6325102880658435,0.931275720164609,0.9806584362139916,0.8987654320987655,0.9238683127572016,0.9139917695473252,0.962962962962963,0.937448559670782,0.9530864197530864,0.9786008230452676,0.9925925925925928,0.9613168724279836,0.9732510288065844,0.9872427983539096,0.9386831275720164,0.9485596707818932 +PAMPA_NCATS,0.7048673705897502,0.7622971928920937,0.7482616533608034,0.7296420293587433,0.7262941024980686,0.6829770795776461,0.7644604687097605,0.6549060005150656,0.6740664434715427,0.4690703064640741,0.6934329126963689,0.7164563481843935,0.7187226371362349,0.5711048158640227,0.728920937419521,0.7227916559361318,0.7376770538243627,0.6671130569147568,0.7455060520216328,0.7203193407159412,0.6732938449652331,0.739943342776204,0.726191089363894,0.7263456090651559,0.7356682977079578 +Pgp_Broccatelli,0.903425753132498,0.91762196747534,0.9329512130098642,0.928918954945348,0.9049586776859504,0.8626699546787523,0.9277525993068516,0.8675686483604372,0.9162556651559584,0.7138762996534258,0.8774993335110637,0.9292188749666755,0.8533724340175954,0.8679018928285791,0.9264196214342842,0.8866302319381498,0.8844308184484136,0.8661690215942415,0.922354038922954,0.8943615035990402,0.9253199146894162,0.9186883497733938,0.9265529192215408,0.8856971474273527,0.8776992801919489 +SARSCoV2_3CLPro_Diamond,0.7265745007680491,0.7402457757296466,0.7379416282642088,0.7714285714285715,0.7204301075268817,0.6728110599078341,0.7331797235023042,0.7503840245775729,0.7152073732718893,0.6287250384024577,0.7516129032258064,0.7485407066052226,0.6359447004608295,0.7585253456221198,0.708141321044547,0.7268817204301076,0.7084485407066051,0.7397849462365591,0.7156682027649769,0.7247311827956988,0.7185867895545314,0.72642089093702,0.7658986175115207,0.6924731182795699,0.7353302611367126 +SARSCoV2_Vitro_Touret,0.5645802805107808,0.6276952061963575,0.6324052752773708,0.5295164329076826,0.52857441909148,0.5626962528783757,0.5938873770148628,0.6161817039983253,0.5480427046263345,0.5545321331379527,0.5520200962947457,0.5869792756960435,0.4967552857441909,0.629997906635964,0.5688716767845928,0.55714883818296,0.4975926313585932,0.486079129160561,0.6083315888633033,0.5449026585723257,0.4785430186309399,0.6133556625497174,0.5830018840276324,0.517479589700649,0.5534854511199497 +hERG,0.8435415403274712,0.8770466949666464,0.9003941782898727,0.8424802910855064,0.751516070345664,0.7612189205579138,0.8045785324439054,0.7328684050939963,0.7152819890842934,0.6704063068526379,0.8154942389326865,0.8308065494238932,0.7897210430563978,0.747725894481504,0.8423286840509401,0.832929047907823,0.8606731352334749,0.9530018192844149,0.8211036992116435,0.8233778047301394,0.8521831412977562,0.8573377804730139,0.8365676167374165,0.8091267434808975,0.78077622801698 +hERG_Karim,0.8707022332602238,0.8634323749475452,0.8596954388812743,0.8684191029915945,0.8326721774356292,0.7821298268757524,0.8783917297552222,0.7649976419540876,0.8106986726938284,0.5694962849028115,0.846036188402616,0.8413800843810004,0.7679469962144103,0.7809715370312728,0.8624753345080254,0.8289131260501322,0.8316733943310697,0.6962339270006486,0.8532706622045324,0.8538285211320167,0.8334478722574932,0.8761459190877873,0.8699798916459933,0.7781623386065469,0.8066255837753172 +ogbg-molbace,0.8709789601808381,0.8112502173535037,0.8342027473482871,0.8262910798122066,0.864545296470179,0.789340984176665,0.8596765779864372,0.7625630325160841,0.7815162580420796,0.6442357850808555,0.8177708224656582,0.8365501651886628,0.7723874108850635,0.7122239610502521,0.8730655538167275,0.8125543383759346,0.8334202747348287,0.8535037384802643,0.838810641627543,0.823074247956877,0.8135106937923839,0.8194227090940706,0.8302903842809947,0.8036863154234045,0.8258563728047296 +ogbg-molbbbp,0.7381365740740741,0.7231867283950617,0.7009066358024691,0.7208236882716049,0.6943479938271606,0.6887056327160495,0.7201003086419753,0.6756847993827161,0.7038966049382716,0.5511188271604938,0.6868248456790123,0.714940200617284,0.646701388888889,0.7098765432098765,0.679783950617284,0.6837384259259259,0.6879822530864199,0.7307098765432098,0.7261766975308642,0.7304205246913581,0.7055844907407406,0.7426215277777777,0.7265142746913581,0.6687885802469136,0.6610725308641976 +ogbg-molclintox,0.685350415319918,0.9361415215653563,0.761158030097661,0.7393515622284781,0.918466096687867,0.7579775310186634,0.7629661488200743,0.6992831821499322,0.6730128940326,0.5836126750773294,0.7893932679943002,0.8836496020574844,0.8100345810308275,0.7490446773016369,0.6624643763250269,0.7107483578354709,0.8784537587321448,0.6691351266812637,0.8898203176589163,0.8933166510270044,0.8905540784763494,0.871610537656831,0.8641139262503041,0.9166136655892676,0.8237766308692177 +ogbg-molhiv,0.8114312752274087,0.8144556673554917,0.93112651847274,0.7870854979818073,0.7817841209756851,0.7468844221105528,0.7893238571621699,0.7427734827986084,0.7545863935059914,0.5222922303826827,0.7017665249323541,0.7801743950250101,0.7613857750289911,0.7546559721685351,0.8008690781977249,0.7757971378358021,0.7881602580196605,0.7304640877575852,0.7759180517974488,0.7801533433886375,0.7757557015848473,0.7794491689215307,0.7774246231155778,0.771157713578528,0.7609693778792924 +ogbg-molmuv,0.7847982062603639,0.7505515332745804,0.9538875049933596,0.7610820032921798,0.6592362259713641,0.6078312235391929,0.7394131645651287,0.6466112111932127,0.6648343348475036,0.5777699271769144,0.6569077928851594,0.7981076630230599,0.597234324900685,0.6090277349727774,0.7801389611725246,0.7427803814620861,0.7555934977158004,0.6485992802102507,0.7541811833214094,0.784767265466513,0.7583653622561716,0.7660191429147382,0.7735443844162032,0.7465818892952342,0.7388285736382725 +ogbg-molsider,0.6869266781940169,0.639254436731632,0.6921749849321256,0.6917469549305515,0.6728463472551298,0.6329586662880131,0.6863617997542489,0.652717459185894,0.6409775235393228,0.5307502431492865,0.6729195664622901,0.6655429178829346,0.5950587662813914,0.5991447413702715,0.6808944388372183,0.6298043585963656,0.6395070055414712,0.5790117403415879,0.655327943527307,0.6825739326709382,0.6628242686972917,0.6399717638635757,0.6873461835744622,0.6144490489451263,0.6218491127682689 +ogbg-moltox21,0.7793677101920942,0.7833803433436679,0.8320623962565512,0.7887222679343062,0.7582697383714505,0.7083346524640269,0.7817977043382754,0.7198347490077192,0.7551421822081971,0.5520488086890106,0.759697908862892,0.7502529731163675,0.7122280489638442,0.7210298631507531,0.7561706408668748,0.7466120217303461,0.748534989937518,0.689452274900356,0.7703849859168921,0.7778255851570117,0.7703945599045738,0.7638956942207354,0.7797656511703458,0.749669751655326,0.7250482371061816 diff --git a/tests/regression/__init__.py b/tests/regression/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/regression/test_benchmarking_mol.py b/tests/regression/test_benchmarking_mol.py new file mode 100644 index 0000000..20ebef0 --- /dev/null +++ b/tests/regression/test_benchmarking_mol.py @@ -0,0 +1,304 @@ +""" +Regression tests for PyBBT model using molecular embeddings benchmarking data. + +This test suite validates the PyBBT model's weak interpretation results against +the ECFP baseline using molecular embeddings benchmarking data from the study +of pretrained molecular embedding models. + +Notes +----- +The test data is adapted from the benchmarking study: + +Praski, Mateusz, Jakub Adamczyk, and Wojciech Czech. +"Benchmarking pretrained molecular embedding models for molecular representation learning." +arXiv preprint arXiv:2508.06199 (2025). +https://arxiv.org/pdf/2508.06199 + +Data source: +https://github.com/scikit-fingerprints/benchmarking_molecular_models/blob/31779f16c004b3fb8aa555ecceb6b95ca71a1d7d/results/arxiv_preprint_2025_08.csv + +The tests validate that the PyBBT model correctly identifies: +- Better performing models (vs ECFP baseline) +- Equivalent performing models (within ROPE) +- Unknown comparisons (insufficient evidence) +- Worse performing models + +Test parameters: +- local_rope_value: 0.01 +- tie_solver: TieSolver.SPREAD +- MCMC sampling: 2000 draws, 1000 tune, 4 chains +""" + +from pathlib import Path + +import pandas as pd +import pytest + +from bbttest import PyBBT, TieSolver +from bbttest.bbt.const import DEFAULT_PROPERTIES, ReportedProperty + + +@pytest.fixture(scope="module") +def benchmarking_data(): + """ + Load benchmarking molecular data. + + Returns + ------- + pd.DataFrame + Molecular embeddings benchmarking results with columns for dataset + and various model scores. + """ + data_path = Path(__file__).parent.parent / "data" / "benchmarking_mol.csv" + return pd.read_csv(data_path) + + +@pytest.fixture(scope="module") +def fitted_model(benchmarking_data): + """ + Fit PyBBT model with local_rope_value=0.01. + + Parameters + ---------- + benchmarking_data : pd.DataFrame + Benchmarking molecular data fixture. + + Returns + ------- + PyBBT + Fitted PyBBT model instance. + """ + model = PyBBT(local_rope_value=0.01, tie_solver=TieSolver.SPREAD) + model.fit( + benchmarking_data, + dataset_col="dataset", + draws=2000, + tune=1000, + chains=4, + random_seed=42, + ) + return model + + +def _extract_interpretations(results): + """ + Extract model interpretations from posterior table results. + + Parameters + ---------- + results : pd.DataFrame + Posterior table results with comparisons against ECFP_count. + + Returns + ------- + dict + Dictionary mapping model names to their weak interpretations. + """ + interpretations = {} + for _, row in results.iterrows(): + if row["left_model"] != "ECFP_count": + interpretations[row["left_model"]] = row["weak_interpretation"] + if row["right_model"] != "ECFP_count": + # Invert interpretation when ECFP is on the left + if row["weak_interpretation"] == f"{row['left_model']} better": + interpretations[row["right_model"]] = "ECFP better" + else: + interpretations[row["right_model"]] = row["weak_interpretation"] + return interpretations + + +class TestWeakInterpretationAgainstECFP: + """Test weak interpretation results against ECFP baseline for different ROPE values.""" + + @pytest.mark.slow + @pytest.mark.parametrize( + "rope,better_models,equivalent_models,unknown_models,worse_models", + [ + ( + (0.45, 0.55), + ["CLAMP", "rmat_4M"], + [], + [ + "AtomPair_count", + "CDDD", + "ChemBERTa-10M-MTR", + "mat_masking_2M", + "molbert", + ], + [ + "ChemFM-3B", + "ChemGPT-4.7M", + "GEM", + "GNN-GraphCL-sum", + "GraphFP-CP", + "GraphMVP_CP-max", + "MoLFormer-XL-both-10pct", + "SELFormer-Lite", + "SimSon", + "TT", + "chemformer_mask", + "coati", + "grover_large", + "mol2vec", + "mol_r_tag_1024", + "unimolv1", + "unimolv2", + ], + ), + ( + (0.4, 0.6), + ["CLAMP", "rmat_4M"], + [ + "CDDD", + "ChemBERTa-10M-MTR", + "mat_masking_2M", + "molbert", + ], + ["AtomPair_count"], + [ + "ChemFM-3B", + "ChemGPT-4.7M", + "GEM", + "GNN-GraphCL-sum", + "GraphFP-CP", + "GraphMVP_CP-max", + "MoLFormer-XL-both-10pct", + "SELFormer-Lite", + "SimSon", + "TT", + "chemformer_mask", + "coati", + "grover_large", + "mol2vec", + "mol_r_tag_1024", + "unimolv1", + "unimolv2", + ], + ), + ( + (0.35, 0.65), + ["CLAMP"], + [ + "AtomPair_count", + "CDDD", + "ChemBERTa-10M-MTR", + "mat_masking_2M", + "molbert", + "rmat_4M", + ], + [], + [ + "ChemFM-3B", + "ChemGPT-4.7M", + "GEM", + "GNN-GraphCL-sum", + "GraphFP-CP", + "GraphMVP_CP-max", + "MoLFormer-XL-both-10pct", + "SELFormer-Lite", + "SimSon", + "TT", + "chemformer_mask", + "coati", + "grover_large", + "mol2vec", + "mol_r_tag_1024", + "unimolv1", + "unimolv2", + ], + ), + ( + (0.3, 0.7), + [], + [ + "AtomPair_count", + "CDDD", + "CLAMP", + "ChemBERTa-10M-MTR", + "MoLFormer-XL-both-10pct", + "mat_masking_2M", + "mol2vec", + "molbert", + "rmat_4M", + ], + [], + [ + "ChemFM-3B", + "ChemGPT-4.7M", + "GEM", + "GNN-GraphCL-sum", + "GraphFP-CP", + "GraphMVP_CP-max", + "SELFormer-Lite", + "SimSon", + "TT", + "chemformer_mask", + "coati", + "grover_large", + "mol_r_tag_1024", + "unimolv1", + "unimolv2", + ], + ), + ], + ids=["rope_0.45_0.55", "rope_0.4_0.6", "rope_0.35_0.65", "rope_0.3_0.7"], + ) + def test_weak_interpretation_for_rope( + self, + fitted_model, + rope, + better_models, + equivalent_models, + unknown_models, + worse_models, + ): + """ + Test weak interpretation results for different ROPE values. + + Parameters + ---------- + fitted_model : PyBBT + Fitted PyBBT model fixture. + rope : tuple of float + Region of Practical Equivalence (ROPE) bounds. + better_models : list of str + Models expected to be better than ECFP. + equivalent_models : list of str + Models expected to be equivalent to ECFP. + unknown_models : list of str + Models with unknown comparison to ECFP. + worse_models : list of str + Models expected to be worse than ECFP. + """ + results = fitted_model.posterior_table( + rope_value=rope, + control_model="ECFP_count", + columns=list(DEFAULT_PROPERTIES) + + [ReportedProperty.LEFT_MODEL, ReportedProperty.RIGHT_MODEL], + ) + + interpretations = _extract_interpretations(results) + + # Validate better models + for model in better_models: + assert interpretations[model] == f"{model} better", ( + f"Model {model} should be better than ECFP for ROPE {rope}" + ) + + # Validate equivalent models + for model in equivalent_models: + assert interpretations[model] == "Equivalent", ( + f"Model {model} should be equivalent to ECFP for ROPE {rope}" + ) + + # Validate unknown models + for model in unknown_models: + assert interpretations[model] == "Unknown", ( + f"Model {model} should have unknown comparison with ECFP for ROPE {rope}" + ) + + # Validate worse models + for model in worse_models: + assert interpretations[model] == "ECFP better", ( + f"Model {model} should be worse than ECFP for ROPE {rope}" + )