Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ wheels/
.idea/
.vscode/

# OS dependent files
.DS_Store
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test: ## Run tests without regression
uv run ruff check
uv run pytest tests -m "not slow"

full-test:
test-all:
uv run ruff check
uv run pytest tests

Expand Down
5 changes: 1 addition & 4 deletions bbttest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""bbt-test: Bayesian Bradley-Terry model for algorithm comparison."""

from .bbt import HyperPrior, PyBBT, ReportedProperty, TieSolver
from .bbt import PyBBT

__all__ = [
"HyperPrior",
"PyBBT",
"ReportedProperty",
"TieSolver",
]
4 changes: 0 additions & 4 deletions bbttest/bbt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
"""bbt module: Bayesian Bradley-Terry model implementation."""

from .params import HyperPrior, ReportedProperty, TieSolver
from .py_bbt import PyBBT

__all__ = [
"HyperPrior",
"PyBBT",
"ReportedProperty",
"TieSolver",
]
27 changes: 27 additions & 0 deletions bbttest/bbt/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Literal, get_args

HyperPriorType = Literal[
"log_normal",
"cauchy",
"normal",
]

TieSolverType = Literal["add", "spread", "forget", "davidson"]

ReportedPropertyColumnType = Literal[
"left_model",
"right_model",
"median",
"mean",
"hdi_low",
"hdi_high",
"delta",
"above_50",
"in_rope",
"weak_interpretation",
"strong_interpretation",
]

ALL_PROPERTIES_COLUMNS: list[ReportedPropertyColumnType] = list(
get_args(ReportedPropertyColumnType)
)
64 changes: 64 additions & 0 deletions bbttest/bbt/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import sys
from functools import wraps
from typing import (
Literal,
get_args,
get_origin,
)

from pymc.distributions import Cauchy, LogNormal, Normal

if sys.version_info >= (3, 12):
from typing import TypeAliasType
else:
from typing_extensions import TypeAliasType


def is_literal_value(value: object, typx: object) -> bool:
if isinstance(typx, TypeAliasType):
typx = typx.__value__
if get_origin(typx) is Literal:
return value in get_args(typx)
return False


def _validate_params(func):
from inspect import signature

sig = signature(func)

@wraps(func)
def wrapper(*args, **kwargs):
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for name, value in bound_args.arguments.items():
param = sig.parameters[name]
# If type annotation is a Literal, validate the value
if param.annotation is not param.empty and is_literal_value(
value, param.annotation
):
continue # Valid value, continue to next parameter
elif (
param.annotation is not param.empty
and get_origin(param.annotation) is Literal
):
raise ValueError(
f"Invalid value '{value}' for parameter '{name}'. Expected one of {get_args(param.annotation)}."
)
return func(*args, **kwargs)

return wrapper


def _get_distribution_for_prior(prior: str, scale: float):
match prior:
case "log_normal":
return LogNormal("sigma", mu=0, sigma=scale)
case "cauchy":
return Cauchy("sigma", alpha=0, beta=scale)
case "normal":
return Normal("sigma", mu=0, sigma=scale)
case _:
raise ValueError(f"Unsupported hyperprior: {prior}")
11 changes: 5 additions & 6 deletions bbttest/bbt/alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tqdm.auto import tqdm

from .const import UNNAMED_COLUMNS_WARNING_TEMPLATE
from .params import TieSolver

ALG1_COL = 2
ALG2_COL = 3
Expand Down Expand Up @@ -107,12 +106,12 @@ def _construct_lrope(
return out_array


def _solve_ties(table: np.ndarray, tie_solver: TieSolver) -> np.ndarray:
if tie_solver == TieSolver.DAVIDSON:
def _solve_ties(table: np.ndarray, tie_solver: str) -> np.ndarray:
if tie_solver == "davidson":
return table
if tie_solver == TieSolver.SPREAD:
if tie_solver == "spread":
tie_val = np.ceil(table[:, TIE_COL] / 2).astype(int)
elif tie_solver == TieSolver.ADD:
elif tie_solver == "add":
tie_val = table[:, TIE_COL].astype(int)
else:
tie_val = 0
Expand All @@ -126,7 +125,7 @@ def _construct_win_table(
data_sd: pd.DataFrame | None,
dataset_col: str | int | None,
local_rope_value: float | None,
tie_solver: TieSolver,
tie_solver: str,
maximize: bool,
) -> tuple[np.ndarray, list[str]]:
# Extract algorithm names
Expand Down
8 changes: 4 additions & 4 deletions bbttest/bbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pymc as pm
import pytensor.tensor as pt

from .params import HyperPrior
from ._utils import _get_distribution_for_prior


def _build_bbt_model(
Expand All @@ -12,7 +12,7 @@ def _build_bbt_model(
win1: list[int],
win2: list[int],
ties: list[int] | None,
hyp: HyperPrior,
hyp: str,
scale: float,
use_davidson: bool,
):
Expand Down Expand Up @@ -53,7 +53,7 @@ def _build_bbt_model(

with pm.Model() as model:
# Hyperprior for sigma
sigma = hyp._get_pymc_dist(scale=scale, name="sigma")
sigma = _get_distribution_for_prior(hyp, scale=scale)

# Abilities for each player
beta = pm.Normal("beta", mu=0, sigma=sigma, shape=K)
Expand Down Expand Up @@ -100,7 +100,7 @@ def _build_bbt_model(
def _mcmcbbt_pymc(
table: np.ndarray,
use_davidson: bool,
hyper_prior: HyperPrior,
hyper_prior: str,
scale: float,
**kwargs,
) -> az.InferenceData:
Expand Down
72 changes: 0 additions & 72 deletions bbttest/bbt/params.py

This file was deleted.

Loading