From 394ffe223d1bd9a5aa61b1b6969bc775b4b4b8ba Mon Sep 17 00:00:00 2001 From: Mateusz Praski Date: Tue, 24 Feb 2026 12:03:31 +0100 Subject: [PATCH 1/4] Replace Enums with string literals --- .gitignore | 2 + Makefile | 2 +- bbttest/__init__.py | 5 +- bbttest/bbt/__init__.py | 4 - bbttest/bbt/_types.py | 27 +++++ bbttest/bbt/_utils.py | 62 +++++++++++ bbttest/bbt/alg.py | 11 +- bbttest/bbt/model.py | 8 +- bbttest/bbt/params.py | 72 ------------- bbttest/bbt/py_bbt.py | 119 +++++++++++++--------- tests/bbt/test_alg.py | 5 +- tests/bbt/test_py_bbt.py | 82 +++++---------- tests/regression/test_benchmarking_mol.py | 8 +- 13 files changed, 202 insertions(+), 205 deletions(-) create mode 100644 bbttest/bbt/_types.py create mode 100644 bbttest/bbt/_utils.py delete mode 100644 bbttest/bbt/params.py diff --git a/.gitignore b/.gitignore index b282f62..dd99524 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ wheels/ .idea/ .vscode/ +# OS dependent files +.DS_Store diff --git a/Makefile b/Makefile index ecf31b7..05f5a18 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ test: ## Run tests without regression uv run ruff check uv run pytest tests -m "not slow" -full-test: +test-all: uv run ruff check uv run pytest tests diff --git a/bbttest/__init__.py b/bbttest/__init__.py index 2776693..545b442 100644 --- a/bbttest/__init__.py +++ b/bbttest/__init__.py @@ -1,10 +1,7 @@ """bbt-test: Bayesian Bradley-Terry model for algorithm comparison.""" -from .bbt import HyperPrior, PyBBT, ReportedProperty, TieSolver +from .bbt import PyBBT __all__ = [ - "HyperPrior", "PyBBT", - "ReportedProperty", - "TieSolver", ] diff --git a/bbttest/bbt/__init__.py b/bbttest/bbt/__init__.py index 64cdcbd..b53f709 100644 --- a/bbttest/bbt/__init__.py +++ b/bbttest/bbt/__init__.py @@ -1,11 +1,7 @@ """bbt module: Bayesian Bradley-Terry model implementation.""" -from .params import HyperPrior, ReportedProperty, TieSolver from .py_bbt import PyBBT __all__ = [ - "HyperPrior", "PyBBT", - "ReportedProperty", - "TieSolver", ] diff --git a/bbttest/bbt/_types.py b/bbttest/bbt/_types.py new file mode 100644 index 0000000..9edc5ba --- /dev/null +++ b/bbttest/bbt/_types.py @@ -0,0 +1,27 @@ +from typing import Literal, get_args + +HyperPriorType = Literal[ + "log_normal", + "cauchy", + "normal", +] + +TieSolverType = Literal["add", "spread", "forget", "davidson"] + +ReportedPropertyColumnType = Literal[ + "left_model", + "right_model", + "median", + "mean", + "hdi_low", + "hdi_high", + "delta", + "above_50", + "in_rope", + "weak_interpretation", + "strong_interpretation", +] + +ALL_PROPERTIES_COLUMNS: list[ReportedPropertyColumnType] = list( + get_args(ReportedPropertyColumnType) +) diff --git a/bbttest/bbt/_utils.py b/bbttest/bbt/_utils.py new file mode 100644 index 0000000..a4e8583 --- /dev/null +++ b/bbttest/bbt/_utils.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import sys +from typing import ( + Literal, + get_args, + get_origin, +) + +from pymc.distributions import Cauchy, LogNormal, Normal + +if sys.version_info >= (3, 12): + from typing import TypeAliasType +else: + from typing_extensions import TypeAliasType + + +def is_literal_value(value: object, typx: object) -> bool: + if isinstance(typx, TypeAliasType): + typx = typx.__value__ + if get_origin(typx) is Literal: + return value in get_args(typx) + return False + + +def _validate_params(func): + from inspect import signature + + sig = signature(func) + + def wrapper(*args, **kwargs): + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + for name, value in bound_args.arguments.items(): + param = sig.parameters[name] + # If type annotation is a Literal, validate the value + if param.annotation is not param.empty and is_literal_value( + value, param.annotation + ): + continue # Valid value, continue to next parameter + elif ( + param.annotation is not param.empty + and get_origin(param.annotation) is Literal + ): + raise ValueError( + f"Invalid value '{value}' for parameter '{name}'. Expected one of {get_args(param.annotation)}." + ) + return func(*args, **kwargs) + + return wrapper + + +def _get_distribution_for_prior(prior: str, scale: float): + match prior: + case "log_normal": + return LogNormal("sigma", mu=0, sigma=scale) + case "cauchy": + return Cauchy("sigma", alpha=0, beta=scale) + case "normal": + return Normal("sigma", mu=0, sigma=scale) + case _: + raise ValueError(f"Unsupported hyperprior: {prior}") diff --git a/bbttest/bbt/alg.py b/bbttest/bbt/alg.py index 2db69bc..d441cd2 100644 --- a/bbttest/bbt/alg.py +++ b/bbttest/bbt/alg.py @@ -8,7 +8,6 @@ from tqdm.auto import tqdm from .const import UNNAMED_COLUMNS_WARNING_TEMPLATE -from .params import TieSolver ALG1_COL = 2 ALG2_COL = 3 @@ -107,12 +106,12 @@ def _construct_lrope( return out_array -def _solve_ties(table: np.ndarray, tie_solver: TieSolver) -> np.ndarray: - if tie_solver == TieSolver.DAVIDSON: +def _solve_ties(table: np.ndarray, tie_solver: str) -> np.ndarray: + if tie_solver == "davidson": return table - if tie_solver == TieSolver.SPREAD: + if tie_solver == "spread": tie_val = np.ceil(table[:, TIE_COL] / 2).astype(int) - elif tie_solver == TieSolver.ADD: + elif tie_solver == "add": tie_val = table[:, TIE_COL].astype(int) else: tie_val = 0 @@ -126,7 +125,7 @@ def _construct_win_table( data_sd: pd.DataFrame | None, dataset_col: str | int | None, local_rope_value: float | None, - tie_solver: TieSolver, + tie_solver: str, maximize: bool, ) -> tuple[np.ndarray, list[str]]: # Extract algorithm names diff --git a/bbttest/bbt/model.py b/bbttest/bbt/model.py index bde5e2c..685cc33 100644 --- a/bbttest/bbt/model.py +++ b/bbttest/bbt/model.py @@ -3,7 +3,7 @@ import pymc as pm import pytensor.tensor as pt -from .params import HyperPrior +from ._utils import _get_distribution_for_prior def _build_bbt_model( @@ -12,7 +12,7 @@ def _build_bbt_model( win1: list[int], win2: list[int], ties: list[int] | None, - hyp: HyperPrior, + hyp: str, scale: float, use_davidson: bool, ): @@ -53,7 +53,7 @@ def _build_bbt_model( with pm.Model() as model: # Hyperprior for sigma - sigma = hyp._get_pymc_dist(scale=scale, name="sigma") + sigma = _get_distribution_for_prior(hyp, scale=scale) # Abilities for each player beta = pm.Normal("beta", mu=0, sigma=sigma, shape=K) @@ -100,7 +100,7 @@ def _build_bbt_model( def _mcmcbbt_pymc( table: np.ndarray, use_davidson: bool, - hyper_prior: HyperPrior, + hyper_prior: str, scale: float, **kwargs, ) -> az.InferenceData: diff --git a/bbttest/bbt/params.py b/bbttest/bbt/params.py deleted file mode 100644 index f5d8044..0000000 --- a/bbttest/bbt/params.py +++ /dev/null @@ -1,72 +0,0 @@ -from enum import Enum - -from pymc.distributions import Cauchy, LogNormal, Normal - - -class HyperPrior(str, Enum): - """ - Hyper Prior distributions for BBT MCMC sampling. - """ - - LOG_NORMAL = "logNormal" - LOG_NORMAL_SCALED = "logNormalScaled" - CAUCHY = "cauchy" - NORMAL = "normal" - - def _get_pymc_dist(self, scale, name="sigma"): - match self: - case HyperPrior.LOG_NORMAL: - return LogNormal(name, mu=0, sigma=1) - case HyperPrior.LOG_NORMAL_SCALED: - return LogNormal(name, mu=0, sigma=scale) - case HyperPrior.CAUCHY: - return Cauchy(name, alpha=0, beta=scale) - case HyperPrior.NORMAL: - return Normal(name, mu=0, sigma=scale) - case _: - raise ValueError(f"Unsupported hyperprior: {self}") - - -class ReportedProperty(str, Enum): - """ - Enum containing properties that can be reported from BBT results. - """ - - LEFT_MODEL = "left_model" - RIGHT_MODEL = "right_model" - MEDIAN = "median" - MEAN = "mean" - HDI_LOW = "hdi_low" - HDI_HIGH = "hdi_high" - DELTA = "delta" - ABOVE_50 = "above_50" - IN_ROPE = "in_rope" - WEAK_INTERPRETATION = "weak_interpretation" - STRONG_INTERPRETATION = "strong_interpretation" - - -class TieSolver(str, Enum): - """ - Enum containing tie solving strategies. - - ADD - Add 1 win to both players. - SPREAD - Add 1/2 win to both players. - FOGET - Ignore the tie. - DAVIDSON - Use Davidson's method to handle ties. - """ - - ADD = "add" - SPREAD = "spread" - FORGET = "forget" - DAVIDSON = "davidson" - - -DEFAULT_PROPERTIES = ( - ReportedProperty.MEAN, - ReportedProperty.DELTA, - ReportedProperty.ABOVE_50, - ReportedProperty.IN_ROPE, - ReportedProperty.WEAK_INTERPRETATION, -) - -ALL_PROPERTIES = tuple(ReportedProperty) diff --git a/bbttest/bbt/py_bbt.py b/bbttest/bbt/py_bbt.py index 747bb85..98df5ad 100644 --- a/bbttest/bbt/py_bbt.py +++ b/bbttest/bbt/py_bbt.py @@ -4,9 +4,10 @@ import numpy as np import pandas as pd +from ._types import HyperPriorType, ReportedPropertyColumnType, TieSolverType +from ._utils import _validate_params from .alg import _construct_win_table, _get_pwin, _hdi from .model import _mcmcbbt_pymc -from .params import DEFAULT_PROPERTIES, HyperPrior, ReportedProperty, TieSolver class PyBBT: @@ -19,28 +20,33 @@ class PyBBT: local_rope_value: float | None, default 0.1 The value of the local ROPE to be used when constructing win/tie/loss pairs. If the models is unpaired (i.e., only one score per model per dataset), this value is used to determine the threshold for ties in the followin manner: + - score_a - score_b > local_rope_value => model A wins - score_b - score_a > local_rope_value => model B wins - otherwise => tie + In case of paired BBT (i.e. multiple readings per model per dataset or data_sd provided), the ties are determined based on the following conditions: + - sigma = sqrt(sd_a^2 + sd_b^2) - score_a - score_b > local_rope_value * sigma => model A wins - score_b - score_a > local_rope_value * sigma => model B wins - otherwise => tie + If None, no ties are recorded. tie_solver: TieSolver, default TieSolver.SPREAD The strategy to handle ties when sampling the BBT model. + - ADD - Adds 1 win to both players for each tie. - SPREAD - Adds 0.5 win to both players for each tie. - FORGET - Ignores the ties. - DAVIDSON - Uses Davidson's method to handle ties in the BBT model. See [1]_. - hyper_prior: HyperPrior, default HyperPrior.LOG_NORMAL + hyper_prior: str, default "log_normal" The hyper prior distribution to be used for the BBT MCMC sampling. scale: float, default 1.0 - The scale parameter for the hyper prior distribution. Ignored if the HyperPrior is LOG_NORMAL. + The scale parameter for the hyper prior distribution. maximize: bool, default True Whether higher scores indicate better performance (e.g. accuracy/f1). If using a metric where the goal is to @@ -77,22 +83,19 @@ class PyBBT: _STRONG_INTERPRETATION_BETTER_THRESHOLD = 0.70 _STRONG_INTERPRETATION_EQUAL_THRESHOLD = 0.55 + @_validate_params def __init__( self, local_rope_value: float | None = None, - tie_solver: TieSolver | str = TieSolver.SPREAD, - hyper_prior: HyperPrior | str = HyperPrior.LOG_NORMAL, + tie_solver: TieSolverType = "spread", + hyper_prior: HyperPriorType = "log_normal", maximize: bool = True, scale: float = 1.0, ): self._local_rope_value = local_rope_value - self._tie_solver = ( - TieSolver(tie_solver) if isinstance(tie_solver, str) else tie_solver - ) - self._use_davidson = self._tie_solver == TieSolver.DAVIDSON - self._hyper_prior = ( - HyperPrior(hyper_prior) if isinstance(hyper_prior, str) else hyper_prior - ) + self._tie_solver = tie_solver + self._use_davidson = self._tie_solver == "davidson" + self._hyper_prior = hyper_prior self._maximize = maximize self._scale = scale self._fitted = False @@ -116,16 +119,21 @@ def fit( """ Fits the BBT for a given result dataframes. - Args: - data (pd.DataFrame): Dataframe containing scores for the models on the datasets. - If data_sd is provided, this dataframe should contain mean scores per model per dataset. - If multiple scores per model per dataset are provided, data_sd is ignored, and dataset_col is required. - data_sd (pd.DataFrame | None, optional): Dataframe containing standard deviations of the scores for the models on the datasets. - dataset_col (str, optional): Column name for the dataset identifier. Defaults to "dataset". + Parameters + ---------- + data : pd.DataFrame + Dataframe containing scores for the models on the datasets. + If data_sd is provided, this dataframe should contain mean scores per model per dataset. + If multiple scores per model per dataset are provided, data_sd is ignored, and dataset_col is required. + data_sd : pd.DataFrame | None, optional + Dataframe containing standard deviations of the scores for the models on the datasets. + dataset_col : str, optional + Column name for the dataset identifier. Defaults to "dataset". Returns ------- - self: fitted PyBBT instance + self : PyBBT + Fitted PyBBT instance """ self._win_table, self._algorithms = _construct_win_table( data=data, @@ -153,23 +161,37 @@ def posterior_table( rope_value: tuple[float, float] = (0.45, 0.55), control_model: str | None = None, selected_models: Iterable[str] | None = None, - columns: Iterable[ReportedProperty | str] = DEFAULT_PROPERTIES, + columns: Iterable[ReportedPropertyColumnType] = ( + "mean", + "delta", + "above_50", + "in_rope", + "weak_interpretation", + ), hdi_proba: float = 0.89, round_ndigits: int | None = 2, ) -> pd.DataFrame: """Compute posterior table containing sampling results for the fitted BBT model. - Args: - rope_value (tuple[float, float], optional): Region of Practical Equivalence (ROPE). Defaults to (0.45, 0.55). - control_model (str | None, optional): Control model for comparison. Defaults to None. - selected_models (list[str] | None, optional): Subset of models to include in the posterior table. Defaults to None. - columns (list[ReportedProperty], optional): Columns to include in the posterior table. Defaults to DEFAULT_PROPERTIES. - hdi_proba (float, optional): Highest Density Interval probability. Defaults to 0.89. - round_ndigits (int | None, optional): Number of digits to round the results to. Defaults to 2. + Parameters + ---------- + rope_value : tuple[float, float], optional + Region of Practical Equivalence (ROPE). Defaults to (0.45, 0.55). + control_model : str | None, optional + Control model for comparison. Defaults to None. + selected_models : Iterable[str] | None, optional + Subset of models to include in the posterior table. Defaults to None. + columns : Iterable[ReportedPropertyColumnType], optional + Columns to include in the posterior table. Defaults to minimum set for weak interpretation. + hdi_proba : float, optional + Highest Density Interval probability. Defaults to 0.89. + round_ndigits : int | None, optional + Number of digits to round the results to. Defaults to 2. Returns ------- - pd.DataFrame: Posterior table containing sampling results for the fitted BBT model. + pd.DataFrame + Posterior table containing sampling results for the fitted BBT model. """ self._check_if_fitted() @@ -254,23 +276,26 @@ def rope_comparison_control_table( The output table contains N rows (one per ROPE) and 5 columns (rope value, better models, equivalent models, worse models, unknown models). - Args: - model: Fitted PyBBT model. - ropes: List of ROPE tuples to evaluate. - interpretation: Type of interpretation to use ("weak" or "strong"), see [1]_. - return_as_array: Whether the individual cells should contain model names as list or joined into single string. - join_char: Character(s) used to join multiple model names in a single cell. + Parameters + ---------- + rope_values : Sequence[tuple[float, float]] + List of ROPE tuples to evaluate. + control_model : str + Control model for comparison. + selected_models : Sequence[str] | None, optional + Subset of models to include. Defaults to None. + interpretation : {"weak", "strong"}, optional + Type of interpretation to use, see [1]_. Defaults to "weak". + return_as_array : bool, optional + Whether the individual cells should contain model names as list or joined into single string. + Defaults to False. + join_char : str, optional + Character(s) used to join multiple model names in a single cell. Defaults to ", ". Returns ------- - pd.DataFrame: Table comparing models against control models across multiple ROPEs. - - References - ---------- - .. [1] `Jacques Wainer - "A Bayesian Bradley-Terry model to compare multiple ML algorithms on multiple data sets" - Journal of Machine Learning Research 24 (2023): 1-34 - `_ + pd.DataFrame + Table comparing models against control models across multiple ROPEs. """ self._check_if_fitted() records = [] @@ -279,11 +304,11 @@ def rope_comparison_control_table( rope_value=rope, control_model=control_model, selected_models=selected_models, - columns=[ - ReportedProperty.LEFT_MODEL, - ReportedProperty.WEAK_INTERPRETATION, - ReportedProperty.STRONG_INTERPRETATION, - ], + columns=( + "left_model", + "weak_interpretation", + "strong_interpretation", + ), ) better_models: list[str] = [] equivalent_models: list[str] = [] diff --git a/tests/bbt/test_alg.py b/tests/bbt/test_alg.py index 2532a11..5ce9f0c 100644 --- a/tests/bbt/test_alg.py +++ b/tests/bbt/test_alg.py @@ -7,7 +7,6 @@ from bbttest.bbt.alg import ( _construct_win_table, ) -from bbttest.bbt.params import TieSolver SCORES_1 = pd.DataFrame( { @@ -76,7 +75,7 @@ def test_construct_win_table( data_sd=None, dataset_col=None, local_rope_value=local_rope_value, - tie_solver=TieSolver.DAVIDSON, # Keeps the ties in the table + tie_solver="davidson", # Keeps the ties in the table maximize=maximize, ) @@ -109,6 +108,6 @@ def test_unnamed_columns(self): data_sd=None, dataset_col=None, # This column is unnamed local_rope_value=None, - tie_solver=TieSolver.DAVIDSON, + tie_solver="davidson", maximize=True, ) diff --git a/tests/bbt/test_py_bbt.py b/tests/bbt/test_py_bbt.py index b00b75f..f8e6420 100644 --- a/tests/bbt/test_py_bbt.py +++ b/tests/bbt/test_py_bbt.py @@ -10,8 +10,8 @@ import pandas as pd import pytest -from bbttest import HyperPrior, PyBBT, ReportedProperty, TieSolver -from bbttest.bbt.params import ALL_PROPERTIES +from bbttest import PyBBT +from bbttest.bbt._types import ALL_PROPERTIES_COLUMNS @pytest.fixture(scope="module") @@ -49,7 +49,7 @@ def fitted_model(mock_data): PyBBT Fitted PyBBT model instance. """ - model = PyBBT(local_rope_value=0.01, tie_solver=TieSolver.SPREAD) + model = PyBBT(local_rope_value=0.01, tie_solver="spread") model.fit( mock_data, dataset_col="dataset", @@ -68,13 +68,13 @@ def test_init_with_enum_parameters(self): """Test that PyBBT can be initialized with enum parameters.""" model = PyBBT( local_rope_value=0.01, - tie_solver=TieSolver.SPREAD, - hyper_prior=HyperPrior.LOG_NORMAL, + tie_solver="spread", + hyper_prior="log_normal", scale=1.0, ) assert model._local_rope_value == 0.01 - assert model._tie_solver == TieSolver.SPREAD - assert model._hyper_prior == HyperPrior.LOG_NORMAL + assert model._tie_solver == "spread" + assert model._hyper_prior == "log_normal" assert model._scale == 1.0 assert not model.fitted @@ -83,58 +83,22 @@ def test_init_with_string_parameters(self): model = PyBBT( local_rope_value=0.01, tie_solver="spread", - hyper_prior="logNormal", + hyper_prior="log_normal", scale=1.0, ) # Verify string values are accepted and work correctly assert model._local_rope_value == 0.01 - assert model._tie_solver == TieSolver.SPREAD - assert model._hyper_prior == HyperPrior.LOG_NORMAL + assert model._tie_solver == "spread" + assert model._hyper_prior == "log_normal" assert model._scale == 1.0 assert not model.fitted - @pytest.mark.parametrize( - "arg, expected", - [ - ("add", TieSolver.ADD), - ("spread", TieSolver.SPREAD), - ("forget", TieSolver.FORGET), - ("davidson", TieSolver.DAVIDSON), - (TieSolver.ADD, TieSolver.ADD), - (TieSolver.SPREAD, TieSolver.SPREAD), - (TieSolver.FORGET, TieSolver.FORGET), - (TieSolver.DAVIDSON, TieSolver.DAVIDSON), - ], - ) - def test_init_with_different_tie_solvers(self, arg, expected): - """Test initialization with different TieSolver values.""" - model = PyBBT(tie_solver=arg) - assert model._tie_solver == expected - - @pytest.mark.parametrize( - "arg, expected", - [ - ("logNormal", HyperPrior.LOG_NORMAL), - ("logNormalScaled", HyperPrior.LOG_NORMAL_SCALED), - ("cauchy", HyperPrior.CAUCHY), - ("normal", HyperPrior.NORMAL), - (HyperPrior.LOG_NORMAL, HyperPrior.LOG_NORMAL), - (HyperPrior.LOG_NORMAL_SCALED, HyperPrior.LOG_NORMAL_SCALED), - (HyperPrior.CAUCHY, HyperPrior.CAUCHY), - (HyperPrior.NORMAL, HyperPrior.NORMAL), - ], - ) - def test_init_with_different_hyper_priors(self, arg, expected): - """Test initialization with different HyperPrior values.""" - model = PyBBT(hyper_prior=arg) - assert model._hyper_prior == expected - def test_init_defaults(self): """Test that default initialization values are set correctly.""" model = PyBBT() assert model._local_rope_value is None - assert model._tie_solver == TieSolver.SPREAD - assert model._hyper_prior == HyperPrior.LOG_NORMAL + assert model._tie_solver == "spread" + assert model._hyper_prior == "log_normal" assert model._scale == 1.0 assert model._maximize assert not model.fitted @@ -207,8 +171,8 @@ def test_posterior_table_strong_interpretation_values(self, fitted_model): # Add strong_interpretation to columns result = fitted_model.posterior_table( columns=[ - ReportedProperty.MEAN, - ReportedProperty.STRONG_INTERPRETATION, + "mean", + "strong_interpretation", ] ) valid_values = {"Equivalent", "Unknown"} @@ -220,7 +184,7 @@ def test_posterior_table_strong_interpretation_values(self, fitted_model): def test_posterior_table_with_control_model(self, fitted_model): """Test posterior_table with control_model parameter.""" result = fitted_model.posterior_table( - control_model="model_a", columns=ALL_PROPERTIES + control_model="model_a", columns=ALL_PROPERTIES_COLUMNS ) assert len(result) > 0 # All comparisons should involve model_a @@ -231,7 +195,7 @@ def test_posterior_table_with_control_model(self, fitted_model): def test_posterior_table_returns_only_requested_columns(self, fitted_model): """Test that posterior_table returns only requested columns.""" - requested_columns = [ReportedProperty.MEAN, ReportedProperty.DELTA] + requested_columns = ["mean", "delta"] result = fitted_model.posterior_table( columns=requested_columns, round_ndigits=None ) @@ -471,8 +435,8 @@ def mock_get_pwin(*args, **kwargs): result = fitted_model.posterior_table( columns=[ - ReportedProperty.MEAN, - ReportedProperty.STRONG_INTERPRETATION, + "mean", + "strong_interpretation", ], round_ndigits=None, ) @@ -524,7 +488,7 @@ def mock_get_pwin(*args, **kwargs): monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) - result = fitted_model.posterior_table(columns=ALL_PROPERTIES) + result = fitted_model.posterior_table(columns=ALL_PROPERTIES_COLUMNS) for _, row in result.iterrows(): expected_pair = f"{row['left_model']} > {row['right_model']}" @@ -572,7 +536,7 @@ def mock_get_pwin(*args, **kwargs): monkeypatch.setattr("bbttest.bbt.py_bbt._get_pwin", mock_get_pwin) result = fitted_model.posterior_table( - columns=[ReportedProperty.HDI_LOW, ReportedProperty.HDI_HIGH], + columns=["hdi_low", "hdi_high"], round_ndigits=None, ) @@ -600,9 +564,9 @@ def mock_get_pwin(*args, **kwargs): result = fitted_model.posterior_table( columns=[ - ReportedProperty.HDI_LOW, - ReportedProperty.HDI_HIGH, - ReportedProperty.DELTA, + "hdi_low", + "hdi_high", + "delta", ], round_ndigits=None, # Don't round to avoid rounding differences ) diff --git a/tests/regression/test_benchmarking_mol.py b/tests/regression/test_benchmarking_mol.py index 9995026..f40f94e 100644 --- a/tests/regression/test_benchmarking_mol.py +++ b/tests/regression/test_benchmarking_mol.py @@ -34,8 +34,7 @@ import pandas as pd import pytest -from bbttest import PyBBT, TieSolver -from bbttest.bbt.params import DEFAULT_PROPERTIES, ReportedProperty +from bbttest import PyBBT @pytest.fixture(scope="module") @@ -68,7 +67,7 @@ def fitted_model(benchmarking_data): PyBBT Fitted PyBBT model instance. """ - model = PyBBT(local_rope_value=0.01, tie_solver=TieSolver.SPREAD) + model = PyBBT(local_rope_value=0.01, tie_solver="spread") model.fit( benchmarking_data, dataset_col="dataset", @@ -273,8 +272,7 @@ def test_weak_interpretation_for_rope( results = fitted_model.posterior_table( rope_value=rope, control_model="ECFP_count", - columns=list(DEFAULT_PROPERTIES) - + [ReportedProperty.LEFT_MODEL, ReportedProperty.RIGHT_MODEL], + columns=["left_model", "right_model", "weak_interpretation"], ) interpretations = _extract_interpretations(results) From 41fe1a0fe378e54e15a840218e40b50471481c3d Mon Sep 17 00:00:00 2001 From: Mateusz Praski Date: Tue, 24 Feb 2026 12:19:51 +0100 Subject: [PATCH 2/4] Add unit tests for validate_params --- tests/bbt/test__utils.py | 32 ++++++++++++++++++++++++++++++++ tests/bbt/test_py_bbt.py | 27 +++++++++++++-------------- 2 files changed, 45 insertions(+), 14 deletions(-) create mode 100644 tests/bbt/test__utils.py diff --git a/tests/bbt/test__utils.py b/tests/bbt/test__utils.py new file mode 100644 index 0000000..1a1f0dd --- /dev/null +++ b/tests/bbt/test__utils.py @@ -0,0 +1,32 @@ +from typing import Literal + +import pytest + +from bbttest.bbt._utils import _validate_params + +MockLiteralType = Literal["option1", "option2", "option3"] + + +@_validate_params +def mock_fun(param_lit: MockLiteralType, param_str: str): ... + + +class TestLiteralValidation: + """Tests _validate_params decorator for validating parameters.""" + + @pytest.mark.parametrize( + "params, should_raise", + [ + ({"param_lit": "option1", "param_str": "any string"}, False), + ({"param_lit": "option4", "param_str": "another string"}, True), + ], + ) + def test_literal_validation(self, params, should_raise): + """Test if _validate_params correctly validates Literal parameters and skips strings.""" + if should_raise: + with pytest.raises( + ValueError, match="Invalid value 'option4' for parameter 'param_lit'" + ): + mock_fun(**params) + else: + mock_fun(**params) diff --git a/tests/bbt/test_py_bbt.py b/tests/bbt/test_py_bbt.py index f8e6420..06d78a9 100644 --- a/tests/bbt/test_py_bbt.py +++ b/tests/bbt/test_py_bbt.py @@ -64,20 +64,6 @@ def fitted_model(mock_data): class TestPyBBTInitialization: """Test PyBBT initialization and parameter validation.""" - def test_init_with_enum_parameters(self): - """Test that PyBBT can be initialized with enum parameters.""" - model = PyBBT( - local_rope_value=0.01, - tie_solver="spread", - hyper_prior="log_normal", - scale=1.0, - ) - assert model._local_rope_value == 0.01 - assert model._tie_solver == "spread" - assert model._hyper_prior == "log_normal" - assert model._scale == 1.0 - assert not model.fitted - def test_init_with_string_parameters(self): """Test that PyBBT can be initialized with string parameters that are cast to enums.""" model = PyBBT( @@ -103,6 +89,19 @@ def test_init_defaults(self): assert model._maximize assert not model.fitted + def test_validate_params(self): + """Test that invalid parameter values raise ValueError.""" + with pytest.raises( + ValueError, + match="Invalid value 'invalid_solver' for parameter 'tie_solver'", + ): + PyBBT(tie_solver="invalid_solver") + with pytest.raises( + ValueError, + match="Invalid value 'invalid_prior' for parameter 'hyper_prior'", + ): + PyBBT(hyper_prior="invalid_prior") + class TestPyBBTFitting: """Test PyBBT model fitting functionality.""" From 31e4206b39983e26143a4a3003fc7800c5fccccf Mon Sep 17 00:00:00 2001 From: Mateusz Praski Date: Tue, 24 Feb 2026 12:24:45 +0100 Subject: [PATCH 3/4] fix: keep function signature in @_validate_params --- bbttest/bbt/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bbttest/bbt/_utils.py b/bbttest/bbt/_utils.py index a4e8583..c224112 100644 --- a/bbttest/bbt/_utils.py +++ b/bbttest/bbt/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +from functools import wraps from typing import ( Literal, get_args, @@ -28,6 +29,7 @@ def _validate_params(func): sig = signature(func) + @wraps(func) def wrapper(*args, **kwargs): bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() From 1a4ba991cbd7283f6266cb1ded4f8f3e1164d7d2 Mon Sep 17 00:00:00 2001 From: Mateusz Praski Date: Tue, 24 Feb 2026 12:27:01 +0100 Subject: [PATCH 4/4] Update docs --- bbttest/bbt/py_bbt.py | 16 ++++++++-------- tests/regression/test_benchmarking_mol.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bbttest/bbt/py_bbt.py b/bbttest/bbt/py_bbt.py index 98df5ad..db6126a 100644 --- a/bbttest/bbt/py_bbt.py +++ b/bbttest/bbt/py_bbt.py @@ -34,15 +34,15 @@ class PyBBT: If None, no ties are recorded. - tie_solver: TieSolver, default TieSolver.SPREAD + tie_solver: str, defaults to `spread` The strategy to handle ties when sampling the BBT model. - - ADD - Adds 1 win to both players for each tie. - - SPREAD - Adds 0.5 win to both players for each tie. - - FORGET - Ignores the ties. - - DAVIDSON - Uses Davidson's method to handle ties in the BBT model. See [1]_. + - `add` - Adds 1 win to both players for each tie. + - `spread` - Adds 0.5 win to both players for each tie. + - `forget` - Ignores the ties. + - `davidson` - Uses Davidson's method to handle ties in the BBT model. See [1]_. - hyper_prior: str, default "log_normal" + hyper_prior: str, default `log_normal` The hyper prior distribution to be used for the BBT MCMC sampling. scale: float, default 1.0 @@ -60,14 +60,14 @@ class PyBBT: Examples -------- >>> import pandas as pd - >>> from bbttest import PyBBT, TieSolver + >>> from bbttest import PyBBT >>> data = pd.DataFrame({ ... 'dataset': ['ds1', 'ds2', 'ds3'], ... 'model_a': [0.8, 0.75, 0.9], ... 'model_b': [0.7, 0.8, 0.85], ... 'model_c': [0.6, 0.65, 0.7] ... }) - >>> model = PyBBT(local_rope_value=0.01, tie_solver=TieSolver.SPREAD) + >>> model = PyBBT(local_rope_value=0.01, tie_solver="spread") >>> model.fit(data, dataset_col='dataset') >>> model.posterior_table(rope_value=(0.45, 0.55)) diff --git a/tests/regression/test_benchmarking_mol.py b/tests/regression/test_benchmarking_mol.py index f40f94e..8ce10a3 100644 --- a/tests/regression/test_benchmarking_mol.py +++ b/tests/regression/test_benchmarking_mol.py @@ -25,7 +25,7 @@ Test parameters: - local_rope_value: 0.01 -- tie_solver: TieSolver.SPREAD +- tie_solver: "spread" - MCMC sampling: 2000 draws, 1000 tune, 4 chains """