From 896638af70d97b1fd759e751a9e8e0ce0c054c88 Mon Sep 17 00:00:00 2001 From: domosedy Date: Thu, 11 Dec 2025 02:32:30 +0300 Subject: [PATCH 1/5] [feat]: added base version of executable ExponentialFamily --- src/pysatl_core/families/__init__.py | 8 + .../families/exponential_family.py | 228 ++++++++++++++++++ .../unit/families/test_exponential_family.py | 53 ++++ 3 files changed, 289 insertions(+) create mode 100644 src/pysatl_core/families/exponential_family.py create mode 100644 tests/unit/families/test_exponential_family.py diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index ed30528..8ed0cb9 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -21,6 +21,11 @@ constraint, parametrization, ) +from .exponential_family import ( + ExponentialFamily, + ExponentialClassParametrization, + ExponentialConjugateHyperparameters, +) from .registry import ParametricFamilyRegister __all__ = [ @@ -34,6 +39,9 @@ "configure_families_register", # builtins *_builtins_all, + "ExponentialFamily", + "ExponentialClassParametrization", + "ExponentialConjugateHyperparameters", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py new file mode 100644 index 0000000..d87f93e --- /dev/null +++ b/src/pysatl_core/families/exponential_family.py @@ -0,0 +1,228 @@ +from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass +import math +from typing import Any, cast +from scipy.integrate import nquad, quad +import numpy as np + +from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf +from pysatl_core.families.parametric_family import ( + ParametricFamily, +) +from pysatl_core.families.parametrizations import Parametrization, parametrization +from pysatl_core.types import ( + DistributionType, + ParametrizationName, +) +from pysatl_core.distributions import ( + SamplingStrategy, +) + +PDF = "pdf" +CDF = "cdf" +PPF = "ppf" +CF = "char_func" +MEAN = "mean" +VAR = "var" +SKEW = "skewness" +KURT = "kurtosis" + + +class ExponentialClassParametrization(Parametrization): + """ + Standard parametrization of Exponential Family. + """ + + theta: list[Callable[[float], float]] # TODO: mb more clever + + +class ExponentialConjugateHyperparameters: + def __init__(self, alpha: Any, beta: int): + self.alpha = alpha + self.beta = beta + + def __str__(self): + return f"alpha={self.alpha}, beta={self.beta}" + + +def accepts(x, support): + if not hasattr(x, "__len__"): + x = [x] + + def accept_1D(x, borders): + left, right = borders + return left <= x <= right + + return all(accept_1D(x_i, border) for x_i, border in zip(x, support)) + + +class ExponentialFamily(ParametricFamily): + def __init__( + self, + *, + A: Callable[[ExponentialClassParametrization], float], + T: Callable[[Any], Any], + h: Callable[[Any], float], + eta: Callable[[Any], Any], + support: list[tuple[float, float]], + param_space: list[tuple[float, float]], + natural_param_space: list[tuple[float, float]], + name: str = "ExponentialFamily", + theta_from_eta: Callable[[Any], Any] = None, + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + sampling_strategy: SamplingStrategy, + support_by_parametrization: SupportArg = None, + ): + + self._A = A + self._T = T + self._h = h + + self._eta = eta if eta is not None else (lambda th: th) + self._theta_from_eta = theta_from_eta + self._natural_param_space = natural_param_space + self._param_space = param_space + self._support = support + + distr_characteristics = { + PDF: self.density, + MEAN: self._mean, + VAR: self._var, + } + + ParametricFamily.__init__( + self, + name=name, + distr_type=distr_type, + distr_parametrizations=distr_parametrizations, + distr_characteristics=distr_characteristics, + sampling_strategy=sampling_strategy, + support_by_parametrization=support_by_parametrization, + ) + parametrization(family=self, name="theta")((ExponentialClassParametrization)) + + @property + def log_density(self) -> ParametrizedFunction: + def log_density_func( + parametrization: ExponentialClassParametrization, x: Any + ) -> Any: + if not accepts(x, self._support): + return float("-inf") + + params = cast(ExponentialClassParametrization, parametrization) + theta = params.parameters.get("theta") + eta = self._eta(theta) + sufficient = self._T(x) + dot = np.dot(eta, sufficient) + + result = float(np.log(self._h(x)) + dot + self._A(parametrization)) + return result + + return log_density_func + + @property + def density(self) -> ParametrizedFunction: + return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) + + @property + def conjugate_prior_family(self): + def conjugate_sufficient(eta: Any): + theta = [self._theta_from_eta(eta)] + if not accepts(theta, self._param_space): + return [float("-inf"), float("-inf")] + + return [eta, self._A(ExponentialClassParametrization(theta=theta))] + + def conjugate_log_partition(parametrization: ExponentialClassParametrization): + alpha = parametrization.theta[0] + beta = parametrization.theta[1] + + def pdf(eta: Any): + theta = self._theta_from_eta(eta) + if not hasattr(theta, "__len__"): + theta = [theta] + parametrization = ExponentialClassParametrization( + theta=theta, + ) + return np.exp(np.dot(eta, alpha) + beta * self._A(parametrization)) + + all_value = nquad(pdf, self._natural_param_space)[0] + return -np.log(all_value) + + if self._theta_from_eta is None: + raise RuntimeError("Theta from eta wasn't specified") + + return ExponentialFamily( + A=conjugate_log_partition, + T=conjugate_sufficient, + h=lambda _: 1, + eta=lambda x: x, + theta_from_eta=lambda eta: eta, + support=self._natural_param_space, + natural_param_space=[(float("-inf"), float("inf"))] * 2, + param_space=[(float("-inf"), float("inf"))] * 2, + sampling_strategy=self.sampling_strategy, + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + support_by_parametrization=self.support_resolver, + ) + + @property + def _mean(self) -> ParametrizedFunction: + def mean_func(parametrization: Parametrization, x: Any) -> Any: + if hasattr(x, "__len__"): + dimension_size = len(x) + else: + dimension_size = 1 + print(dimension_size) + return nquad( + lambda x: np.dot(x, self.density(parametrization, x)), + [(float("-inf"), float("inf"))] * dimension_size, + )[0] + + return mean_func + + @property + def _second_moment(self) -> ParametrizedFunction: + def func(parametrization: Parametrization, x: Any) -> Any: + if hasattr(x, "__len__"): + dimension_size = len(x) + else: + dimension_size = 1 + return nquad( + lambda x: x**2 * self.density(parametrization, x), + [(float("-inf"), float("inf"))] * dimension_size, + )[0] + + return func + + @property + def _var(self): + def func(parametrization, x: Any): + return ( + self._second_moment(parametrization, x) + - self._mean(parametrization, x) ** 2 + ) + + return func + + def posterior_hyperparameters( + self, prior_hyper: ExponentialConjugateHyperparameters, sample + ): + alpha = prior_hyper.alpha + beta = prior_hyper.beta + + alpha_post = None + beta_post = None + if hasattr(sample, "__iter__") and not isinstance(sample, str): + alpha_post = np.sum([self._T(x) for x in sample], axis=0) + beta_post = len(sample) + else: + alpha_post = self.T(sample) + beta_post = 1 + + return ExponentialConjugateHyperparameters( + alpha=alpha + alpha_post, beta=beta + beta_post + ) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py new file mode 100644 index 0000000..620ec32 --- /dev/null +++ b/tests/unit/families/test_exponential_family.py @@ -0,0 +1,53 @@ +from typing import cast +import pytest +import numpy as np + +# from pysatl_core.distributions.computation import PDF +from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy +from pysatl_core.distributions.support import ContinuousSupport +from pysatl_core.families import ( + ExponentialFamily, + ExponentialConjugateHyperparameters, + ExponentialClassParametrization, +) +from pysatl_core.families.registry import ParametricFamilyRegister +from pysatl_core.types import UnivariateContinuous +import math + + +# TODO: WRITE TEEEEEEESTS +def test_exponential(): + pass + # fam = ExponentialFamily( + # A=lambda parametrization: np.log(parametrization.theta[0]), + # T=lambda x: x, + # h=lambda _: 1, + # eta=lambda theta: -1 * theta, + # theta_from_eta=lambda eta: -1 * eta, + # param_space=[(0, float("+inf"))], + # support=[(0, float("+inf"))], + # natural_param_space=[(float("-inf"), 0)], + # distr_type=UnivariateContinuous, + # distr_parametrizations=["theta"], + # sampling_strategy=DefaultSamplingUnivariateStrategy(), + # ) + + # conjugate_fam = fam + # conjugate_fam = fam.conjugate_prior_family + # params = ExponentialClassParametrization(theta=np.array([2, -1])) + # print(conjugate_fam._A(params)) + # ParametricFamilyRegister().register(conjugate_fam) + # # print( + # # fam.posterior_hyperparameters( + # # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] + # # ) + # # ) + # gamma_family: ExponentialFamily = cast( + # ExponentialFamily, ParametricFamilyRegister().get("ExponentialFamily") + # ) + # print(type(gamma_family)) + # # conjugate = gamma_family.conjugate_prior_family + # # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") + # exponential = gamma_family(theta=np.array([0, -1]), parametrization_name="theta") + # pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) + # print(pdf(-1)) From 73da345822ba3468cb3934a0f09752ccb2694913 Mon Sep 17 00:00:00 2001 From: domosedy Date: Sat, 31 Jan 2026 00:04:02 +0300 Subject: [PATCH 2/5] [feat] added new class structure and manual testing for conjugate prior --- src/pysatl_core/families/__init__.py | 12 +- .../families/exponential_family.py | 221 ++++++++++++------ .../unit/families/test_exponential_family.py | 118 +++++++--- 3 files changed, 247 insertions(+), 104 deletions(-) diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 8ed0cb9..3fd94c7 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -22,9 +22,12 @@ parametrization, ) from .exponential_family import ( - ExponentialFamily, - ExponentialClassParametrization, ExponentialConjugateHyperparameters, + ExponentialFamily, + ExponentialFamilyParametrization, + NaturalExponentialFamily, + SpacePredicate, + SpacePredicateArray, ) from .registry import ParametricFamilyRegister @@ -40,8 +43,11 @@ # builtins *_builtins_all, "ExponentialFamily", - "ExponentialClassParametrization", + "ExponentialFamilyParametrization", "ExponentialConjugateHyperparameters", + "SpacePredicate", + "SpacePredicateArray", + "NaturalExponentialFamily", ] del _builtins_all diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index d87f93e..dfa14f1 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -1,15 +1,11 @@ from __future__ import annotations from collections.abc import Callable -from dataclasses import dataclass -import math -from typing import Any, cast +from typing import Any, cast, TYPE_CHECKING from scipy.integrate import nquad, quad import numpy as np from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf -from pysatl_core.families.parametric_family import ( - ParametricFamily, -) +from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import Parametrization, parametrization from pysatl_core.types import ( DistributionType, @@ -19,6 +15,13 @@ SamplingStrategy, ) +if TYPE_CHECKING: + from pysatl_core.distributions.support import Support + + type ParametrizedFunction = Callable[[Parametrization, Any], Any] + type SupportArg = Callable[[Parametrization], Support | None] | None + + PDF = "pdf" CDF = "cdf" PPF = "ppf" @@ -29,7 +32,7 @@ KURT = "kurtosis" -class ExponentialClassParametrization(Parametrization): +class ExponentialFamilyParametrization(Parametrization): """ Standard parametrization of Exponential Family. """ @@ -46,45 +49,56 @@ def __str__(self): return f"alpha={self.alpha}, beta={self.beta}" -def accepts(x, support): +def doesAccept(x, support): if not hasattr(x, "__len__"): x = [x] def accept_1D(x, borders): left, right = borders + if abs(x) == 0 and (abs(left) == 0 or abs(right) == 0): + return False return left <= x <= right return all(accept_1D(x_i, border) for x_i, border in zip(x, support)) -class ExponentialFamily(ParametricFamily): +class SpacePredicate: + def __init__(self, predicate: Callable[[Any], bool]): + self._predicate = predicate + + def accepts(self, x: Any) -> bool: + return self._predicate(x) + + +class SpacePredicateArray(SpacePredicate): + def __init__(self, space: list[tuple[float, float]]): + SpacePredicate.__init__(self, lambda x: doesAccept(x, space)) + self._space = space + + +class NaturalExponentialFamily(ParametricFamily): def __init__( self, *, - A: Callable[[ExponentialClassParametrization], float], - T: Callable[[Any], Any], - h: Callable[[Any], float], - eta: Callable[[Any], Any], - support: list[tuple[float, float]], - param_space: list[tuple[float, float]], - natural_param_space: list[tuple[float, float]], - name: str = "ExponentialFamily", - theta_from_eta: Callable[[Any], Any] = None, + log_partition: Callable[[ExponentialFamilyParametrization], float], + sufficient_statistics: Callable[[Any], Any], + normalization_constant: Callable[[Any], Any], + support: SpacePredicate, + parameter_space: SpacePredicate, + sufficient_statistics_values: SpacePredicate, + name: str = "NaturalExponentialFamily", distr_type: DistributionType | Callable[[Parametrization], DistributionType], distr_parametrizations: list[ParametrizationName], sampling_strategy: SamplingStrategy, support_by_parametrization: SupportArg = None, ): + self._sufficient = sufficient_statistics + self._log_partition = log_partition + self._normalization = normalization_constant - self._A = A - self._T = T - self._h = h - - self._eta = eta if eta is not None else (lambda th: th) - self._theta_from_eta = theta_from_eta - self._natural_param_space = natural_param_space - self._param_space = param_space self._support = support + self._parameter_space = parameter_space + self._sufficient_statistics_values = sufficient_statistics_values distr_characteristics = { PDF: self.density, @@ -101,23 +115,25 @@ def __init__( sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) - parametrization(family=self, name="theta")((ExponentialClassParametrization)) + parametrization(family=self, name="theta")((ExponentialFamilyParametrization)) @property def log_density(self) -> ParametrizedFunction: def log_density_func( - parametrization: ExponentialClassParametrization, x: Any + parametrization: ExponentialFamilyParametrization, x: Any ) -> Any: - if not accepts(x, self._support): + if not self._support.accepts(x): return float("-inf") - params = cast(ExponentialClassParametrization, parametrization) + params = cast(ExponentialFamilyParametrization, parametrization) theta = params.parameters.get("theta") - eta = self._eta(theta) - sufficient = self._T(x) - dot = np.dot(eta, sufficient) - - result = float(np.log(self._h(x)) + dot + self._A(parametrization)) + sufficient = self._sufficient(x) + dot = np.dot(theta, sufficient) + result = float( + np.log(self._normalization(x)) + + dot + + self._log_partition(parametrization) + ) return result return log_density_func @@ -128,41 +144,59 @@ def density(self) -> ParametrizedFunction: @property def conjugate_prior_family(self): - def conjugate_sufficient(eta: Any): - theta = [self._theta_from_eta(eta)] - if not accepts(theta, self._param_space): + def conjugate_sufficient(theta: Any): + if not self._parameter_space.accepts(theta): return [float("-inf"), float("-inf")] - return [eta, self._A(ExponentialClassParametrization(theta=theta))] + return [ + theta, + self._log_partition(ExponentialFamilyParametrization(theta=[theta])), + ] - def conjugate_log_partition(parametrization: ExponentialClassParametrization): + def conjugate_log_partition(parametrization: ExponentialFamilyParametrization): alpha = parametrization.theta[0] beta = parametrization.theta[1] - def pdf(eta: Any): - theta = self._theta_from_eta(eta) + def pdf(theta: Any): if not hasattr(theta, "__len__"): theta = [theta] - parametrization = ExponentialClassParametrization( + parametrization = ExponentialFamilyParametrization( theta=theta, ) - return np.exp(np.dot(eta, alpha) + beta * self._A(parametrization)) + return np.exp( + np.dot(theta, alpha) + beta * self._log_partition(parametrization) + )[0] - all_value = nquad(pdf, self._natural_param_space)[0] + all_value = nquad( + lambda x: pdf(x) if self._parameter_space.accepts(x) else 0, + [(float("-inf"), float("+inf"))], + )[0] return -np.log(all_value) - if self._theta_from_eta is None: - raise RuntimeError("Theta from eta wasn't specified") - - return ExponentialFamily( - A=conjugate_log_partition, - T=conjugate_sufficient, - h=lambda _: 1, - eta=lambda x: x, - theta_from_eta=lambda eta: eta, - support=self._natural_param_space, - natural_param_space=[(float("-inf"), float("inf"))] * 2, - param_space=[(float("-inf"), float("inf"))] * 2, + # TODO: remove hardcoding - Done, all hardcoding is only on user's hands + # 1. pr with prototype/draft - in progress + # 2. write instruction about to add distributions as member of exponential family - not started + # 3. parametrization's spaces (передавать в конструктор) - maybe impossible, discuss this with desiment on meeting + + def conjugate_sufficient_accepts( + parametrization: ExponentialFamilyParametrization, + ): + parametrization = cast(parametrization, ExponentialFamilyParametrization) + theta = parametrization.parameters.get("theta") + xi = theta[:-1] + nu = theta[-1] + + return self._sufficient_statistics_values(xi) and SpacePredicateArray( + [(0, float("+inf"))] + ).accepts(nu) + + return NaturalExponentialFamily( + log_partition=conjugate_log_partition, + sufficient_statistics=conjugate_sufficient, + normalization_constant=lambda _: 1, + support=self._parameter_space, + sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this + parameter_space=SpacePredicate(conjugate_sufficient_accepts), sampling_strategy=self.sampling_strategy, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, @@ -172,13 +206,15 @@ def pdf(eta: Any): @property def _mean(self) -> ParametrizedFunction: def mean_func(parametrization: Parametrization, x: Any) -> Any: + dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) - else: - dimension_size = 1 - print(dimension_size) return nquad( - lambda x: np.dot(x, self.density(parametrization, x)), + lambda x: ( + np.dot(x, self.density(parametrization, x)) + if self._support.accepts(x) + else 0 + ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -187,12 +223,15 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: @property def _second_moment(self) -> ParametrizedFunction: def func(parametrization: Parametrization, x: Any) -> Any: + dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) - else: - dimension_size = 1 return nquad( - lambda x: x**2 * self.density(parametrization, x), + lambda x: ( + x**2 * self.density(parametrization, x) + if self._support.accepts(x) + else 0 + ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -217,12 +256,62 @@ def posterior_hyperparameters( alpha_post = None beta_post = None if hasattr(sample, "__iter__") and not isinstance(sample, str): - alpha_post = np.sum([self._T(x) for x in sample], axis=0) + alpha_post = np.sum([self._sufficient(x) for x in sample], axis=0) beta_post = len(sample) else: - alpha_post = self.T(sample) + alpha_post = self._sufficient(sample) beta_post = 1 return ExponentialConjugateHyperparameters( alpha=alpha + alpha_post, beta=beta + beta_post ) + + +class ExponentialFamily(NaturalExponentialFamily): + def __init__( + self, + *, + log_partition: Callable[[ExponentialFamilyParametrization], float], + sufficient_statistics: Callable[[Any], Any], + normalization_constant: Callable[[Any], Any], + parameter_from_natural_parameter: Callable[[Any], Any], + support: SpacePredicate, + parameter_space: SpacePredicate, + sufficient_statistics_values: SpacePredicate, + distr_type: DistributionType | Callable[[Parametrization], DistributionType], + distr_parametrizations: list[ParametrizationName], + sampling_strategy: SamplingStrategy, + name: str = "ExponentialFamily", + support_by_parametrization: SupportArg = None, + ): + def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): + eta_parametrizaion = cast( + ExponentialFamilyParametrization, eta_parametrizaion + ) + eta = eta_parametrizaion.parameters.get("theta") + theta = parameter_from_natural_parameter(eta) + return log_partition(ExponentialFamilyParametrization(theta=[theta])) + + natural_sufficient_statistics_values = SpacePredicate( + lambda eta: sufficient_statistics_values.accepts( + parameter_from_natural_parameter(eta) + ) + ) + natural_parameter_space = SpacePredicate( + lambda eta: parameter_space.accepts(parameter_from_natural_parameter(eta)), + ) + + NaturalExponentialFamily.__init__( + self, + log_partition=natural_log_partition, + sufficient_statistics=sufficient_statistics, + normalization_constant=normalization_constant, + support=support, + parameter_space=natural_parameter_space, + sufficient_statistics_values=natural_sufficient_statistics_values, + name=name, + distr_parametrizations=distr_parametrizations, + distr_type=distr_type, + sampling_strategy=sampling_strategy, + support_by_parametrization=support_by_parametrization, + ) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index 620ec32..171d052 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,53 +1,101 @@ -from typing import cast -import pytest import numpy as np +import pytest +import scipy +from typing import cast # from pysatl_core.distributions.computation import PDF from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy -from pysatl_core.distributions.support import ContinuousSupport from pysatl_core.families import ( ExponentialFamily, - ExponentialConjugateHyperparameters, - ExponentialClassParametrization, + ExponentialFamilyParametrization, + SpacePredicateArray, ) from pysatl_core.families.registry import ParametricFamilyRegister from pysatl_core.types import UnivariateContinuous -import math -# TODO: WRITE TEEEEEEESTS +# TODO: WRITE TEEEEEEESTS, MANY TESTS. def test_exponential(): - pass - # fam = ExponentialFamily( - # A=lambda parametrization: np.log(parametrization.theta[0]), - # T=lambda x: x, - # h=lambda _: 1, - # eta=lambda theta: -1 * theta, - # theta_from_eta=lambda eta: -1 * eta, - # param_space=[(0, float("+inf"))], - # support=[(0, float("+inf"))], - # natural_param_space=[(float("-inf"), 0)], + # pass + # fam = NaturalExponentialFamily( + # log_partition=lambda parametrization: np.log(-parametrization.theta[0]), + # sufficient_statistics=lambda x: x, + # normalization_constant=lambda _: 1, + # # param_space=SpacePredicateArray([(0, float("+inf"))]), + # support=SpacePredicateArray([(0, float("+inf"))]), + # parameter_space=SpacePredicateArray([(float("-inf"), 0)]), + # sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), # distr_type=UnivariateContinuous, # distr_parametrizations=["theta"], # sampling_strategy=DefaultSamplingUnivariateStrategy(), # ) - # conjugate_fam = fam - # conjugate_fam = fam.conjugate_prior_family - # params = ExponentialClassParametrization(theta=np.array([2, -1])) - # print(conjugate_fam._A(params)) - # ParametricFamilyRegister().register(conjugate_fam) - # # print( - # # fam.posterior_hyperparameters( - # # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] - # # ) - # # ) - # gamma_family: ExponentialFamily = cast( - # ExponentialFamily, ParametricFamilyRegister().get("ExponentialFamily") + def get_parameter_from_natural_parameter( + eta_parametrization: ExponentialFamilyParametrization, + ): + if hasattr(eta_parametrization, "__len__"): + if len(eta_parametrization) > 1: + return list(-1 * np.array(eta_parametrization)) + eta_parametrization = eta_parametrization[0] + return -eta_parametrization + + fam = ExponentialFamily( + log_partition=lambda parametrization: np.log(parametrization.theta[0]), + sufficient_statistics=lambda x: x, + normalization_constant=lambda _: 1, + parameter_from_natural_parameter=get_parameter_from_natural_parameter, + parameter_space=SpacePredicateArray([(0, float("+inf"))]), + sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), + support=SpacePredicateArray([(0, float("+inf"))]), + distr_type=UnivariateContinuous, + distr_parametrizations=["theta"], + sampling_strategy=DefaultSamplingUnivariateStrategy(), + ) + + conjugate_fam = fam + conjugate_fam = fam.conjugate_prior_family + ParametricFamilyRegister().register(conjugate_fam) + # print( + # fam.posterior_hyperparameters( + # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] + # ) # ) - # print(type(gamma_family)) - # # conjugate = gamma_family.conjugate_prior_family - # # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") - # exponential = gamma_family(theta=np.array([0, -1]), parametrization_name="theta") - # pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) - # print(pdf(-1)) + gamma_family: ExponentialFamily = cast( + ExponentialFamily, ParametricFamilyRegister().get("NaturalExponentialFamily") + ) + print(type(gamma_family)) + # conjugate = gamma_family.conjugate_prior_family + # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") + theta1 = 4 + theta2 = 4 + + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family( + theta=np.array([theta1, theta2]), parametrization_name="theta" + ) + pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) + + def gamma_pdf(alpha: float, beta: float, x: float): + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + + x = [i / 10 for i in range(-100, 100)] + # print(pdf(-x)) + import matplotlib.pyplot as plt + + plt.plot(x, [pdf(-xx) for xx in x], label="conjugate") + plt.plot( + x, + [gamma_pdf(alpha, beta, xx) for xx in x], + label=f"gamma({alpha}, {beta}) test", + ) + + from scipy.integrate import quad + + print(quad(pdf, float("-inf"), float("inf"))) + # mean = exponential.computation_strategy.query_method("mean", distr=exponential) + # print(mean(12)) + plt.legend() + plt.savefig("a.png") + # print(gamma_pdf(alpha, beta, x)) From 3186e9614b5c0120091b9aaf382923a2b2e9fa88 Mon Sep 17 00:00:00 2001 From: domosedy Date: Tue, 3 Feb 2026 14:35:00 +0300 Subject: [PATCH 3/5] [feat] added transform method to ExponentialFamily --- src/pysatl_core/families/__init__.py | 14 +- .../families/exponential_family.py | 152 +++++++++++++----- .../unit/families/test_exponential_family.py | 114 +++++++------ 3 files changed, 184 insertions(+), 96 deletions(-) diff --git a/src/pysatl_core/families/__init__.py b/src/pysatl_core/families/__init__.py index 3fd94c7..44994c2 100644 --- a/src/pysatl_core/families/__init__.py +++ b/src/pysatl_core/families/__init__.py @@ -14,13 +14,6 @@ from .builtins import __all__ as _builtins_all from .configuration import configure_families_register from .distribution import ParametricFamilyDistribution -from .parametric_family import ParametricFamily -from .parametrizations import ( - Parametrization, - ParametrizationConstraint, - constraint, - parametrization, -) from .exponential_family import ( ExponentialConjugateHyperparameters, ExponentialFamily, @@ -29,6 +22,13 @@ SpacePredicate, SpacePredicateArray, ) +from .parametric_family import ParametricFamily +from .parametrizations import ( + Parametrization, + ParametrizationConstraint, + constraint, + parametrization, +) from .registry import ParametricFamilyRegister __all__ = [ diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index dfa14f1..f2c0fcc 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -1,19 +1,24 @@ from __future__ import annotations + from collections.abc import Callable -from typing import Any, cast, TYPE_CHECKING -from scipy.integrate import nquad, quad +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Iterable, Sized, cast + import numpy as np +from scipy.integrate import nquad +from scipy.linalg import det +from scipy.differentiate import jacobian -from pysatl_core.distributions.fitters import _ppf_brentq_from_cdf +from pysatl_core.distributions import ( + SamplingStrategy, +) from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import Parametrization, parametrization from pysatl_core.types import ( + GenericCharacteristicName, DistributionType, ParametrizationName, ) -from pysatl_core.distributions import ( - SamplingStrategy, -) if TYPE_CHECKING: from pysatl_core.distributions.support import Support @@ -32,12 +37,13 @@ KURT = "kurtosis" +@dataclass class ExponentialFamilyParametrization(Parametrization): """ Standard parametrization of Exponential Family. """ - theta: list[Callable[[float], float]] # TODO: mb more clever + theta: list[float] # TODO: mb more clever class ExponentialConjugateHyperparameters: @@ -45,21 +51,23 @@ def __init__(self, alpha: Any, beta: int): self.alpha = alpha self.beta = beta - def __str__(self): + def __str__(self) -> str: return f"alpha={self.alpha}, beta={self.beta}" -def doesAccept(x, support): +def doesAccept(x: list[float] | float, support: list[tuple[float, float]]) -> bool: if not hasattr(x, "__len__"): x = [x] - def accept_1D(x, borders): + x = cast(list[float], x) + + def accept_1D(x: float, borders: tuple[float, float]) -> bool: left, right = borders if abs(x) == 0 and (abs(left) == 0 or abs(right) == 0): return False return left <= x <= right - return all(accept_1D(x_i, border) for x_i, border in zip(x, support)) + return all(accept_1D(x_i, border) for x_i, border in zip(x, support, strict=False)) class SpacePredicate: @@ -100,7 +108,10 @@ def __init__( self._parameter_space = parameter_space self._sufficient_statistics_values = sufficient_statistics_values - distr_characteristics = { + distr_characteristics: dict[ + GenericCharacteristicName, + dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction, + ] = { PDF: self.density, MEAN: self._mean, VAR: self._var, @@ -115,20 +126,29 @@ def __init__( sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) - parametrization(family=self, name="theta")((ExponentialFamilyParametrization)) + parametrization(family=self, name="theta")(ExponentialFamilyParametrization) + + def _transform_to_natural_parametrization( + self, theta_parametrization: ExponentialFamilyParametrization + ) -> ExponentialFamilyParametrization: + return theta_parametrization @property def log_density(self) -> ParametrizedFunction: - def log_density_func( - parametrization: ExponentialFamilyParametrization, x: Any - ) -> Any: + def log_density_func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) + parametrization = self._transform_to_natural_parametrization( + parametrization + ) if not self._support.accepts(x): return float("-inf") - params = cast(ExponentialFamilyParametrization, parametrization) - theta = params.parameters.get("theta") + theta = parametrization.theta sufficient = self._sufficient(x) dot = np.dot(theta, sufficient) + if hasattr(dot, "__len__"): + dot = dot[0] + result = float( np.log(self._normalization(x)) + dot @@ -143,26 +163,31 @@ def density(self) -> ParametrizedFunction: return lambda parametrization, x: np.exp(self.log_density(parametrization, x)) @property - def conjugate_prior_family(self): - def conjugate_sufficient(theta: Any): + def conjugate_prior_family(self) -> NaturalExponentialFamily: + def conjugate_sufficient( + theta: float, + ) -> list[Any]: if not self._parameter_space.accepts(theta): return [float("-inf"), float("-inf")] + parametrization = ExponentialFamilyParametrization([theta]) + # parametrization.theta = [theta] return [ theta, - self._log_partition(ExponentialFamilyParametrization(theta=[theta])), + self._log_partition(parametrization), ] - def conjugate_log_partition(parametrization: ExponentialFamilyParametrization): + def conjugate_log_partition( + parametrization: ExponentialFamilyParametrization, + ) -> Any: alpha = parametrization.theta[0] beta = parametrization.theta[1] - def pdf(theta: Any): + def pdf(theta: Any) -> Any: if not hasattr(theta, "__len__"): theta = [theta] - parametrization = ExponentialFamilyParametrization( - theta=theta, - ) + parametrization = ExponentialFamilyParametrization(theta=theta) + # parametrization.theta = theta return np.exp( np.dot(theta, alpha) + beta * self._log_partition(parametrization) )[0] @@ -180,15 +205,14 @@ def pdf(theta: Any): def conjugate_sufficient_accepts( parametrization: ExponentialFamilyParametrization, - ): - parametrization = cast(parametrization, ExponentialFamilyParametrization) - theta = parametrization.parameters.get("theta") + ) -> bool: + theta = parametrization.theta xi = theta[:-1] nu = theta[-1] - return self._sufficient_statistics_values(xi) and SpacePredicateArray( - [(0, float("+inf"))] - ).accepts(nu) + return self._sufficient_statistics_values.accepts( + xi + ) and SpacePredicateArray([(0, float("+inf"))]).accepts(nu) return NaturalExponentialFamily( log_partition=conjugate_log_partition, @@ -197,15 +221,50 @@ def conjugate_sufficient_accepts( support=self._parameter_space, sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this parameter_space=SpacePredicate(conjugate_sufficient_accepts), + name=self.name, sampling_strategy=self.sampling_strategy, distr_type=self._distr_type, distr_parametrizations=self.parametrization_names, support_by_parametrization=self.support_resolver, ) + def transform( + self, + transform_function: Callable[[Any], Any], + ) -> NaturalExponentialFamily: + def calculate_jacobian(x: Any) -> Any: + if type(x) is not list: + x = np.array([x]) + + return np.abs(det(jacobian(transform_function, x).df)) + + def new_support(x: Any) -> bool: + return self._support.accepts(transform_function(x)) + + def new_sufficient(x: Any) -> Any: + return self._sufficient(transform_function(x)) + + def new_normalization(x: Any) -> Any: + return self._normalization(x) * calculate_jacobian(x) + + return NaturalExponentialFamily( + log_partition=self._log_partition, + sufficient_statistics=new_sufficient, + normalization_constant=new_normalization, + support=SpacePredicate(new_support), + parameter_space=self._parameter_space, + sufficient_statistics_values=self._sufficient_statistics_values, + name=f"Transformed{self._name}", + distr_type=self._distr_type, + distr_parametrizations=self.parametrization_names, + sampling_strategy=self.sampling_strategy, + support_by_parametrization=self.support_resolver, + ) + @property def _mean(self) -> ParametrizedFunction: def mean_func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) @@ -223,6 +282,7 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: @property def _second_moment(self) -> ParametrizedFunction: def func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) dimension_size = 1 if hasattr(x, "__len__"): dimension_size = len(x) @@ -238,8 +298,9 @@ def func(parametrization: Parametrization, x: Any) -> Any: return func @property - def _var(self): - def func(parametrization, x: Any): + def _var(self) -> ParametrizedFunction: + def func(parametrization: Parametrization, x: Any) -> Any: + parametrization = cast(ExponentialFamilyParametrization, parametrization) return ( self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 @@ -248,8 +309,8 @@ def func(parametrization, x: Any): return func def posterior_hyperparameters( - self, prior_hyper: ExponentialConjugateHyperparameters, sample - ): + self, prior_hyper: ExponentialConjugateHyperparameters, sample: list[Any] + ) -> ExponentialConjugateHyperparameters: alpha = prior_hyper.alpha beta = prior_hyper.beta @@ -275,6 +336,9 @@ def __init__( sufficient_statistics: Callable[[Any], Any], normalization_constant: Callable[[Any], Any], parameter_from_natural_parameter: Callable[[Any], Any], + natural_parameter: Callable[ + [ExponentialFamilyParametrization], ExponentialFamilyParametrization + ], support: SpacePredicate, parameter_space: SpacePredicate, sufficient_statistics_values: SpacePredicate, @@ -284,11 +348,10 @@ def __init__( name: str = "ExponentialFamily", support_by_parametrization: SupportArg = None, ): - def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): - eta_parametrizaion = cast( - ExponentialFamilyParametrization, eta_parametrizaion - ) - eta = eta_parametrizaion.parameters.get("theta") + def natural_log_partition( + eta_parametrizaion: ExponentialFamilyParametrization, + ) -> Any: + eta = eta_parametrizaion.theta theta = parameter_from_natural_parameter(eta) return log_partition(ExponentialFamilyParametrization(theta=[theta])) @@ -297,6 +360,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): parameter_from_natural_parameter(eta) ) ) + + self._natural_parameter = natural_parameter natural_parameter_space = SpacePredicate( lambda eta: parameter_space.accepts(parameter_from_natural_parameter(eta)), ) @@ -315,3 +380,8 @@ def natural_log_partition(eta_parametrizaion: ExponentialFamilyParametrization): sampling_strategy=sampling_strategy, support_by_parametrization=support_by_parametrization, ) + + def _transform_to_natural_parametrization( + self, theta_parametrization: ExponentialFamilyParametrization + ) -> ExponentialFamilyParametrization: + return self._natural_parameter(theta_parametrization) diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index 171d052..35ea7cc 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -1,9 +1,10 @@ +from typing import Any, cast + import numpy as np import pytest import scipy -from typing import cast +from numpy.testing import assert_allclose -# from pysatl_core.distributions.computation import PDF from pysatl_core.distributions.strategies import DefaultSamplingUnivariateStrategy from pysatl_core.families import ( ExponentialFamily, @@ -14,22 +15,12 @@ from pysatl_core.types import UnivariateContinuous -# TODO: WRITE TEEEEEEESTS, MANY TESTS. -def test_exponential(): - # pass - # fam = NaturalExponentialFamily( - # log_partition=lambda parametrization: np.log(-parametrization.theta[0]), - # sufficient_statistics=lambda x: x, - # normalization_constant=lambda _: 1, - # # param_space=SpacePredicateArray([(0, float("+inf"))]), - # support=SpacePredicateArray([(0, float("+inf"))]), - # parameter_space=SpacePredicateArray([(float("-inf"), 0)]), - # sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), - # distr_type=UnivariateContinuous, - # distr_parametrizations=["theta"], - # sampling_strategy=DefaultSamplingUnivariateStrategy(), - # ) +def gamma_pdf(alpha: float, beta: float, x: float) -> float: + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + +@pytest.fixture(scope="function") +def conjugate_for_exponential() -> ExponentialFamily: def get_parameter_from_natural_parameter( eta_parametrization: ExponentialFamilyParametrization, ): @@ -39,11 +30,29 @@ def get_parameter_from_natural_parameter( eta_parametrization = eta_parametrization[0] return -eta_parametrization + def natural_parameter( + theta_parametrization: Any, + ) -> Any: + if type(theta_parametrization) is ExponentialFamilyParametrization: + theta_parametrization = cast( + ExponentialFamilyParametrization, theta_parametrization + ) + eta = -theta_parametrization.theta + return ExponentialFamilyParametrization(theta=eta) + + return -1 * theta_parametrization + + def transform_function(x: list[Any]) -> list[Any]: + if type(x) is not list: + return -x + return [-x[0]] + fam = ExponentialFamily( log_partition=lambda parametrization: np.log(parametrization.theta[0]), sufficient_statistics=lambda x: x, normalization_constant=lambda _: 1, parameter_from_natural_parameter=get_parameter_from_natural_parameter, + natural_parameter=natural_parameter, parameter_space=SpacePredicateArray([(0, float("+inf"))]), sufficient_statistics_values=SpacePredicateArray([(0, float("+inf"))]), support=SpacePredicateArray([(0, float("+inf"))]), @@ -52,22 +61,18 @@ def get_parameter_from_natural_parameter( sampling_strategy=DefaultSamplingUnivariateStrategy(), ) - conjugate_fam = fam - conjugate_fam = fam.conjugate_prior_family + conjugate_fam = fam.conjugate_prior_family.transform(transform_function) ParametricFamilyRegister().register(conjugate_fam) - # print( - # fam.posterior_hyperparameters( - # ExponentialConjugateHyperparameters(alpha=10, beta=1), [12] - # ) - # ) - gamma_family: ExponentialFamily = cast( - ExponentialFamily, ParametricFamilyRegister().get("NaturalExponentialFamily") + return cast( + ExponentialFamily, + ParametricFamilyRegister().get("TransformedExponentialFamily"), ) - print(type(gamma_family)) - # conjugate = gamma_family.conjugate_prior_family - # exponential = gamma_family(theta=np.array([2]), parametrization_name="theta") - theta1 = 4 - theta2 = 4 + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): + gamma_family: ExponentialFamily = conjugate_for_exponential alpha = theta2 + 1 beta = theta1 @@ -77,25 +82,38 @@ def get_parameter_from_natural_parameter( ) pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) - def gamma_pdf(alpha: float, beta: float, x: float): - return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + x = [i / 10 for i in range(100)] + + assert_allclose( + [pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6 + ) + + +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_mean(theta1, theta2, conjugate_for_exponential): + gamma_family: ExponentialFamily = conjugate_for_exponential - x = [i / 10 for i in range(-100, 100)] - # print(pdf(-x)) - import matplotlib.pyplot as plt + alpha = theta2 + 1 + beta = theta1 - plt.plot(x, [pdf(-xx) for xx in x], label="conjugate") - plt.plot( - x, - [gamma_pdf(alpha, beta, xx) for xx in x], - label=f"gamma({alpha}, {beta}) test", + exponential = gamma_family( + theta=np.array([theta1, theta2]), parametrization_name="theta" ) + mean = exponential.computation_strategy.query_method("mean", distr=exponential) + assert np.isclose(mean(12), alpha / beta, rtol=1e-6) + - from scipy.integrate import quad +@pytest.mark.parametrize("theta1", range(2, 5)) +@pytest.mark.parametrize("theta2", range(2, 5)) +def test_exponential_var(theta1, theta2, conjugate_for_exponential): + gamma_family: ExponentialFamily = conjugate_for_exponential - print(quad(pdf, float("-inf"), float("inf"))) - # mean = exponential.computation_strategy.query_method("mean", distr=exponential) - # print(mean(12)) - plt.legend() - plt.savefig("a.png") - # print(gamma_pdf(alpha, beta, x)) + alpha = theta2 + 1 + beta = theta1 + + exponential = gamma_family( + theta=np.array([theta1, theta2]), parametrization_name="theta" + ) + var = exponential.computation_strategy.query_method("var", distr=exponential) + assert np.isclose(var(12), alpha / beta**2, rtol=1e-6) From 0eae845a1e3f073cd2609de9ae174eebecc6ec97 Mon Sep 17 00:00:00 2001 From: domosedy Date: Tue, 3 Feb 2026 14:46:50 +0300 Subject: [PATCH 4/5] [refactor] refactoring for mypy --- .../families/exponential_family.py | 56 ++++++------------- .../unit/families/test_exponential_family.py | 35 ++++-------- 2 files changed, 29 insertions(+), 62 deletions(-) diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index f2c0fcc..21c24ba 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -2,12 +2,12 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Iterable, Sized, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np +from scipy.differentiate import jacobian from scipy.integrate import nquad from scipy.linalg import det -from scipy.differentiate import jacobian from pysatl_core.distributions import ( SamplingStrategy, @@ -15,8 +15,8 @@ from pysatl_core.families.parametric_family import ParametricFamily from pysatl_core.families.parametrizations import Parametrization, parametrization from pysatl_core.types import ( - GenericCharacteristicName, DistributionType, + GenericCharacteristicName, ParametrizationName, ) @@ -137,9 +137,7 @@ def _transform_to_natural_parametrization( def log_density(self) -> ParametrizedFunction: def log_density_func(parametrization: Parametrization, x: Any) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) - parametrization = self._transform_to_natural_parametrization( - parametrization - ) + parametrization = self._transform_to_natural_parametrization(parametrization) if not self._support.accepts(x): return float("-inf") @@ -150,9 +148,7 @@ def log_density_func(parametrization: Parametrization, x: Any) -> Any: dot = dot[0] result = float( - np.log(self._normalization(x)) - + dot - + self._log_partition(parametrization) + np.log(self._normalization(x)) + dot + self._log_partition(parametrization) ) return result @@ -188,21 +184,14 @@ def pdf(theta: Any) -> Any: theta = [theta] parametrization = ExponentialFamilyParametrization(theta=theta) # parametrization.theta = theta - return np.exp( - np.dot(theta, alpha) + beta * self._log_partition(parametrization) - )[0] + return np.exp(np.dot(theta, alpha) + beta * self._log_partition(parametrization))[0] all_value = nquad( - lambda x: pdf(x) if self._parameter_space.accepts(x) else 0, + lambda x: pdf(x) if self._parameter_space.accepts(x) else 0, # type: ignore[arg-type] [(float("-inf"), float("+inf"))], )[0] return -np.log(all_value) - # TODO: remove hardcoding - Done, all hardcoding is only on user's hands - # 1. pr with prototype/draft - in progress - # 2. write instruction about to add distributions as member of exponential family - not started - # 3. parametrization's spaces (передавать в конструктор) - maybe impossible, discuss this with desiment on meeting - def conjugate_sufficient_accepts( parametrization: ExponentialFamilyParametrization, ) -> bool: @@ -210,9 +199,9 @@ def conjugate_sufficient_accepts( xi = theta[:-1] nu = theta[-1] - return self._sufficient_statistics_values.accepts( - xi - ) and SpacePredicateArray([(0, float("+inf"))]).accepts(nu) + return self._sufficient_statistics_values.accepts(xi) and SpacePredicateArray( + [(0, float("+inf"))] + ).accepts(nu) return NaturalExponentialFamily( log_partition=conjugate_log_partition, @@ -269,10 +258,8 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any: if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: ( - np.dot(x, self.density(parametrization, x)) - if self._support.accepts(x) - else 0 + lambda x: ( # type: ignore[arg-type] + np.dot(x, self.density(parametrization, x)) if self._support.accepts(x) else 0 ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -287,10 +274,8 @@ def func(parametrization: Parametrization, x: Any) -> Any: if hasattr(x, "__len__"): dimension_size = len(x) return nquad( - lambda x: ( - x**2 * self.density(parametrization, x) - if self._support.accepts(x) - else 0 + lambda x: ( # type: ignore[arg-type] + x**2 * self.density(parametrization, x) if self._support.accepts(x) else 0 ), [(float("-inf"), float("inf"))] * dimension_size, )[0] @@ -301,10 +286,7 @@ def func(parametrization: Parametrization, x: Any) -> Any: def _var(self) -> ParametrizedFunction: def func(parametrization: Parametrization, x: Any) -> Any: parametrization = cast(ExponentialFamilyParametrization, parametrization) - return ( - self._second_moment(parametrization, x) - - self._mean(parametrization, x) ** 2 - ) + return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2 return func @@ -323,9 +305,7 @@ def posterior_hyperparameters( alpha_post = self._sufficient(sample) beta_post = 1 - return ExponentialConjugateHyperparameters( - alpha=alpha + alpha_post, beta=beta + beta_post - ) + return ExponentialConjugateHyperparameters(alpha=alpha + alpha_post, beta=beta + beta_post) class ExponentialFamily(NaturalExponentialFamily): @@ -356,9 +336,7 @@ def natural_log_partition( return log_partition(ExponentialFamilyParametrization(theta=[theta])) natural_sufficient_statistics_values = SpacePredicate( - lambda eta: sufficient_statistics_values.accepts( - parameter_from_natural_parameter(eta) - ) + lambda eta: sufficient_statistics_values.accepts(parameter_from_natural_parameter(eta)) ) self._natural_parameter = natural_parameter diff --git a/tests/unit/families/test_exponential_family.py b/tests/unit/families/test_exponential_family.py index 35ea7cc..5b52154 100644 --- a/tests/unit/families/test_exponential_family.py +++ b/tests/unit/families/test_exponential_family.py @@ -16,14 +16,14 @@ def gamma_pdf(alpha: float, beta: float, x: float) -> float: - return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() + return scipy.stats.gamma(a=alpha, scale=1 / beta).pdf(x).item() # type: ignore[attr-defined] @pytest.fixture(scope="function") def conjugate_for_exponential() -> ExponentialFamily: def get_parameter_from_natural_parameter( - eta_parametrization: ExponentialFamilyParametrization, - ): + eta_parametrization: Any, + ) -> Any: if hasattr(eta_parametrization, "__len__"): if len(eta_parametrization) > 1: return list(-1 * np.array(eta_parametrization)) @@ -34,18 +34,15 @@ def natural_parameter( theta_parametrization: Any, ) -> Any: if type(theta_parametrization) is ExponentialFamilyParametrization: - theta_parametrization = cast( - ExponentialFamilyParametrization, theta_parametrization - ) - eta = -theta_parametrization.theta + eta = list(-np.array(theta_parametrization.theta)) return ExponentialFamilyParametrization(theta=eta) return -1 * theta_parametrization - def transform_function(x: list[Any]) -> list[Any]: - if type(x) is not list: - return -x - return [-x[0]] + def transform_function(x: list[float] | float) -> list[float] | float: + if type(x) is list: + return [-x[0]] + return -x # type: ignore[operator] fam = ExponentialFamily( log_partition=lambda parametrization: np.log(parametrization.theta[0]), @@ -77,16 +74,12 @@ def test_exponential_pdf(theta1, theta2, conjugate_for_exponential): alpha = theta2 + 1 beta = theta1 - exponential = gamma_family( - theta=np.array([theta1, theta2]), parametrization_name="theta" - ) + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") pdf = exponential.computation_strategy.query_method("pdf", distr=exponential) x = [i / 10 for i in range(100)] - assert_allclose( - [pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6 - ) + assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6) @pytest.mark.parametrize("theta1", range(2, 5)) @@ -97,9 +90,7 @@ def test_exponential_mean(theta1, theta2, conjugate_for_exponential): alpha = theta2 + 1 beta = theta1 - exponential = gamma_family( - theta=np.array([theta1, theta2]), parametrization_name="theta" - ) + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") mean = exponential.computation_strategy.query_method("mean", distr=exponential) assert np.isclose(mean(12), alpha / beta, rtol=1e-6) @@ -112,8 +103,6 @@ def test_exponential_var(theta1, theta2, conjugate_for_exponential): alpha = theta2 + 1 beta = theta1 - exponential = gamma_family( - theta=np.array([theta1, theta2]), parametrization_name="theta" - ) + exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta") var = exponential.computation_strategy.query_method("var", distr=exponential) assert np.isclose(var(12), alpha / beta**2, rtol=1e-6) From c1c0054f816dee626f1d05b161126eeec3d30a8b Mon Sep 17 00:00:00 2001 From: domosedy Date: Tue, 3 Feb 2026 14:54:27 +0300 Subject: [PATCH 5/5] [refactor] remove comments in code --- src/pysatl_core/families/exponential_family.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pysatl_core/families/exponential_family.py b/src/pysatl_core/families/exponential_family.py index 21c24ba..d43d027 100644 --- a/src/pysatl_core/families/exponential_family.py +++ b/src/pysatl_core/families/exponential_family.py @@ -43,7 +43,7 @@ class ExponentialFamilyParametrization(Parametrization): Standard parametrization of Exponential Family. """ - theta: list[float] # TODO: mb more clever + theta: list[float] class ExponentialConjugateHyperparameters: @@ -167,7 +167,6 @@ def conjugate_sufficient( return [float("-inf"), float("-inf")] parametrization = ExponentialFamilyParametrization([theta]) - # parametrization.theta = [theta] return [ theta, self._log_partition(parametrization), @@ -183,7 +182,6 @@ def pdf(theta: Any) -> Any: if not hasattr(theta, "__len__"): theta = [theta] parametrization = ExponentialFamilyParametrization(theta=theta) - # parametrization.theta = theta return np.exp(np.dot(theta, alpha) + beta * self._log_partition(parametrization))[0] all_value = nquad(