diff --git a/lrmodule/mcmc.py b/lrmodule/mcmc.py index 48db446..7a6b4c5 100644 --- a/lrmodule/mcmc.py +++ b/lrmodule/mcmc.py @@ -3,6 +3,8 @@ import arviz as az import numpy as np import pymc as pm +from lir.algorithms.bayeserror import ELUBBounder +from lir.bounding import LLRBounder from lir.data.models import FeatureData, LLRData from lir.transform import Transformer from scipy.stats import betabinom, binom, norm @@ -17,12 +19,13 @@ class McmcLLRModel(Transformer): determined. """ - def __init__( + def __init__( # noqa: PLR0913 (more than five arguments are allowed) self, distribution_h1: str, parameters_h1: dict[str, dict[str, int]] | None, distribution_h2: str, parameters_h2: dict[str, dict[str, int]] | None, + bounding: LLRBounder | None = ELUBBounder(), interval: tuple[float, float] = (0.05, 0.95), **mcmc_kwargs, ): @@ -33,17 +36,31 @@ def __init__( :param parameters_h1: definition of the parameters of distribution_h1, and their prior distributions :param distribution_h2: statistical distribution used to model H2, for example 'normal' or 'binomial' :param parameters_h2: definition of the parameters of distribution_h2, and their prior distributions + :param bounder: bounding method to apply to the unbound llrs, to prevent overextrapolation :param interval: lower and upper bounds of the credible interval in range 0..1; default: (0.05, 0.95) :param mcmc_kwargs: mcmc simulation settings, see `McmcModel.__init__` for more details. """ self.model_h1 = McmcModel(distribution_h1, parameters_h1, **mcmc_kwargs) self.model_h2 = McmcModel(distribution_h2, parameters_h2, **mcmc_kwargs) + self.bounding = bounding + self.bounders = None self.interval = interval def fit(self, instances: FeatureData) -> Self: """Fit the defined model to the supplied instances.""" + if instances.labels is None: + raise ValueError("Labels are required to fit this model.") self.model_h1.fit(instances.features[instances.labels == 1]) self.model_h2.fit(instances.features[instances.labels == 0]) + if self.bounding is not None: + # determine the bounds based on the LLRs of the training data, each sample results into an LR-system + logp_h1 = self.model_h1.transform(instances.features) + logp_h2 = self.model_h2.transform(instances.features) + llrs = logp_h1 - logp_h2 + # determine the bounds for each LR-system individually + self.bounders = [self.bounding.__class__() for _ in range(llrs.shape[1])] + for i_system in range(llrs.shape[1]): + self.bounders[i_system] = self.bounders[i_system].fit(llrs[:, i_system], instances.labels) return self def transform(self, instances: FeatureData) -> LLRData: @@ -51,12 +68,16 @@ def transform(self, instances: FeatureData) -> LLRData: logp_h1 = self.model_h1.transform(instances.features) logp_h2 = self.model_h2.transform(instances.features) llrs = logp_h1 - logp_h2 + if (self.bounding is not None) and (self.bounders is not None): + # apply the bounders one by one + for i_system in range(llrs.shape[1]): + llrs[:, i_system] = self.bounders[i_system].transform(llrs[:, i_system]) quantiles = np.quantile(llrs, [0.5] + list(self.interval), axis=1, method="midpoint") return instances.replace_as(LLRData, features=quantiles.transpose(1, 0)) class McmcModel: - def __init__( # noqa: PLR0913 + def __init__( # noqa: PLR0913 (more than five arguments are allowed) self, distribution: str, parameters: dict[str, dict[str, int]] | None, diff --git a/tests/mcmc_matlab_comparison/test_mcmc.py b/tests/mcmc_matlab_comparison/test_mcmc.py index f160da4..83fd51e 100644 --- a/tests/mcmc_matlab_comparison/test_mcmc.py +++ b/tests/mcmc_matlab_comparison/test_mcmc.py @@ -88,7 +88,12 @@ def test_llr_dataset(dataset_name: str): labels = np.concatenate([np.ones(scores_km.shape[0]), np.zeros(scores_knm.shape[0])]) model = McmcLLRModel( - cfg.distribution_h1, cfg.parameters_h1, cfg.distribution_h2, cfg.parameters_h2, random_seed=cfg.random_seed + cfg.distribution_h1, + cfg.parameters_h1, + cfg.distribution_h2, + cfg.parameters_h2, + bounding=None, + random_seed=cfg.random_seed, ) model.fit(FeatureData(features=features, labels=labels)) llrs = model.transform(FeatureData(features=scores_eval))