From 59f94c59b0ff30f432999e13e7a52287d9f50e29 Mon Sep 17 00:00:00 2001 From: LeonidElkin Date: Sun, 15 Mar 2026 18:34:43 +0300 Subject: [PATCH 1/5] chore(gitignore): ignore AI-related local artifacts --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index a75017c..75e2774 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,7 @@ poetry.lock .cursorindexingignore # Docs + docs/build/* docs/source/api/* docs/source/examples/* @@ -137,3 +138,9 @@ docs/source/examples/* # Unuran build artifacts src/pysatl_core/sampling/unuran/bindings/_unuran_cffi.c + +# ai + +.ai/* +AGENTS.md +CLAUDE.md From fe62ce9b884cb3c3b8e3fd37326c68251c74458b Mon Sep 17 00:00:00 2001 From: LeonidElkin Date: Sun, 15 Mar 2026 18:39:15 +0300 Subject: [PATCH 2/5] refactor(distribution): add __init__ to `Distribution` to minimize boilerplate code --- src/pysatl_core/distributions/distribution.py | 126 +++++++++++++++--- src/pysatl_core/families/distribution.py | 87 +++--------- src/pysatl_core/families/parametric_family.py | 4 +- .../unit/families/test_distribution_cache.py | 33 +++-- .../families/test_family_and_distribution.py | 10 ++ tests/utils/mocks.py | 46 ++----- 6 files changed, 170 insertions(+), 136 deletions(-) diff --git a/src/pysatl_core/distributions/distribution.py b/src/pysatl_core/distributions/distribution.py index 6a072bd..9728462 100644 --- a/src/pysatl_core/distributions/distribution.py +++ b/src/pysatl_core/distributions/distribution.py @@ -12,22 +12,23 @@ __license__ = "SPDX-License-Identifier: MIT" from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping +from copy import deepcopy from typing import TYPE_CHECKING, Self, cast +from pysatl_core.distributions.strategies import ( + ComputationStrategy, + SamplingStrategy, +) from pysatl_core.types import NumericArray _KEEP: object = object() if TYPE_CHECKING: - from collections.abc import Mapping from typing import Any - from pysatl_core.distributions.computation import AnalyticalComputation - from pysatl_core.distributions.strategies import ( - ComputationStrategy, - SamplingStrategy, - ) + from pysatl_core.distributions.computation import AnalyticalComputation, Method from pysatl_core.distributions.support import Support from pysatl_core.types import ( DistributionType, @@ -58,27 +59,79 @@ class Distribution(ABC): Support of the distribution, if defined. """ + def __init__( + self, + distribution_type: DistributionType, + analytical_computations: ( + Iterable[AnalyticalComputation[Any, Any]] + | Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] + ), + support: Support | None = None, + sampling_strategy: SamplingStrategy | None = None, + computation_strategy: ComputationStrategy | None = None, + ) -> None: + """ + Initialize common distribution state. + + Parameters + ---------- + distribution_type : DistributionType + Type information about the distribution (kind, dimension, etc.). + analytical_computations : + Iterable[AnalyticalComputation] | Mapping[str, AnalyticalComputation] + Analytical computations provided by the distribution. + support : Support or None, default=None + Support of the distribution. + sampling_strategy : SamplingStrategy or None, default=None + Sampling strategy instance. If omitted, univariate default is used. + computation_strategy : ComputationStrategy or None, default=None + Computation strategy instance. If omitted, default strategy is used. + """ + from pysatl_core.distributions.strategies import ( + DefaultComputationStrategy, + DefaultSamplingUnivariateStrategy, + ) + + self._distribution_type = distribution_type + if isinstance(analytical_computations, Mapping): + normalized_analytical = dict(analytical_computations) + else: + normalized_analytical = {ac.target: ac for ac in analytical_computations} + + if not normalized_analytical: + raise ValueError("Distribution requires at least one analytical computation.") + + self._analytical_computations = normalized_analytical + self._support = support + self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy() + self._computation_strategy = computation_strategy or DefaultComputationStrategy() + @property - @abstractmethod - def distribution_type(self) -> DistributionType: ... + def distribution_type(self) -> DistributionType: + """Return type metadata of the distribution (kind, dimension, etc.).""" + return self._distribution_type @property - @abstractmethod def analytical_computations( self, - ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: ... + ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: + """Return analytical computations provided directly by this distribution.""" + return self._analytical_computations @property - @abstractmethod - def sampling_strategy(self) -> SamplingStrategy: ... + def sampling_strategy(self) -> SamplingStrategy: + """Return the currently attached sampling strategy.""" + return self._sampling_strategy @property - @abstractmethod - def computation_strategy(self) -> ComputationStrategy: ... + def computation_strategy(self) -> ComputationStrategy: + """Return the currently attached computation strategy.""" + return self._computation_strategy @property - @abstractmethod - def support(self) -> Support | None: ... + def support(self) -> Support | None: + """Return the support of the distribution, if it is defined.""" + return self._support @abstractmethod def _clone_with_strategies( @@ -86,7 +139,46 @@ def _clone_with_strategies( *, sampling_strategy: SamplingStrategy | None | object = _KEEP, computation_strategy: ComputationStrategy | None | object = _KEEP, - ) -> Distribution: ... + ) -> Distribution: + """ + Return a cloned distribution with updated strategies. + + The ``_KEEP`` sentinel means the existing strategy should be preserved + for that side. + """ + ... + + def _new_sampling_strategy( + self, + sampling_strategy: SamplingStrategy | None | object = _KEEP, + ) -> SamplingStrategy | None: + """ + Resolve sampling strategy for cloning. + + When ``sampling_strategy`` is ``_KEEP``, returns a deep copy of the + current sampling strategy. + """ + return cast( + SamplingStrategy | None, + deepcopy(self._sampling_strategy) if sampling_strategy is _KEEP else sampling_strategy, + ) + + def _new_computation_strategy( + self, + computation_strategy: ComputationStrategy | None | object = _KEEP, + ) -> ComputationStrategy | None: + """ + Resolve computation strategy for cloning. + + When ``computation_strategy`` is ``_KEEP``, returns a deep copy of the + current computation strategy. + """ + return cast( + ComputationStrategy | None, + deepcopy(self._computation_strategy) + if computation_strategy is _KEEP + else computation_strategy, + ) def with_sampling_strategy(self, sampling_strategy: SamplingStrategy | None) -> Self: """Return a copy of this distribution with an updated sampling strategy.""" diff --git a/src/pysatl_core/families/distribution.py b/src/pysatl_core/families/distribution.py index 45c77c4..7b57945 100644 --- a/src/pysatl_core/families/distribution.py +++ b/src/pysatl_core/families/distribution.py @@ -14,10 +14,6 @@ from typing import TYPE_CHECKING, cast from pysatl_core.distributions.distribution import _KEEP, Distribution -from pysatl_core.distributions.strategies import ( - DefaultComputationStrategy, - DefaultSamplingUnivariateStrategy, -) from pysatl_core.families.registry import ParametricFamilyRegister from pysatl_core.types import NumericArray @@ -72,34 +68,29 @@ def __init__( self, family_name: str, distribution_type: DistributionType, + analytical_computations: Mapping[ + GenericCharacteristicName, AnalyticalComputation[Any, Any] + ], parametrization: Parametrization, support: Support | None, sampling_strategy: SamplingStrategy | None = None, computation_strategy: ComputationStrategy | None = None, ): - self._distribution_type = distribution_type + super().__init__( + distribution_type=distribution_type, + analytical_computations=analytical_computations, + support=support, + sampling_strategy=sampling_strategy, + computation_strategy=computation_strategy, + ) self._family_name = family_name self._parametrization = parametrization - self._support = support - - self._computation_strategy = computation_strategy or DefaultComputationStrategy() - self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy() - - self._analytical_cache_key: tuple[int, GenericCharacteristicName] | None = None - self._analytical_cache_val: ( - Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] | None - ) = None @property def family_name(self) -> str: "Get the name of the family this distribution belongs to." return self._family_name - @property - def distribution_type(self) -> DistributionType: - """Get the distribution type.""" - return self._distribution_type - @property def parametrization(self) -> Parametrization: """ @@ -160,36 +151,6 @@ def family(self) -> ParametricFamily: """ return ParametricFamilyRegister.get(self.family_name) - @property - def analytical_computations( - self, - ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: - """ - Get analytical computations for this distribution. - - Lazily computed and cached per instance. Cache invalidates when - parametrization object or name changes. - """ - key = (id(self.parametrization), self.parametrization_name) - - if self._analytical_cache_key != key or self._analytical_cache_val is None: - self._analytical_cache_val = self.family.build_analytical_computations( - self.parametrization - ) - self._analytical_cache_key = key - - return self._analytical_cache_val - - @property - def sampling_strategy(self) -> SamplingStrategy: - """Get the sampling strategy for this distribution.""" - return self._sampling_strategy - - @property - def computation_strategy(self) -> ComputationStrategy: - """Get the computation strategy for this distribution.""" - return self._computation_strategy - def _clone_with_strategies( self, *, @@ -197,32 +158,18 @@ def _clone_with_strategies( computation_strategy: ComputationStrategy | None | object = _KEEP, ) -> ParametricFamilyDistribution: """Return a copy of this distribution with updated strategies.""" - new_sampling: SamplingStrategy | None = ( - self._sampling_strategy - if sampling_strategy is _KEEP - else cast(SamplingStrategy | None, sampling_strategy) - ) - - new_computation: ComputationStrategy | None = ( - self._computation_strategy - if computation_strategy is _KEEP - else cast(ComputationStrategy | None, computation_strategy) - ) - return ParametricFamilyDistribution( family_name=self._family_name, - distribution_type=self._distribution_type, + distribution_type=self.distribution_type, + analytical_computations=self.analytical_computations, parametrization=self._parametrization, - support=self._support, - sampling_strategy=new_sampling, - computation_strategy=new_computation, + support=self.support, + sampling_strategy=self._new_sampling_strategy(sampling_strategy=sampling_strategy), + computation_strategy=self._new_computation_strategy( + computation_strategy=computation_strategy + ), ) - @property - def support(self) -> Support | None: - """Get the support of this distribution.""" - return self._support - def sample(self, n: int, **options: Any) -> NumericArray: """ Generate random samples from the distribution. diff --git a/src/pysatl_core/families/parametric_family.py b/src/pysatl_core/families/parametric_family.py index 240480c..e9709f5 100644 --- a/src/pysatl_core/families/parametric_family.py +++ b/src/pysatl_core/families/parametric_family.py @@ -263,7 +263,7 @@ def _bind_parametrization[In, Out]( else func, ) - def build_analytical_computations( + def _build_analytical_computations( self, parameters: Parametrization ) -> dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: """ @@ -334,9 +334,11 @@ def distribution( parameters.validate() base_parameters = self.to_base(parameters) distribution_type = self._distr_type(base_parameters) + analytical_computations = self._build_analytical_computations(parameters) return ParametricFamilyDistribution( family_name=self.name, distribution_type=distribution_type, + analytical_computations=analytical_computations, parametrization=parameters, support=self.support_resolver(parameters), sampling_strategy=sampling_strategy, diff --git a/tests/unit/families/test_distribution_cache.py b/tests/unit/families/test_distribution_cache.py index d5eb8d8..8952a23 100644 --- a/tests/unit/families/test_distribution_cache.py +++ b/tests/unit/families/test_distribution_cache.py @@ -18,7 +18,7 @@ def _fallback_characteristics(self) -> dict[GenericCharacteristicName, dict[str, CharacteristicName.CDF: {"base": lambda params, x: params.value}, } - def test_cache_auto_invalidation(self) -> None: + def test_analytical_computations_are_built_at_distribution_creation(self) -> None: family = self.make_default_family(distr_characteristics=self._fallback_characteristics()) ParametricFamilyRegister.register(family) @@ -27,17 +27,18 @@ def test_cache_auto_invalidation(self) -> None: computations1 = distribution.analytical_computations computations1_again = distribution.analytical_computations - assert computations1 is computations1_again # cache hit + assert computations1 is computations1_again - # Replacing with a *new* object of the same parametrization should rebuild the cache + # Replacing with a *new* object of the same parametrization does not + # rebuild computations because they are now materialized at creation time. distribution._parametrization = family.parametrizations["alt"](value=5.0) # type: ignore[call-arg] computations2 = distribution.analytical_computations - assert computations2 is not computations1 + assert computations2 is computations1 - # Switching to the base parametrization should also rebuild the cache + # Switching to the base parametrization also does not rebuild computations. distribution._parametrization = family.parametrizations["base"](value=7.0) # type: ignore[call-arg] computations3 = distribution.analytical_computations - assert computations3 is not computations2 + assert computations3 is computations2 # Both mappings must contain PDF and CDF and be callable for mapping in (computations2, computations3): @@ -46,13 +47,12 @@ def test_cache_auto_invalidation(self) -> None: mapping[CharacteristicName.CDF] ) - # For alt(value=5.0) → fallback to base(value=5.0) - assert computations2[CharacteristicName.PDF](1.23) == pytest.approx(5.0) - assert computations2[CharacteristicName.CDF](0.5) == pytest.approx(5.0) + # Computations remain bound to initial alt(value=2.0) -> base(value=2.0) + assert computations2[CharacteristicName.PDF](1.23) == pytest.approx(2.0) + assert computations2[CharacteristicName.CDF](0.5) == pytest.approx(2.0) - # For base(value=7.0) - assert computations3[CharacteristicName.PDF](42.0) == pytest.approx(7.0) - assert computations3[CharacteristicName.CDF](0.0) == pytest.approx(7.0) + assert computations3[CharacteristicName.PDF](42.0) == pytest.approx(2.0) + assert computations3[CharacteristicName.CDF](0.0) == pytest.approx(2.0) def test_fallback_to_base_for_missing_form(self) -> None: family = self.make_default_family(distr_characteristics=self._fallback_characteristics()) @@ -67,3 +67,12 @@ def test_fallback_to_base_for_missing_form(self) -> None: # For alt(value=2.0) → base(value=2.0) assert computations[CharacteristicName.PDF](1.23) == pytest.approx(2.0) assert computations[CharacteristicName.CDF](0.5) == pytest.approx(2.0) + + def test_distribution_creation_requires_analytical_computations(self) -> None: + family = self.make_default_family(distr_characteristics={}) + ParametricFamilyRegister.register(family) + + with pytest.raises( + ValueError, match="Distribution requires at least one analytical computation." + ): + family.distribution("alt", value=2.0) diff --git a/tests/unit/families/test_family_and_distribution.py b/tests/unit/families/test_family_and_distribution.py index b13f614..a8e6659 100644 --- a/tests/unit/families/test_family_and_distribution.py +++ b/tests/unit/families/test_family_and_distribution.py @@ -46,3 +46,13 @@ def test_family_registration_and_distribution_sampling(self) -> None: } assert computations[CharacteristicName.CDF](0.25) == pytest.approx(0.25) assert computations[CharacteristicName.PPF](0.75) == pytest.approx(0.75) + + def test_distribution_clone_with_keep_strategies_copies_strategies(self) -> None: + fam = self.make_default_family() + ParametricFamilyRegister.register(fam) + + distr = fam.distribution("base", value=0.0) + cloned = distr.with_strategies() + + assert cloned.sampling_strategy is not distr.sampling_strategy + assert cloned.computation_strategy is not distr.computation_strategy diff --git a/tests/utils/mocks.py b/tests/utils/mocks.py index 2802699..31c48f7 100644 --- a/tests/utils/mocks.py +++ b/tests/utils/mocks.py @@ -13,8 +13,6 @@ from pysatl_core.distributions import ( AnalyticalComputation, ComputationStrategy, - DefaultComputationStrategy, - DefaultSamplingUnivariateStrategy, Distribution, SamplingStrategy, ) @@ -45,7 +43,7 @@ class StandaloneEuclideanUnivariateDistribution(Distribution): """ _distribution_type: EuclideanDistributionType - _analytical: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]] + _analytical_computations: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]] _support: Support | None def __init__( @@ -57,38 +55,11 @@ def __init__( ) = (), support: Support | None = None, ) -> None: - self._distribution_type = EuclideanDistributionType(kind, 1) - self._support = support - if isinstance(analytical_computations, Mapping): - self._analytical = dict(analytical_computations) - else: - self._analytical = {ac.target: ac for ac in analytical_computations} - - @property - def distribution_type(self) -> EuclideanDistributionType: - """Distribution type descriptor (kind and dimension).""" - return self._distribution_type - - @property - def analytical_computations( - self, - ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: - """Mapping from characteristic name to analytical callable.""" - return self._analytical - - @property - def sampling_strategy(self) -> SamplingStrategy: - """Sampling strategy instance.""" - return DefaultSamplingUnivariateStrategy() - - @property - def computation_strategy(self) -> ComputationStrategy: - """Computation strategy instance.""" - return DefaultComputationStrategy() - - @property - def support(self): - return self._support + super(StandaloneEuclideanUnivariateDistribution, self).__init__( + distribution_type=EuclideanDistributionType(kind, 1), + analytical_computations=analytical_computations, + support=support, + ) def _clone_with_strategies( self, @@ -97,4 +68,7 @@ def _clone_with_strategies( computation_strategy: ComputationStrategy | None | object = _KEEP, ) -> StandaloneEuclideanUnivariateDistribution: # Actually a stub - return StandaloneEuclideanUnivariateDistribution(Kind.CONTINUOUS) + return StandaloneEuclideanUnivariateDistribution( + Kind.CONTINUOUS, + analytical_computations=self.analytical_computations, + ) From 3704e70819180e6b9dcd7c395c19b797be5502ed Mon Sep 17 00:00:00 2001 From: LeonidElkin Date: Tue, 17 Mar 2026 19:46:05 +0300 Subject: [PATCH 3/5] refactor(distribution): now there can be several analytical calculations for one characteristic --- src/pysatl_core/distributions/distribution.py | 54 ++++++++--- src/pysatl_core/distributions/strategies.py | 33 ++++++- src/pysatl_core/families/distribution.py | 4 +- src/pysatl_core/families/parametric_family.py | 91 ++++++++++++++----- src/pysatl_core/types.py | 8 ++ tests/unit/distributions/test_basic.py | 81 +++++++++++++---- tests/unit/families/test_basic.py | 13 ++- .../unit/families/test_distribution_cache.py | 36 +++++--- .../families/test_family_and_distribution.py | 14 ++- tests/utils/mocks.py | 38 ++++++-- 10 files changed, 286 insertions(+), 86 deletions(-) diff --git a/src/pysatl_core/distributions/distribution.py b/src/pysatl_core/distributions/distribution.py index 9728462..40ebbb4 100644 --- a/src/pysatl_core/distributions/distribution.py +++ b/src/pysatl_core/distributions/distribution.py @@ -12,7 +12,7 @@ __license__ = "SPDX-License-Identifier: MIT" from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping +from collections.abc import Mapping from copy import deepcopy from typing import TYPE_CHECKING, Self, cast @@ -20,7 +20,7 @@ ComputationStrategy, SamplingStrategy, ) -from pysatl_core.types import NumericArray +from pysatl_core.types import DEFAULT_ANALYTICAL_COMPUTATION_LABEL, NumericArray _KEEP: object = object() @@ -28,11 +28,12 @@ if TYPE_CHECKING: from typing import Any - from pysatl_core.distributions.computation import AnalyticalComputation, Method + from pysatl_core.distributions.computation import AnalyticalComputation from pysatl_core.distributions.support import Support from pysatl_core.types import ( DistributionType, GenericCharacteristicName, + LabelName, Method, ) @@ -49,7 +50,13 @@ class Distribution(ABC): ---------- distribution_type : DistributionType Type information about the distribution (kind, dimension, etc.). - analytical_computations : Mapping[str, AnalyticalComputation] + analytical_computations : Mapping[ + GenericCharacteristicName, + ( + AnalyticalComputation[Any, Any] + | Mapping[LabelName, AnalyticalComputation[Any, Any]] + ), + ] Direct analytical computations provided by the distribution. sampling_strategy : SamplingStrategy Strategy for generating random samples. @@ -62,10 +69,10 @@ class Distribution(ABC): def __init__( self, distribution_type: DistributionType, - analytical_computations: ( - Iterable[AnalyticalComputation[Any, Any]] - | Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] - ), + analytical_computations: Mapping[ + GenericCharacteristicName, + (AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]]), + ], support: Support | None = None, sampling_strategy: SamplingStrategy | None = None, computation_strategy: ComputationStrategy | None = None, @@ -78,7 +85,13 @@ def __init__( distribution_type : DistributionType Type information about the distribution (kind, dimension, etc.). analytical_computations : - Iterable[AnalyticalComputation] | Mapping[str, AnalyticalComputation] + Mapping[ + GenericCharacteristicName, + ( + AnalyticalComputation[Any, Any] + | Mapping[LabelName, AnalyticalComputation[Any, Any]] + ), + ] Analytical computations provided by the distribution. support : Support or None, default=None Support of the distribution. @@ -93,14 +106,27 @@ def __init__( ) self._distribution_type = distribution_type - if isinstance(analytical_computations, Mapping): - normalized_analytical = dict(analytical_computations) - else: - normalized_analytical = {ac.target: ac for ac in analytical_computations} + normalized_analytical: dict[ + GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]] + ] = {} + for characteristic_name, methods in analytical_computations.items(): + if isinstance(methods, Mapping): + normalized_analytical[characteristic_name] = dict(methods) + else: + normalized_analytical[characteristic_name] = { + DEFAULT_ANALYTICAL_COMPUTATION_LABEL: methods + } if not normalized_analytical: raise ValueError("Distribution requires at least one analytical computation.") + for characteristic_name, labeled_methods in normalized_analytical.items(): + if not labeled_methods: + raise ValueError( + f"Characteristic '{characteristic_name}' must provide at least one " + "analytical computation." + ) + self._analytical_computations = normalized_analytical self._support = support self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy() @@ -114,7 +140,7 @@ def distribution_type(self) -> DistributionType: @property def analytical_computations( self, - ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: + ) -> Mapping[GenericCharacteristicName, Mapping[LabelName, AnalyticalComputation[Any, Any]]]: """Return analytical computations provided directly by this distribution.""" return self._analytical_computations diff --git a/src/pysatl_core/distributions/strategies.py b/src/pysatl_core/distributions/strategies.py index f044385..77425e9 100644 --- a/src/pysatl_core/distributions/strategies.py +++ b/src/pysatl_core/distributions/strategies.py @@ -19,11 +19,16 @@ from pysatl_core.types import CharacteristicName, Method, NumericArray if TYPE_CHECKING: + from collections.abc import Mapping from typing import Any - from pysatl_core.distributions.computation import FittedComputationMethod + from pysatl_core.distributions.computation import ( + AnalyticalComputation, + FittedComputationMethod, + Method, + ) from pysatl_core.distributions.distribution import Distribution - from pysatl_core.types import GenericCharacteristicName + from pysatl_core.types import GenericCharacteristicName, LabelName class ComputationStrategy(Protocol): @@ -101,6 +106,26 @@ def _pop_guard(self, distr: Distribution, state: GenericCharacteristicName) -> N if not seen: self._resolving.pop(key, None) + @staticmethod + def _pick_analytical_method( + state: GenericCharacteristicName, + methods: Mapping[LabelName, AnalyticalComputation[Any, Any]], + ) -> AnalyticalComputation[Any, Any]: + """ + Pick the first available analytical method for a characteristic. + + Raises + ------ + RuntimeError + If no labeled analytical methods are available for the characteristic. + """ + try: + return next(iter(methods.values())) + except StopIteration as exc: + raise RuntimeError( + f"Characteristic '{state}' provides no labeled analytical computations." + ) from exc + def query_method( self, state: GenericCharacteristicName, distr: Distribution, **options: Any ) -> Method[Any, Any]: @@ -134,7 +159,7 @@ def query_method( """ # 1. Check for analytical implementation if state in distr.analytical_computations: - return distr.analytical_computations[state] + return self._pick_analytical_method(state, distr.analytical_computations[state]) # 2. Check cache if enabled if self._enable_caching: @@ -156,7 +181,7 @@ def query_method( # 5. Try each analytical characteristic as a source for src in distr.analytical_computations: if src == state: - return distr.analytical_computations[src] + return self._pick_analytical_method(src, distr.analytical_computations[src]) # Find conversion path in the graph path = reg.find_path(src, state) diff --git a/src/pysatl_core/families/distribution.py b/src/pysatl_core/families/distribution.py index 7b57945..166b98c 100644 --- a/src/pysatl_core/families/distribution.py +++ b/src/pysatl_core/families/distribution.py @@ -35,6 +35,7 @@ from pysatl_core.types import ( DistributionType, GenericCharacteristicName, + LabelName, ParametrizationName, ) @@ -69,7 +70,8 @@ def __init__( family_name: str, distribution_type: DistributionType, analytical_computations: Mapping[ - GenericCharacteristicName, AnalyticalComputation[Any, Any] + GenericCharacteristicName, + Mapping[LabelName, AnalyticalComputation[Any, Any]], ], parametrization: Parametrization, support: Support | None, diff --git a/src/pysatl_core/families/parametric_family.py b/src/pysatl_core/families/parametric_family.py index e9709f5..091eac6 100644 --- a/src/pysatl_core/families/parametric_family.py +++ b/src/pysatl_core/families/parametric_family.py @@ -19,7 +19,11 @@ from pysatl_core.distributions.computation import AnalyticalComputation from pysatl_core.families.distribution import ParametricFamilyDistribution -from pysatl_core.types import ComputationFunc, DistributionType +from pysatl_core.types import ( + DEFAULT_ANALYTICAL_COMPUTATION_LABEL, + ComputationFunc, + DistributionType, +) if TYPE_CHECKING: from collections.abc import Callable @@ -31,13 +35,17 @@ ) from pysatl_core.types import ( GenericCharacteristicName, + LabelName, ParametrizationName, ) type SupportArg = Callable[[Parametrization], Support | None] | None type SupportResolver = Callable[[Parametrization], Support | None] + type LabeledCharacteristicProvider = ( + Mapping[LabelName, CharacteristicFunction[Any, Any]] | CharacteristicFunction[Any, Any] + ) type CharacteristicProvider = ( - Mapping[ParametrizationName, CharacteristicFunction[Any, Any]] + Mapping[ParametrizationName, LabeledCharacteristicProvider] | CharacteristicFunction[Any, Any] ) type CharacteristicsMap = Mapping[GenericCharacteristicName, CharacteristicProvider] @@ -78,7 +86,9 @@ class ParametricFamily: - nullary characteristics (e.g., mean, var): provider(params, **kwargs) -> Any - pointwise characteristics (e.g., pdf, cdf, ppf): provider(params, x, **kwargs) -> Any - If a single callable is provided, it is treated as defined for the base parametrization. + Providers are grouped by parametrization and may define multiple labeled methods. + If a single callable is provided, it is treated as the base-parametrization method + under ``DEFAULT_ANALYTICAL_COMPUTATION_LABEL``. support_by_parametrization : Callable or None, optional Function that returns support for given parameters. """ @@ -109,31 +119,63 @@ def __init__( # Runtime registry of parametrization classes self._parametrizations: dict[ParametrizationName, type[Parametrization]] = {} - def _normalize_characteristic( - value: Mapping[ParametrizationName, CharacteristicFunction[Any, Any]] - | CharacteristicFunction[Any, Any], - ) -> dict[ParametrizationName, CharacteristicFunction[Any, Any]]: - return ( - dict(value) - if isinstance(value, Mapping) - else {self.base_parametrization_name: value} + def _normalize_labeled_provider( + characteristic_name: GenericCharacteristicName, + parametrization_name: ParametrizationName, + provider: LabeledCharacteristicProvider, + ) -> dict[LabelName, CharacteristicFunction[Any, Any]]: + normalized = ( + dict(provider) + if isinstance(provider, Mapping) + else {DEFAULT_ANALYTICAL_COMPUTATION_LABEL: provider} ) + if not normalized: + raise ValueError( + f"Characteristic '{characteristic_name}' has no labeled providers for " + f"parametrization '{parametrization_name}'." + ) + return normalized + + def _normalize_characteristic( + characteristic_name: GenericCharacteristicName, + value: CharacteristicProvider, + ) -> dict[ParametrizationName, dict[LabelName, CharacteristicFunction[Any, Any]]]: + if not isinstance(value, Mapping): + base_name = self.base_parametrization_name + return { + base_name: _normalize_labeled_provider(characteristic_name, base_name, value) + } + + normalized_by_parametrization: dict[ + ParametrizationName, dict[LabelName, CharacteristicFunction[Any, Any]] + ] = {} + for parametrization_name, provider in value.items(): + normalized_by_parametrization[parametrization_name] = _normalize_labeled_provider( + characteristic_name, + parametrization_name, + provider, + ) + return normalized_by_parametrization self.distr_characteristics: dict[ - GenericCharacteristicName, dict[ParametrizationName, CharacteristicFunction[Any, Any]] - ] = {k: _normalize_characteristic(v) for k, v in distr_characteristics.items()} + GenericCharacteristicName, + dict[ParametrizationName, dict[LabelName, CharacteristicFunction[Any, Any]]], + ] = { + characteristic_name: _normalize_characteristic(characteristic_name, provider) + for characteristic_name, provider in distr_characteristics.items() + } # Validate characteristic providers valid_names = set(self.parametrization_names) for char_name, forms in self.distr_characteristics.items(): + if not forms: + raise ValueError(f"Characteristic '{char_name}' has no providers.") unknown = set(forms) - valid_names if unknown: raise ValueError( f"Characteristic '{char_name}' has providers for unknown parametrizations: " f"{sorted(unknown)}." ) - if self.base_parametrization_name not in forms and len(forms) == 0: - raise ValueError(f"Characteristic '{char_name}' has no providers.") # Precompute analytical plan: for each parametrization pick provider (self or base) self._analytical_plan: dict[ @@ -265,14 +307,16 @@ def _bind_parametrization[In, Out]( def _build_analytical_computations( self, parameters: Parametrization - ) -> dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: + ) -> dict[GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]]]: """ Build analytical computations for given parameters. Uses precomputed provider plan for efficient computation. """ plan = self._analytical_plan.get(parameters.name, {}) - result: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]] = {} + result: dict[ + GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]] + ] = {} base_params: Parametrization | None = None for characteristic, provider_name in plan.items(): @@ -282,11 +326,14 @@ def _build_analytical_computations( base_params = base_params or self.to_base(parameters) params_obj = base_params - func_factory = self.distr_characteristics[characteristic][provider_name] - result[characteristic] = AnalyticalComputation( - target=characteristic, - func=self._bind_parametrization(func_factory, params_obj), - ) + labeled_providers = self.distr_characteristics[characteristic][provider_name] + result[characteristic] = { + label_name: AnalyticalComputation( + target=characteristic, + func=self._bind_parametrization(func_factory, params_obj), + ) + for label_name, func_factory in labeled_providers.items() + } return result diff --git a/src/pysatl_core/types.py b/src/pysatl_core/types.py index 4abfce0..150c9f1 100644 --- a/src/pysatl_core/types.py +++ b/src/pysatl_core/types.py @@ -243,6 +243,12 @@ def shape(self) -> ContinuousSupportShape1D: type GenericCharacteristicName = str """Type alias for characteristic names (e.g., 'pdf', 'cdf').""" +type LabelName = str +"""Type alias for labels that distinguish computation variants.""" + +DEFAULT_ANALYTICAL_COMPUTATION_LABEL: LabelName = "PySATL_default_analytical_computation" +"""Default label for analytical methods when a label is not explicitly provided.""" + type ParametrizationName = str """Type alias for parametrization names.""" @@ -302,6 +308,8 @@ class FamilyName(StrEnum): "UnivariateContinuous", "UnivariateDiscrete", "GenericCharacteristicName", + "LabelName", + "DEFAULT_ANALYTICAL_COMPUTATION_LABEL", "ParametrizationName", "ComputationFunc", "DistributionType", diff --git a/tests/unit/distributions/test_basic.py b/tests/unit/distributions/test_basic.py index 0dccbe1..2c386c6 100644 --- a/tests/unit/distributions/test_basic.py +++ b/tests/unit/distributions/test_basic.py @@ -8,6 +8,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any, cast +import pytest from mypy_extensions import KwArg from pysatl_core.distributions.computation import ( @@ -19,7 +20,11 @@ ContinuousSupport, ExplicitTableDiscreteSupport, ) -from pysatl_core.types import CharacteristicName, Kind +from pysatl_core.types import ( + DEFAULT_ANALYTICAL_COMPUTATION_LABEL, + CharacteristicName, + Kind, +) from tests.utils.mocks import ( StandaloneEuclideanUnivariateDistribution, ) @@ -27,6 +32,8 @@ if TYPE_CHECKING: from collections.abc import Sequence +DEFAULT_ANALYTICAL_LABEL = "default" + class DistributionTestBase: def make_uniform_ppf_distribution( @@ -35,9 +42,13 @@ def make_uniform_ppf_distribution( ppf_func = cast(Callable[[float, KwArg(Any)], float], lambda q, **kwargs: q) return StandaloneEuclideanUnivariateDistribution( kind=Kind.CONTINUOUS, - analytical_computations=[ - AnalyticalComputation[float, float](target=CharacteristicName.PPF, func=ppf_func), - ], + analytical_computations={ + CharacteristicName.PPF: { + DEFAULT_ANALYTICAL_LABEL: AnalyticalComputation[float, float]( + target=CharacteristicName.PPF, func=ppf_func + ) + } + }, support=ContinuousSupport(0, 1), ) @@ -50,11 +61,13 @@ def logistic_cdf(x: float, **_: Any) -> float: logistic_cdf_func = cast(Callable[[float, KwArg(Any)], float], logistic_cdf) return StandaloneEuclideanUnivariateDistribution( kind=Kind.CONTINUOUS, - analytical_computations=[ - AnalyticalComputation[float, float]( - target=CharacteristicName.CDF, func=logistic_cdf_func - ), - ], + analytical_computations={ + CharacteristicName.CDF: { + DEFAULT_ANALYTICAL_LABEL: AnalyticalComputation[float, float]( + target=CharacteristicName.CDF, func=logistic_cdf_func + ) + } + }, support=ContinuousSupport(), ) @@ -68,11 +81,13 @@ def uniform_pdf(x: float, **_: Any) -> float: return StandaloneEuclideanUnivariateDistribution( kind=Kind.CONTINUOUS, - analytical_computations=[ - AnalyticalComputation[float, float]( - target=CharacteristicName.PDF, func=uniform_pdf_func - ), - ], + analytical_computations={ + CharacteristicName.PDF: { + DEFAULT_ANALYTICAL_LABEL: AnalyticalComputation[float, float]( + target=CharacteristicName.PDF, func=uniform_pdf_func + ) + } + }, support=ContinuousSupport(0, 1), ) @@ -90,9 +105,13 @@ def pmf(x: float) -> float: return StandaloneEuclideanUnivariateDistribution( kind=Kind.DISCRETE, - analytical_computations=[ - AnalyticalComputation[float, float](target=CharacteristicName.PMF, func=pmf_func), - ], + analytical_computations={ + CharacteristicName.PMF: { + DEFAULT_ANALYTICAL_LABEL: AnalyticalComputation[float, float]( + target=CharacteristicName.PMF, func=pmf_func + ) + } + }, support=support, ) @@ -109,3 +128,31 @@ def _impl(*_args: Any, **_kwargs: Any) -> Any: return ComputationMethod( target=target, sources=sources, fitter=lambda *_a, **_k: _fitted_const(None) ) + + +class TestDistributionInitialization: + def test_distribution_accepts_unlabeled_analytical_mapping(self) -> None: + ppf_func = cast(Callable[[float, KwArg(Any)], float], lambda q, **_kwargs: q) + ppf_method = AnalyticalComputation[float, float]( + target=CharacteristicName.PPF, func=ppf_func + ) + + distr = StandaloneEuclideanUnivariateDistribution( + kind=Kind.CONTINUOUS, + analytical_computations={CharacteristicName.PPF: ppf_method}, + support=ContinuousSupport(0, 1), + ) + + methods = distr.analytical_computations[CharacteristicName.PPF] + assert set(methods.keys()) == {DEFAULT_ANALYTICAL_COMPUTATION_LABEL} + assert methods[DEFAULT_ANALYTICAL_COMPUTATION_LABEL](0.42) == pytest.approx(0.42) + + def test_distribution_rejects_empty_labeled_analytical_computations(self) -> None: + with pytest.raises( + ValueError, + match="Characteristic 'cdf' must provide at least one analytical computation.", + ): + StandaloneEuclideanUnivariateDistribution( + kind=Kind.CONTINUOUS, + analytical_computations={CharacteristicName.CDF: {}}, + ) diff --git a/tests/unit/families/test_basic.py b/tests/unit/families/test_basic.py index eff186a..4e9fa2e 100644 --- a/tests/unit/families/test_basic.py +++ b/tests/unit/families/test_basic.py @@ -15,13 +15,18 @@ class TestBaseFamily: def make_default_family( self, - distr_characteristics: dict[GenericCharacteristicName, dict[str, object]] | None = None, + distr_characteristics: ( + dict[GenericCharacteristicName, dict[str, dict[str, object]]] | None + ) = None, ) -> ParametricFamily: if distr_characteristics is None: distr_characteristics = { - CharacteristicName.PDF: {"base": lambda p, x: x}, - CharacteristicName.CDF: {"alt": lambda p, x: x, "base": lambda p, x: x}, - CharacteristicName.PPF: {"base": lambda p, x: x}, + CharacteristicName.PDF: {"base": {"default": lambda p, x: x}}, + CharacteristicName.CDF: { + "alt": {"default": lambda p, x: x}, + "base": {"default": lambda p, x: x}, + }, + CharacteristicName.PPF: {"base": {"default": lambda p, x: x}}, } fam = ParametricFamily( name="Default", diff --git a/tests/unit/families/test_distribution_cache.py b/tests/unit/families/test_distribution_cache.py index 8952a23..7d97b7c 100644 --- a/tests/unit/families/test_distribution_cache.py +++ b/tests/unit/families/test_distribution_cache.py @@ -12,10 +12,12 @@ class TestAnalyticalComputationCache(TestBaseFamily): - def _fallback_characteristics(self) -> dict[GenericCharacteristicName, dict[str, object]]: + def _fallback_characteristics( + self, + ) -> dict[GenericCharacteristicName, dict[str, dict[str, object]]]: return { - CharacteristicName.PDF: {"base": lambda params, x: params.value}, - CharacteristicName.CDF: {"base": lambda params, x: params.value}, + CharacteristicName.PDF: {"base": {"default": lambda params, x: params.value}}, + CharacteristicName.CDF: {"base": {"default": lambda params, x: params.value}}, } def test_analytical_computations_are_built_at_distribution_creation(self) -> None: @@ -43,16 +45,15 @@ def test_analytical_computations_are_built_at_distribution_creation(self) -> Non # Both mappings must contain PDF and CDF and be callable for mapping in (computations2, computations3): assert CharacteristicName.PDF in mapping and CharacteristicName.CDF in mapping - assert callable(mapping[CharacteristicName.PDF]) and callable( - mapping[CharacteristicName.CDF] - ) + assert callable(mapping[CharacteristicName.PDF]["default"]) + assert callable(mapping[CharacteristicName.CDF]["default"]) # Computations remain bound to initial alt(value=2.0) -> base(value=2.0) - assert computations2[CharacteristicName.PDF](1.23) == pytest.approx(2.0) - assert computations2[CharacteristicName.CDF](0.5) == pytest.approx(2.0) + assert computations2[CharacteristicName.PDF]["default"](1.23) == pytest.approx(2.0) + assert computations2[CharacteristicName.CDF]["default"](0.5) == pytest.approx(2.0) - assert computations3[CharacteristicName.PDF](42.0) == pytest.approx(2.0) - assert computations3[CharacteristicName.CDF](0.0) == pytest.approx(2.0) + assert computations3[CharacteristicName.PDF]["default"](42.0) == pytest.approx(2.0) + assert computations3[CharacteristicName.CDF]["default"](0.0) == pytest.approx(2.0) def test_fallback_to_base_for_missing_form(self) -> None: family = self.make_default_family(distr_characteristics=self._fallback_characteristics()) @@ -65,8 +66,8 @@ def test_fallback_to_base_for_missing_form(self) -> None: assert CharacteristicName.PDF in computations and CharacteristicName.CDF in computations # For alt(value=2.0) → base(value=2.0) - assert computations[CharacteristicName.PDF](1.23) == pytest.approx(2.0) - assert computations[CharacteristicName.CDF](0.5) == pytest.approx(2.0) + assert computations[CharacteristicName.PDF]["default"](1.23) == pytest.approx(2.0) + assert computations[CharacteristicName.CDF]["default"](0.5) == pytest.approx(2.0) def test_distribution_creation_requires_analytical_computations(self) -> None: family = self.make_default_family(distr_characteristics={}) @@ -76,3 +77,14 @@ def test_distribution_creation_requires_analytical_computations(self) -> None: ValueError, match="Distribution requires at least one analytical computation." ): family.distribution("alt", value=2.0) + + def test_family_creation_rejects_empty_labeled_providers(self) -> None: + with pytest.raises( + ValueError, + match=("Characteristic 'pdf' has no labeled providers for parametrization 'base'."), + ): + self.make_default_family( + distr_characteristics={ + CharacteristicName.PDF: {"base": {}}, + } + ) diff --git a/tests/unit/families/test_family_and_distribution.py b/tests/unit/families/test_family_and_distribution.py index a8e6659..d06359b 100644 --- a/tests/unit/families/test_family_and_distribution.py +++ b/tests/unit/families/test_family_and_distribution.py @@ -16,11 +16,15 @@ class TestFamilyRegistrationAndSampling(TestBaseFamily): def test_family_registration_and_distribution_sampling(self) -> None: fam = self.make_default_family( distr_characteristics={ - CharacteristicName.PDF: {"base": lambda p, x: 1.0 if 0.0 <= x <= 1.0 else 0.0}, + CharacteristicName.PDF: { + "base": {"default": lambda p, x: 1.0 if 0.0 <= x <= 1.0 else 0.0} + }, CharacteristicName.CDF: { - "base": lambda p, x: x if 0.0 <= x <= 1.0 else (0.0 if x < 0.0 else 1.0) + "base": { + "default": lambda p, x: x if 0.0 <= x <= 1.0 else (0.0 if x < 0.0 else 1.0) + } }, - CharacteristicName.PPF: {"base": lambda p, q: q}, + CharacteristicName.PPF: {"base": {"default": lambda p, q: q}}, }, ) @@ -44,8 +48,8 @@ def test_family_registration_and_distribution_sampling(self) -> None: CharacteristicName.CDF, CharacteristicName.PPF, } - assert computations[CharacteristicName.CDF](0.25) == pytest.approx(0.25) - assert computations[CharacteristicName.PPF](0.75) == pytest.approx(0.75) + assert computations[CharacteristicName.CDF]["default"](0.25) == pytest.approx(0.25) + assert computations[CharacteristicName.PPF]["default"](0.75) == pytest.approx(0.75) def test_distribution_clone_with_keep_strategies_copies_strategies(self) -> None: fam = self.make_default_family() diff --git a/tests/utils/mocks.py b/tests/utils/mocks.py index 31c48f7..767f236 100644 --- a/tests/utils/mocks.py +++ b/tests/utils/mocks.py @@ -4,7 +4,7 @@ __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" -from collections.abc import Iterable, Mapping +from collections.abc import Mapping, Sequence from dataclasses import dataclass from typing import Any @@ -19,12 +19,19 @@ from pysatl_core.distributions.distribution import _KEEP from pysatl_core.distributions.support import Support from pysatl_core.types import ( + CharacteristicName, EuclideanDistributionType, GenericCharacteristicName, Kind, + LabelName, NumericArray, ) +type MockAnalyticalComputations = Mapping[ + GenericCharacteristicName, + AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]], +] + class MockSamplingStrategy(SamplingStrategy): def sample(self, n: int, distr: Distribution, **options: Any) -> NumericArray: @@ -43,23 +50,40 @@ class StandaloneEuclideanUnivariateDistribution(Distribution): """ _distribution_type: EuclideanDistributionType - _analytical_computations: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]] + _analytical_computations: dict[ + GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]] + ] _support: Support | None def __init__( self, kind: Kind, - analytical_computations: ( - Iterable[AnalyticalComputation[Any, Any]] - | Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] - ) = (), + analytical_computations: MockAnalyticalComputations + | Sequence[AnalyticalComputation[Any, Any]], support: Support | None = None, ) -> None: + force_empty_analytical = False + if isinstance(analytical_computations, Mapping): + normalized_analytical: MockAnalyticalComputations = analytical_computations + else: + normalized_analytical = {comp.target: comp for comp in analytical_computations} + if not normalized_analytical: + # Keep backward compatibility with legacy tests that passed an empty list. + normalized_analytical = { + CharacteristicName.MEAN: AnalyticalComputation[Any, float]( + target=CharacteristicName.MEAN, + func=lambda **_kwargs: 0.0, + ) + } + force_empty_analytical = True + super(StandaloneEuclideanUnivariateDistribution, self).__init__( distribution_type=EuclideanDistributionType(kind, 1), - analytical_computations=analytical_computations, + analytical_computations=normalized_analytical, support=support, ) + if force_empty_analytical: + self._analytical_computations = {} def _clone_with_strategies( self, From e5c30bb128a684707dff805f29ed326e703588e7 Mon Sep 17 00:00:00 2001 From: LeonidElkin Date: Tue, 17 Mar 2026 20:53:23 +0300 Subject: [PATCH 4/5] feat(registry): add looped edges for graph view for analytical computations --- .../distributions/registry/__init__.py | 4 + .../distributions/registry/graph.py | 105 ++++++++++++++---- .../registry/graph_primitives.py | 55 +++++++-- src/pysatl_core/distributions/strategies.py | 61 +++++++--- tests/unit/distributions/test_registry.py | 72 +++++++++++- 5 files changed, 245 insertions(+), 52 deletions(-) diff --git a/src/pysatl_core/distributions/registry/__init__.py b/src/pysatl_core/distributions/registry/__init__.py index 60b429a..9aa13bd 100644 --- a/src/pysatl_core/distributions/registry/__init__.py +++ b/src/pysatl_core/distributions/registry/__init__.py @@ -25,6 +25,8 @@ ) from .graph_primitives import ( DEFAULT_COMPUTATION_KEY, + AnalyticalLoopEdgeMeta, + ComputationEdgeMeta, EdgeMeta, GraphInvariantError, ) @@ -33,6 +35,8 @@ # Graph primitives and constants "DEFAULT_COMPUTATION_KEY", "EdgeMeta", + "ComputationEdgeMeta", + "AnalyticalLoopEdgeMeta", "GraphInvariantError", # Constraint types "Constraint", diff --git a/src/pysatl_core/distributions/registry/graph.py b/src/pysatl_core/distributions/registry/graph.py index ca379d8..cddc8ed 100644 --- a/src/pysatl_core/distributions/registry/graph.py +++ b/src/pysatl_core/distributions/registry/graph.py @@ -29,6 +29,8 @@ from pysatl_core.distributions.registry.constraint import GraphPrimitiveConstraint from pysatl_core.distributions.registry.graph_primitives import ( DEFAULT_COMPUTATION_KEY, + AnalyticalLoopEdgeMeta, + ComputationEdgeMeta, EdgeMeta, GraphInvariantError, ) @@ -36,7 +38,7 @@ if TYPE_CHECKING: from pysatl_core.distributions.computation import ComputationMethod from pysatl_core.distributions.distribution import Distribution - from pysatl_core.types import GenericCharacteristicName + from pysatl_core.types import GenericCharacteristicName, LabelName # --------------------------------------------------------------------------- # @@ -89,7 +91,7 @@ def __init__(self) -> None: # Adjacency: src → dst → label → [EdgeMeta] self._adj: dict[ GenericCharacteristicName, - dict[GenericCharacteristicName, dict[str, list[EdgeMeta]]], + dict[GenericCharacteristicName, dict[LabelName, list[EdgeMeta]]], ] = {} self._all_nodes: set[GenericCharacteristicName] = set() @@ -98,7 +100,7 @@ def __init__(self) -> None: self._def_rules: dict[GenericCharacteristicName, GraphPrimitiveConstraint] = {} # Label preference for path finding - self.label_preference: tuple[str, ...] = (DEFAULT_COMPUTATION_KEY,) + self.label_preference: tuple[LabelName, ...] = (DEFAULT_COMPUTATION_KEY,) self._initialized = True @@ -128,6 +130,11 @@ def _ensure_node(self, node: GenericCharacteristicName) -> bool: """Check if a node has been declared via add_characteristic().""" return node in self._all_nodes + @property + def declared_characteristics(self) -> set[GenericCharacteristicName]: + """Return declared registry characteristics.""" + return set(self._all_nodes) + def _add_presence_rule( self, name: GenericCharacteristicName, constraint: GraphPrimitiveConstraint | None ) -> None: @@ -168,7 +175,7 @@ def add_computation( self, method: ComputationMethod[Any, Any], *, - label: str = DEFAULT_COMPUTATION_KEY, + label: LabelName = DEFAULT_COMPUTATION_KEY, constraint: GraphPrimitiveConstraint | None = None, ) -> None: """ @@ -178,7 +185,7 @@ def add_computation( ---------- method : ComputationMethod Computation object with exactly one source and one target. - label : str, default=DEFAULT_COMPUTATION_KEY + label : LabelName, default=DEFAULT_COMPUTATION_KEY Variant label for the edge. constraint : GraphPrimitiveConstraint, optional Edge applicability constraint. If None, a pass-through constraint is used. @@ -208,7 +215,7 @@ def add_computation( # constraints self._adj[src][dst].setdefault(label, []) self._adj[src][dst][label].append( - EdgeMeta( + ComputationEdgeMeta( method=method, constraint=constraint or GraphPrimitiveConstraint(), ) @@ -292,6 +299,33 @@ def _compute_definitive_nodes(self, distr: Distribution) -> set[GenericCharacter definitive.add(name) return definitive + @staticmethod + def _attach_analytical_loops( + adj: dict[ + GenericCharacteristicName, + dict[GenericCharacteristicName, dict[LabelName, EdgeMeta]], + ], + distr: Distribution, + present_nodes: set[GenericCharacteristicName], + ) -> None: + """ + Attach analytical self-loops for distribution-provided computations. + + Notes + ----- + Analytical loops are only added for characteristics present in this view. + Each labeled analytical computation becomes one loop edge ``char -> char``. + """ + for characteristic_name, labeled_methods in distr.analytical_computations.items(): + if characteristic_name not in present_nodes: + continue + + loop_variants = adj.setdefault(characteristic_name, {}).setdefault( + characteristic_name, {} + ) + for label_name, analytical_method in labeled_methods.items(): + loop_variants[label_name] = AnalyticalLoopEdgeMeta(method=analytical_method) + def view(self, distr: Distribution) -> RegistryView: """ Create a filtered view of the graph for the given distribution. @@ -310,16 +344,17 @@ def view(self, distr: Distribution) -> RegistryView: ----- 1. Filters edges by their constraints 2. Removes edges touching absent nodes - 3. Computes definitive nodes from the remaining present nodes - 4. Validates graph invariants + 3. Adds analytical self-loops from distribution analytical computations + 4. Computes definitive nodes from the remaining present nodes + 5. Validates graph invariants """ # 1) Filter edges by applicability adj: dict[ - GenericCharacteristicName, dict[GenericCharacteristicName, dict[str, EdgeMeta]] + GenericCharacteristicName, dict[GenericCharacteristicName, dict[LabelName, EdgeMeta]] ] = {} for src, d in self._adj.items(): for dst, variants in d.items(): - kept: dict[str, EdgeMeta] = {} + kept: dict[LabelName, EdgeMeta] = {} for label, metas in variants.items(): for edge in metas: if edge.constraint.allows(distr): @@ -343,7 +378,10 @@ def view(self, distr: Distribution) -> RegistryView: for node in present_nodes: adj.setdefault(node, {}) - # 3) Compute definitive nodes (must be present) + # 3) Attach analytical loops + self._attach_analytical_loops(adj, distr, present_nodes) + + # 4) Compute definitive nodes (must be present) definitive_nodes = self._compute_definitive_nodes(distr) & present_nodes return RegistryView(adj, definitive_nodes, present_nodes) @@ -387,14 +425,15 @@ def __init__( self, adj: Mapping[ GenericCharacteristicName, - Mapping[GenericCharacteristicName, Mapping[str, EdgeMeta]], + Mapping[GenericCharacteristicName, Mapping[LabelName, EdgeMeta]], ], definitive_nodes: set[GenericCharacteristicName], present_nodes: set[GenericCharacteristicName], ) -> None: # Deep copy adjacency to ensure immutability self._adj: dict[ - GenericCharacteristicName, dict[GenericCharacteristicName, dict[str, EdgeMeta]] + GenericCharacteristicName, + dict[GenericCharacteristicName, dict[LabelName, EdgeMeta]], ] = {} for src, d in adj.items(): self._adj[src] = {dst: dict(variants) for dst, variants in d.items()} @@ -419,7 +458,7 @@ def indefinitive_characteristics(self) -> set[GenericCharacteristicName]: def successors( self, v: GenericCharacteristicName - ) -> Mapping[GenericCharacteristicName, Mapping[str, EdgeMeta]]: + ) -> Mapping[GenericCharacteristicName, Mapping[LabelName, EdgeMeta]]: """ Get outgoing edges from a characteristic. @@ -430,7 +469,7 @@ def successors( Returns ------- - Mapping[str, Mapping[str, EdgeMeta]] + Mapping[str, Mapping[LabelName, EdgeMeta]] Destination → label → edge metadata. """ return self._adj.get(v, {}) @@ -473,7 +512,7 @@ def predecessors(self, v: GenericCharacteristicName) -> set[GenericCharacteristi def variants( self, src: GenericCharacteristicName, dst: GenericCharacteristicName - ) -> Mapping[str, EdgeMeta]: + ) -> Mapping[LabelName, EdgeMeta]: """ Get all labeled edges between two characteristics. @@ -484,17 +523,35 @@ def variants( Returns ------- - Mapping[str, EdgeMeta] + Mapping[LabelName, EdgeMeta] Label → edge metadata mapping. """ return self._adj.get(src, {}).get(dst, {}) + def analytical_variants(self, state: GenericCharacteristicName) -> Mapping[LabelName, EdgeMeta]: + """ + Get analytical self-loop variants for a characteristic. + + Parameters + ---------- + state : str + Characteristic name. + + Returns + ------- + Mapping[LabelName, EdgeMeta] + Label → analytical loop metadata for ``state -> state``. + """ + return { + label: edge for label, edge in self.variants(state, state).items() if edge.is_analytical + } + def find_path( self, src: GenericCharacteristicName, dst: GenericCharacteristicName, *, - prefer_label: str | None = None, + prefer_label: LabelName | None = None, ) -> list[Any] | None: """ Find a computation path from src to dst using BFS. @@ -503,12 +560,12 @@ def find_path( ---------- src, dst : str Source and destination characteristics. - prefer_label : str, optional + prefer_label : LabelName, optional Preferred edge label to use when multiple options exist. Returns ------- - list of ComputationMethod or None + list of Any or None List of computation methods forming the path, or None if no path exists. Notes @@ -668,17 +725,17 @@ def _reachable_from( @staticmethod def _pick_method( - variants: Mapping[str, EdgeMeta], - prefer_label: str | None, + variants: Mapping[LabelName, EdgeMeta], + prefer_label: LabelName | None, ) -> Any: """ Select a method from label variants. Parameters ---------- - variants : Mapping[str, EdgeMeta] + variants : Mapping[LabelName, EdgeMeta] Available edge variants. - prefer_label : str, optional + prefer_label : LabelName, optional Preferred label. Returns diff --git a/src/pysatl_core/distributions/registry/graph_primitives.py b/src/pysatl_core/distributions/registry/graph_primitives.py index 5235af4..6a2c741 100644 --- a/src/pysatl_core/distributions/registry/graph_primitives.py +++ b/src/pysatl_core/distributions/registry/graph_primitives.py @@ -8,36 +8,73 @@ __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" +from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import Any +from pysatl_core.distributions.computation import ( + AnalyticalComputation, + ComputationMethod, +) from pysatl_core.distributions.registry.constraint import GraphPrimitiveConstraint +from pysatl_core.types import LabelName -if TYPE_CHECKING: - from typing import Any +type EdgeMethod = ComputationMethod[Any, Any] | AnalyticalComputation[Any, Any] - from pysatl_core.distributions.computation import ComputationMethod - -DEFAULT_COMPUTATION_KEY: str = "PySATL_default_computation" +DEFAULT_COMPUTATION_KEY: LabelName = "PySATL_default_computation" """Default label for computation edges when no specific label is provided.""" @dataclass(frozen=True, slots=True) -class EdgeMeta: +class EdgeMeta(ABC): """ Metadata for a computation edge in the characteristic graph. Parameters ---------- - method : ComputationMethod + method : EdgeMethod The computation method that defines the edge. constraint : GraphPrimitiveConstraint Constraint determining when this edge is applicable to a distribution. Defaults to a pass-through constraint that always allows. + is_analytical : bool + Whether this edge represents an analytical computation. """ - method: ComputationMethod[Any, Any] + method: EdgeMethod constraint: GraphPrimitiveConstraint = field(default_factory=GraphPrimitiveConstraint) + is_analytical: bool = field(default=False) + + @abstractmethod + def edge_kind(self) -> str: + """Return edge kind identifier.""" + ... + + +@dataclass(frozen=True, slots=True) +class ComputationEdgeMeta(EdgeMeta): + """ + Edge metadata for conversion computations from the registry graph. + """ + + method: ComputationMethod[Any, Any] + is_analytical: bool = field(default=False) + + def edge_kind(self) -> str: + return "computation" + + +@dataclass(frozen=True, slots=True) +class AnalyticalLoopEdgeMeta(EdgeMeta): + """ + Edge metadata for self-loop analytical computations from a distribution. + """ + + method: AnalyticalComputation[Any, Any] + is_analytical: bool = field(default=True) + + def edge_kind(self) -> str: + return "analytical_loop" class GraphInvariantError(RuntimeError): diff --git a/src/pysatl_core/distributions/strategies.py b/src/pysatl_core/distributions/strategies.py index 77425e9..b7ca0b1 100644 --- a/src/pysatl_core/distributions/strategies.py +++ b/src/pysatl_core/distributions/strategies.py @@ -11,7 +11,7 @@ __copyright__ = "Copyright (c) 2025 PySATL project" __license__ = "SPDX-License-Identifier: MIT" -from typing import TYPE_CHECKING, Protocol, cast +from typing import TYPE_CHECKING, Any, Protocol, cast import numpy as np @@ -20,14 +20,13 @@ if TYPE_CHECKING: from collections.abc import Mapping - from typing import Any from pysatl_core.distributions.computation import ( AnalyticalComputation, FittedComputationMethod, - Method, ) from pysatl_core.distributions.distribution import Distribution + from pysatl_core.distributions.registry.graph import RegistryView from pysatl_core.types import GenericCharacteristicName, LabelName @@ -126,6 +125,19 @@ def _pick_analytical_method( f"Characteristic '{state}' provides no labeled analytical computations." ) from exc + @staticmethod + def _pick_analytical_loop_method( + state: GenericCharacteristicName, + view: RegistryView, + ) -> Method[Any, Any] | None: + """ + Pick the first analytical self-loop method for a characteristic in a view. + """ + loops = view.analytical_variants(state) + if not loops: + return None + return cast(Method[Any, Any], next(iter(loops.values())).method) + def query_method( self, state: GenericCharacteristicName, distr: Distribution, **options: Any ) -> Method[Any, Any]: @@ -133,9 +145,10 @@ def query_method( Resolve a computation method for the target characteristic. Resolution order: - 1. Analytical implementation from the distribution - 2. Cached fitted method (if caching enabled) - 3. Conversion path from an analytical characteristic via the graph + 1. Cached fitted method (if caching enabled) + 2. Analytical implementation for non-registry characteristics + 3. Analytical self-loop from the registry view + 4. Conversion path from analytical-loop characteristics via the graph Parameters ---------- @@ -157,34 +170,46 @@ def query_method( If no analytical base exists, no conversion path is found, or a cycle is detected. """ - # 1. Check for analytical implementation - if state in distr.analytical_computations: - return self._pick_analytical_method(state, distr.analytical_computations[state]) - - # 2. Check cache if enabled + # 1. Check cache if enabled if self._enable_caching: cached = self._cache.get(state) if cached is not None: return cached - # 3. Require at least one analytical characteristic + # 2. Require at least one analytical characteristic if not distr.analytical_computations: raise RuntimeError( "Distribution provides no analytical computations to ground conversions." ) - # 4. Get filtered graph view for this distribution - reg = characteristic_registry().view(distr) + # 3. Non-registry characteristics are resolved directly. + # It covers the situation where user is providing their analytical computation which isn't + # in the graph + registry = characteristic_registry() + if state not in registry.declared_characteristics: + if state in distr.analytical_computations: + return self._pick_analytical_method(state, distr.analytical_computations[state]) + raise RuntimeError( + f"Characteristic '{state}' is not declared in the registry and has no " + "analytical implementation in the distribution." + ) + + # 4. Get filtered graph view for this distribution. + view = registry.view(distr) self._push_guard(distr, state) try: - # 5. Try each analytical characteristic as a source + loop_method = self._pick_analytical_loop_method(state, view) + if loop_method is not None: + return loop_method + + # 5. Try each analytical-loop characteristic as a source for src in distr.analytical_computations: - if src == state: - return self._pick_analytical_method(src, distr.analytical_computations[src]) + if not view.analytical_variants(src): + continue # Find conversion path in the graph - path = reg.find_path(src, state) + path = view.find_path(src, state) if not path: continue diff --git a/tests/unit/distributions/test_registry.py b/tests/unit/distributions/test_registry.py index b2a38e5..25ee46f 100644 --- a/tests/unit/distributions/test_registry.py +++ b/tests/unit/distributions/test_registry.py @@ -5,9 +5,14 @@ __license__ = "SPDX-License-Identifier: MIT" +from collections.abc import Callable +from typing import Any, cast + import numpy as np import pytest +from mypy_extensions import KwArg +from pysatl_core.distributions.computation import AnalyticalComputation from pysatl_core.distributions.registry import ( DEFAULT_COMPUTATION_KEY, CharacteristicRegistry, @@ -18,8 +23,10 @@ reset_characteristic_registry, ) from pysatl_core.distributions.strategies import DefaultComputationStrategy -from pysatl_core.types import CharacteristicName +from pysatl_core.distributions.support import ContinuousSupport +from pysatl_core.types import CharacteristicName, Kind from tests.unit.distributions.test_basic import DistributionTestBase +from tests.utils.mocks import StandaloneEuclideanUnivariateDistribution class TestCharacteristicRegistry(DistributionTestBase): @@ -51,6 +58,68 @@ def test_configuration_continuous_presence_and_connectivity(self) -> None: errs = [abs(float(cdf(float(ppf(float(q))))) - q) for q in qs] assert max(errs) < 5e-3 + def test_view_adds_analytical_self_loops_with_labels(self) -> None: + def cdf_primary(x: float, **_kwargs: Any) -> float: + return 1.0 / (1.0 + np.exp(-x)) + + def cdf_secondary(x: float, **_kwargs: Any) -> float: + return 1.0 / (1.0 + np.exp(-x)) + + cdf_primary_func = cast(Callable[[float, KwArg(Any)], float], cdf_primary) + cdf_secondary_func = cast(Callable[[float, KwArg(Any)], float], cdf_secondary) + + distr = StandaloneEuclideanUnivariateDistribution( + kind=Kind.CONTINUOUS, + analytical_computations={ + CharacteristicName.CDF: { + "primary": AnalyticalComputation[float, float]( + target=CharacteristicName.CDF, func=cdf_primary_func + ), + "secondary": AnalyticalComputation[float, float]( + target=CharacteristicName.CDF, func=cdf_secondary_func + ), + } + }, + support=ContinuousSupport(), + ) + + view = characteristic_registry().view(distr) + loops = view.variants(CharacteristicName.CDF, CharacteristicName.CDF) + + assert set(loops.keys()) == {"primary", "secondary"} + assert all(edge.is_analytical for edge in loops.values()) + assert view.analytical_variants(CharacteristicName.CDF) == loops + + def test_strategy_prefers_first_analytical_loop(self) -> None: + def cdf_first(_x: float, **_kwargs: Any) -> float: + return 0.25 + + def cdf_second(_x: float, **_kwargs: Any) -> float: + return 0.75 + + cdf_first_func = cast(Callable[[float, KwArg(Any)], float], cdf_first) + cdf_second_func = cast(Callable[[float, KwArg(Any)], float], cdf_second) + + distr = StandaloneEuclideanUnivariateDistribution( + kind=Kind.CONTINUOUS, + analytical_computations={ + CharacteristicName.CDF: { + "first": AnalyticalComputation[float, float]( + target=CharacteristicName.CDF, func=cdf_first_func + ), + "second": AnalyticalComputation[float, float]( + target=CharacteristicName.CDF, func=cdf_second_func + ), + } + }, + support=ContinuousSupport(), + ) + + strategy = DefaultComputationStrategy(enable_caching=False) + cdf = strategy.query_method(CharacteristicName.CDF, distr) + + assert cdf(0.0) == pytest.approx(0.25) + def test_configuration_discrete_requires_support_then_ok(self) -> None: reg = characteristic_registry() distr = self.make_discrete_point_pmf_distribution(is_with_support=False) @@ -185,6 +254,7 @@ def test_label_variants_and_picking(self) -> None: variants = view.variants("src", "dst") assert set(variants.keys()) == {DEFAULT_COMPUTATION_KEY, "fast"} + assert all(not edge.is_analytical for edge in variants.values()) path = view.find_path("src", "dst", prefer_label="fast") assert path == [alternative_method] From 33fd6ed6c3f5ad88f4ff847de1aa30dc38926370 Mon Sep 17 00:00:00 2001 From: LeonidElkin Date: Tue, 17 Mar 2026 21:10:18 +0300 Subject: [PATCH 5/5] feat(registry): added support for hyper edges in the graph --- .../distributions/registry/graph.py | 170 ++++++++------ tests/unit/distributions/test_registry.py | 215 +++++++++++++++++- 2 files changed, 308 insertions(+), 77 deletions(-) diff --git a/src/pysatl_core/distributions/registry/graph.py b/src/pysatl_core/distributions/registry/graph.py index cddc8ed..485db67 100644 --- a/src/pysatl_core/distributions/registry/graph.py +++ b/src/pysatl_core/distributions/registry/graph.py @@ -10,7 +10,7 @@ Core concepts: - **Nodes**: Characteristics (PDF, CDF, etc.) with presence and definitiveness rules - - **Edges**: Unary computation methods between characteristics + - **Edges**: Computation methods from one-or-many sources to one target - **Constraints**: Rules that determine when nodes/edges are applicable - **View**: A filtered subgraph for a specific distribution - **Definitive characteristics**: Starting points for computations @@ -64,14 +64,14 @@ class CharacteristicRegistry: add_characteristic(name, is_definitive, presence_constraint=None, definitive_constraint=None) Declare a characteristic with presence and optional definitiveness rules. add_computation(method, label=DEFAULT_COMPUTATION_KEY, constraint=None) - Add a unary computation edge between declared nodes. + Add a computation edge between declared nodes. view(distr) Create a filtered view for the given distribution. Notes ----- - Nodes must be declared before adding computations - - Only unary computations (1 source → 1 target) are supported + - Only many-to-one computations (n sources → 1 target) are supported - No invariant validation happens during mutation; validation occurs when creating a view with view() """ @@ -88,10 +88,12 @@ def __init__(self) -> None: if getattr(self, "_initialized", False): return - # Adjacency: src → dst → label → [EdgeMeta] + # Adjacency projection: src → dst → label → [ComputationEdgeMeta] + # For hyperedges (many sources -> one target), the same edge metadata object + # is projected under each source to preserve graph reachability semantics. self._adj: dict[ GenericCharacteristicName, - dict[GenericCharacteristicName, dict[LabelName, list[EdgeMeta]]], + dict[GenericCharacteristicName, dict[LabelName, list[ComputationEdgeMeta]]], ] = {} self._all_nodes: set[GenericCharacteristicName] = set() @@ -179,12 +181,12 @@ def add_computation( constraint: GraphPrimitiveConstraint | None = None, ) -> None: """ - Add a labeled unary computation edge. + Add a labeled computation edge. Parameters ---------- method : ComputationMethod - Computation object with exactly one source and one target. + Computation object with one-or-many sources and one target. label : LabelName, default=DEFAULT_COMPUTATION_KEY Variant label for the edge. constraint : GraphPrimitiveConstraint, optional @@ -193,33 +195,36 @@ def add_computation( Raises ------ ValueError - If method is not unary, or source/target nodes are not declared. + If method has no sources, or source/target nodes are not declared. Notes ----- - Multiple edges with different labels can exist between the same nodes - The first matching edge for each label is kept when creating views + - Hyperedges are represented as projected edges from each source to target, + while preserving one shared underlying computation method. """ - if len(method.sources) != 1: - raise ValueError("Only unary computations are supported (1 source → 1 target).") + if not method.sources: + raise ValueError("Computation must define at least one source characteristic.") - src = method.sources[0] + unique_sources = tuple(dict.fromkeys(method.sources)) dst = method.target - if not self._ensure_node(src) or not self._ensure_node(dst): + if not self._ensure_node(dst) or any(not self._ensure_node(src) for src in unique_sources): raise ValueError("Source characteristic or destination characteristic is invalid.") - self._adj[src].setdefault(dst, {}) + edge_meta = ComputationEdgeMeta( + method=method, + constraint=constraint or GraphPrimitiveConstraint(), + ) + # TODO: We need to be careful here if some constraint more general and with the same label # than other it can consume it. Actually, the same label methods should not intersect their # constraints - self._adj[src][dst].setdefault(label, []) - self._adj[src][dst][label].append( - ComputationEdgeMeta( - method=method, - constraint=constraint or GraphPrimitiveConstraint(), - ) - ) + for src in unique_sources: + self._adj[src].setdefault(dst, {}) + self._adj[src][dst].setdefault(label, []) + self._adj[src][dst][label].append(edge_meta) def add_characteristic( self, @@ -342,41 +347,36 @@ def view(self, distr: Distribution) -> RegistryView: Notes ----- - 1. Filters edges by their constraints - 2. Removes edges touching absent nodes + 1. Computes present nodes for the distribution + 2. Filters edges by node presence and edge constraints 3. Adds analytical self-loops from distribution analytical computations 4. Computes definitive nodes from the remaining present nodes 5. Validates graph invariants """ - # 1) Filter edges by applicability + # 1) Compute present nodes once and pre-create adjacency. + present_nodes = self._compute_present_nodes(distr) adj: dict[ GenericCharacteristicName, dict[GenericCharacteristicName, dict[LabelName, EdgeMeta]] - ] = {} - for src, d in self._adj.items(): - for dst, variants in d.items(): + ] = {node: {} for node in present_nodes} + + # 2) Filter edges by node presence and applicability. + for src in present_nodes: + for dst, variants in self._adj.get(src, {}).items(): + if dst not in present_nodes: + continue kept: dict[LabelName, EdgeMeta] = {} for label, metas in variants.items(): for edge in metas: - if edge.constraint.allows(distr): + if edge.constraint.allows(distr) and all( + source in present_nodes for source in edge.method.sources + ): kept[label] = edge # TODO: It is possible that there are two edges under the same label # that fit the same distribution, this should not be the case. # Taking the first one for now break if kept: - adj.setdefault(src, {}).setdefault(dst, {}).update(kept) - - # 2) Filter by node presence - present_nodes = self._compute_present_nodes(distr) - if present_nodes: - adj = { - src: {dst: dict(variants) for dst, variants in d.items() if dst in present_nodes} - for src, d in adj.items() - if src in present_nodes - } - # Ensure isolated present nodes are preserved - for node in present_nodes: - adj.setdefault(node, {}) + adj[src][dst] = kept # 3) Attach analytical loops self._attach_analytical_loops(adj, distr, present_nodes) @@ -438,6 +438,14 @@ def __init__( for src, d in adj.items(): self._adj[src] = {dst: dict(variants) for dst, variants in d.items()} + self._rev_adj: dict[GenericCharacteristicName, set[GenericCharacteristicName]] = { + node: set() for node in self._adj + } + for src, d in self._adj.items(): + for dst, variants in d.items(): + if variants: + self._rev_adj.setdefault(dst, set()).add(src) + self.definitive_characteristics: set[GenericCharacteristicName] = set(definitive_nodes) self.all_characteristics: set[GenericCharacteristicName] = set(present_nodes) @@ -504,11 +512,7 @@ def predecessors(self, v: GenericCharacteristicName) -> set[GenericCharacteristi set of str Characteristics that can reach v directly. """ - res: set[GenericCharacteristicName] = set() - for src, d in self._adj.items(): - if v in d and d[v]: - res.add(src) - return res + return set(self._rev_adj.get(v, set())) def variants( self, src: GenericCharacteristicName, dst: GenericCharacteristicName @@ -643,16 +647,8 @@ def _definitive_strongly_connected(self) -> bool: if fwd != (defs - {start}): return False - # Check reverse reachability - seen: set[GenericCharacteristicName] = {start} - stack = [start] - while stack: - v = stack.pop() - for w in self.predecessors(v): - if w in defs and w not in seen: - seen.add(w) - stack.append(w) - return seen == defs + rev = self._reachable_from_many({start}, allowed=defs, reverse=True) + return rev == (defs - {start}) def _all_indefinitives_reachable_from_definitives(self) -> bool: """ @@ -667,9 +663,7 @@ def _all_indefinitives_reachable_from_definitives(self) -> bool: if not indefs: return True - total: set[GenericCharacteristicName] = set() - for d in self.definitive_characteristics: - total |= self._reachable_from(d) + total = self._reachable_from_many(self.definitive_characteristics) return indefs.issubset(total) def _exists_path_from_indefinitive_to_definitive(self) -> bool: @@ -682,46 +676,78 @@ def _exists_path_from_indefinitive_to_definitive(self) -> bool: True if such a path exists (which would violate invariants). """ defs = self.definitive_characteristics - return any(self._reachable_from(i) & defs for i in self.indefinitive_characteristics) + if not defs: + return False - def _reachable_from( + can_reach_definitive = self._reachable_from_many(defs, reverse=True) + return bool(can_reach_definitive & self.indefinitive_characteristics) + + def _reachable_from_many( self, - src: GenericCharacteristicName, + sources: set[GenericCharacteristicName], *, allowed: set[GenericCharacteristicName] | None = None, + reverse: bool = False, ) -> set[GenericCharacteristicName]: """ - Compute forward reachable nodes from src. + Compute reachable nodes from multiple sources. Parameters ---------- - src : str - Starting node. + sources : set of str + Starting nodes. allowed : set of str, optional - Restrict to this set of nodes. + Restrict traversal to this set of nodes. + reverse : bool, default=False + If True, traverse reverse edges. Returns ------- set of str - Nodes reachable from src (excluding src itself). + Nodes reachable from sources (excluding sources themselves). """ - if allowed is not None and src not in allowed: + starts = {s for s in sources if allowed is None or s in allowed} + if not starts: return set() visited: set[GenericCharacteristicName] = set() - stack = [src] + stack = list(starts) while stack: v = stack.pop() if v in visited: continue visited.add(v) - for w in self.successors_nodes(v): + neighbors = self._rev_adj.get(v, set()) if reverse else self.successors_nodes(v) + for w in neighbors: if allowed is not None and w not in allowed: continue if w not in visited: stack.append(w) - visited.discard(src) - return visited + + return visited - starts + + def _reachable_from( + self, + src: GenericCharacteristicName, + *, + allowed: set[GenericCharacteristicName] | None = None, + ) -> set[GenericCharacteristicName]: + """ + Compute forward reachable nodes from src. + + Parameters + ---------- + src : str + Starting node. + allowed : set of str, optional + Restrict to this set of nodes. + + Returns + ------- + set of str + Nodes reachable from src (excluding src itself). + """ + return self._reachable_from_many({src}, allowed=allowed) @staticmethod def _pick_method( @@ -747,5 +773,5 @@ def _pick_method( return variants[prefer_label].method if DEFAULT_COMPUTATION_KEY in variants: return variants[DEFAULT_COMPUTATION_KEY].method - label = sorted(variants.keys())[0] + label = min(variants) return variants[label].method diff --git a/tests/unit/distributions/test_registry.py b/tests/unit/distributions/test_registry.py index 25ee46f..3a8e263 100644 --- a/tests/unit/distributions/test_registry.py +++ b/tests/unit/distributions/test_registry.py @@ -12,7 +12,13 @@ import pytest from mypy_extensions import KwArg -from pysatl_core.distributions.computation import AnalyticalComputation +from pysatl_core.distributions import strategies as strategies_module +from pysatl_core.distributions.computation import ( + AnalyticalComputation, + ComputationMethod, + FittedComputationMethod, +) +from pysatl_core.distributions.distribution import Distribution from pysatl_core.distributions.registry import ( DEFAULT_COMPUTATION_KEY, CharacteristicRegistry, @@ -182,12 +188,14 @@ def test_add_computation_validation_and_duplicate_rules(self) -> None: reg = CharacteristicRegistry() reg.add_characteristic("a", is_definitive=True) reg.add_characteristic("b", is_definitive=True) + reg.add_characteristic("c", is_definitive=False) - # non-unary validation + # empty-sources validation with pytest.raises(ValueError): - reg.add_computation( - self.make_fictitious_computation_method(target="b", sources=["a", "b"]) - ) + reg.add_computation(self.make_fictitious_computation_method(target="b", sources=[])) + + # many-to-one computation is supported + reg.add_computation(self.make_fictitious_computation_method(target="c", sources=["a", "b"])) # undeclared nodes with pytest.raises(ValueError): @@ -261,3 +269,200 @@ def test_label_variants_and_picking(self) -> None: path = view.find_path("src", "dst") assert path == [default_method] + + def test_hyperedge_many_to_one_projection_and_single_fitter( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + reg = CharacteristicRegistry() + reg.add_characteristic("A", is_definitive=True) + reg.add_characteristic("B", is_definitive=True) + reg.add_characteristic("C", is_definitive=False) + + reg.add_computation(self.make_fictitious_computation_method(target="B", sources=["A"])) + reg.add_computation(self.make_fictitious_computation_method(target="A", sources=["B"])) + + fitter_calls: dict[str, int] = {"count": 0} + + def fit_ab_to_c( + distribution: Distribution, **_kwargs: Any + ) -> FittedComputationMethod[Any, Any]: + fitter_calls["count"] += 1 + a_method = distribution.query_method("A") + b_method = distribution.query_method("B") + + def c_func(**_options: Any) -> float: + return float(a_method() + b_method()) + + return FittedComputationMethod( + target="C", + sources=("A", "B"), + func=cast(Callable[[KwArg(Any)], float], c_func), + ) + + hyper_method = ComputationMethod[Any, Any]( + target="C", + sources=("A", "B"), + fitter=cast( + Callable[[Distribution, KwArg(Any)], FittedComputationMethod[Any, Any]], + fit_ab_to_c, + ), + ) + reg.add_computation(hyper_method, label="ab_to_c") + + a_func = cast(Callable[[KwArg(Any)], float], lambda **_kwargs: 2.0) + b_func = cast(Callable[[KwArg(Any)], float], lambda **_kwargs: 3.0) + distr = StandaloneEuclideanUnivariateDistribution( + kind=Kind.CONTINUOUS, + analytical_computations={ + "A": {"default": AnalyticalComputation[Any, Any](target="A", func=a_func)}, + "B": {"default": AnalyticalComputation[Any, Any](target="B", func=b_func)}, + }, + support=ContinuousSupport(), + ) + + view = reg.view(distr) + assert view.variants("A", "C")["ab_to_c"].method is hyper_method + assert view.variants("B", "C")["ab_to_c"].method is hyper_method + assert view.find_path("A", "C") == [hyper_method] + assert view.find_path("B", "C") == [hyper_method] + + monkeypatch.setattr(strategies_module, "characteristic_registry", lambda: reg) + strategy = DefaultComputationStrategy(enable_caching=True) + c_method_first = strategy.query_method("C", distr) + c_method_second = strategy.query_method("C", distr) + + assert c_method_first() == pytest.approx(5.0) + assert c_method_second() == pytest.approx(5.0) + assert fitter_calls["count"] == 1 + + def test_hyperedge_requires_all_sources_present_in_view(self) -> None: + reg = CharacteristicRegistry() + reg.add_characteristic("A", is_definitive=True) + reg.add_characteristic( + "B", + is_definitive=False, + presence_constraint=GraphPrimitiveConstraint( + distribution_type_feature_constraints={ + "dimension": NumericConstraint(allowed=frozenset({2})) + } + ), + ) + reg.add_characteristic("C", is_definitive=False) + + reg.add_computation(self.make_fictitious_computation_method(target="C", sources=["A", "B"])) + + # Source B is absent for 1D distributions, so hyperedge A,B -> C must be filtered out. + # Then C becomes unreachable from definitive nodes and invariant validation must fail. + with pytest.raises(GraphInvariantError): + reg.view(self.make_logistic_cdf_distribution()) + + def test_strategy_resolves_diamond_graph_from_single_analytical_source( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + reg = CharacteristicRegistry() + reg.add_characteristic("pdf", is_definitive=True) + reg.add_characteristic("mean", is_definitive=False) + reg.add_characteristic("second_moment", is_definitive=False) + reg.add_characteristic("mean_sq", is_definitive=False) + reg.add_characteristic("var", is_definitive=False) + + def fit_pdf_to_mean( + _distribution: Distribution, **_kwargs: Any + ) -> FittedComputationMethod[Any, Any]: + return FittedComputationMethod( + target="mean", + sources=("pdf",), + func=cast(Callable[[KwArg(Any)], float], lambda **_opts: 2.0), + ) + + def fit_pdf_to_second_moment( + _distribution: Distribution, **_kwargs: Any + ) -> FittedComputationMethod[Any, Any]: + return FittedComputationMethod( + target="second_moment", + sources=("pdf",), + func=cast(Callable[[KwArg(Any)], float], lambda **_opts: 5.0), + ) + + def fit_mean_to_mean_sq( + distribution: Distribution, **_kwargs: Any + ) -> FittedComputationMethod[Any, Any]: + mean_method = distribution.query_method("mean") + return FittedComputationMethod( + target="mean_sq", + sources=("mean",), + func=cast( + Callable[[KwArg(Any)], float], + lambda **_opts: float(mean_method() ** 2), + ), + ) + + def fit_second_moment_and_mean_sq_to_var( + distribution: Distribution, **_kwargs: Any + ) -> FittedComputationMethod[Any, Any]: + second_moment_method = distribution.query_method("second_moment") + mean_sq_method = distribution.query_method("mean_sq") + return FittedComputationMethod( + target="var", + sources=("second_moment", "mean_sq"), + func=cast( + Callable[[KwArg(Any)], float], + lambda **_opts: float(second_moment_method() - mean_sq_method()), + ), + ) + + reg.add_computation( + ComputationMethod[Any, Any]( + target="mean", + sources=("pdf",), + fitter=cast( + Callable[[Distribution, KwArg(Any)], FittedComputationMethod[Any, Any]], + fit_pdf_to_mean, + ), + ) + ) + reg.add_computation( + ComputationMethod[Any, Any]( + target="second_moment", + sources=("pdf",), + fitter=cast( + Callable[[Distribution, KwArg(Any)], FittedComputationMethod[Any, Any]], + fit_pdf_to_second_moment, + ), + ) + ) + reg.add_computation( + ComputationMethod[Any, Any]( + target="mean_sq", + sources=("mean",), + fitter=cast( + Callable[[Distribution, KwArg(Any)], FittedComputationMethod[Any, Any]], + fit_mean_to_mean_sq, + ), + ) + ) + reg.add_computation( + ComputationMethod[Any, Any]( + target="var", + sources=("second_moment", "mean_sq"), + fitter=cast( + Callable[[Distribution, KwArg(Any)], FittedComputationMethod[Any, Any]], + fit_second_moment_and_mean_sq_to_var, + ), + ) + ) + + pdf_func = cast(Callable[[float, KwArg(Any)], float], lambda _x, **_opts: 1.0) + distr = StandaloneEuclideanUnivariateDistribution( + kind=Kind.CONTINUOUS, + analytical_computations={ + "pdf": {"default": AnalyticalComputation[Any, Any](target="pdf", func=pdf_func)} + }, + support=ContinuousSupport(), + ) + + monkeypatch.setattr(strategies_module, "characteristic_registry", lambda: reg) + strategy = DefaultComputationStrategy(enable_caching=True) + var_method = strategy.query_method("var", distr) + + assert var_method() == pytest.approx(1.0)