Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions lrmodule/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import arviz as az
import numpy as np
import pymc as pm
from lir.algorithms.bayeserror import ELUBBounder
from lir.bounding import LLRBounder
from lir.data.models import FeatureData, LLRData
from lir.transform import Transformer
from scipy.stats import betabinom, binom, norm
Expand All @@ -17,12 +19,13 @@ class McmcLLRModel(Transformer):
determined.
"""

def __init__(
def __init__( # noqa: PLR0913 (more than five arguments are allowed)
self,
distribution_h1: str,
parameters_h1: dict[str, dict[str, int]] | None,
distribution_h2: str,
parameters_h2: dict[str, dict[str, int]] | None,
bounding: LLRBounder | None = ELUBBounder(),
interval: tuple[float, float] = (0.05, 0.95),
**mcmc_kwargs,
):
Expand All @@ -33,30 +36,48 @@ def __init__(
:param parameters_h1: definition of the parameters of distribution_h1, and their prior distributions
:param distribution_h2: statistical distribution used to model H2, for example 'normal' or 'binomial'
:param parameters_h2: definition of the parameters of distribution_h2, and their prior distributions
:param bounder: bounding method to apply to the unbound llrs, to prevent overextrapolation
:param interval: lower and upper bounds of the credible interval in range 0..1; default: (0.05, 0.95)
:param mcmc_kwargs: mcmc simulation settings, see `McmcModel.__init__` for more details.
"""
self.model_h1 = McmcModel(distribution_h1, parameters_h1, **mcmc_kwargs)
self.model_h2 = McmcModel(distribution_h2, parameters_h2, **mcmc_kwargs)
self.bounding = bounding
self.bounders = None
self.interval = interval

def fit(self, instances: FeatureData) -> Self:
"""Fit the defined model to the supplied instances."""
if instances.labels is None:
raise ValueError("Labels are required to fit this model.")
self.model_h1.fit(instances.features[instances.labels == 1])
self.model_h2.fit(instances.features[instances.labels == 0])
if self.bounding is not None:
# determine the bounds based on the LLRs of the training data, each sample results into an LR-system
logp_h1 = self.model_h1.transform(instances.features)
logp_h2 = self.model_h2.transform(instances.features)
llrs = logp_h1 - logp_h2
# determine the bounds for each LR-system individually
self.bounders = [self.bounding.__class__() for _ in range(llrs.shape[1])]
for i_system in range(llrs.shape[1]):
self.bounders[i_system] = self.bounders[i_system].fit(llrs[:, i_system], instances.labels)
return self

def transform(self, instances: FeatureData) -> LLRData:
"""Apply the fitted model to the supplied instances."""
logp_h1 = self.model_h1.transform(instances.features)
logp_h2 = self.model_h2.transform(instances.features)
llrs = logp_h1 - logp_h2
if (self.bounding is not None) and (self.bounders is not None):
# apply the bounders one by one
for i_system in range(llrs.shape[1]):
llrs[:, i_system] = self.bounders[i_system].transform(llrs[:, i_system])
quantiles = np.quantile(llrs, [0.5] + list(self.interval), axis=1, method="midpoint")
return instances.replace_as(LLRData, features=quantiles.transpose(1, 0))


class McmcModel:
def __init__( # noqa: PLR0913
def __init__( # noqa: PLR0913 (more than five arguments are allowed)
self,
distribution: str,
parameters: dict[str, dict[str, int]] | None,
Expand Down
7 changes: 6 additions & 1 deletion tests/mcmc_matlab_comparison/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ def test_llr_dataset(dataset_name: str):
labels = np.concatenate([np.ones(scores_km.shape[0]), np.zeros(scores_knm.shape[0])])

model = McmcLLRModel(
cfg.distribution_h1, cfg.parameters_h1, cfg.distribution_h2, cfg.parameters_h2, random_seed=cfg.random_seed
cfg.distribution_h1,
cfg.parameters_h1,
cfg.distribution_h2,
cfg.parameters_h2,
bounding=None,
random_seed=cfg.random_seed,
)
model.fit(FeatureData(features=features, labels=labels))
llrs = model.transform(FeatureData(features=scores_eval))
Expand Down