Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ poetry.lock
.cursorindexingignore

# Docs

docs/build/*
docs/source/api/*
docs/source/examples/*
Expand All @@ -137,3 +138,9 @@ docs/source/examples/*

# Unuran build artifacts
src/pysatl_core/sampling/unuran/bindings/_unuran_cffi.c

# ai

.ai/*
AGENTS.md
CLAUDE.md
154 changes: 136 additions & 18 deletions src/pysatl_core/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,28 @@
__license__ = "SPDX-License-Identifier: MIT"

from abc import ABC, abstractmethod
from collections.abc import Mapping
from copy import deepcopy
from typing import TYPE_CHECKING, Self, cast

from pysatl_core.types import NumericArray
from pysatl_core.distributions.strategies import (
ComputationStrategy,
SamplingStrategy,
)
from pysatl_core.types import DEFAULT_ANALYTICAL_COMPUTATION_LABEL, NumericArray

_KEEP: object = object()


if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any

from pysatl_core.distributions.computation import AnalyticalComputation
from pysatl_core.distributions.strategies import (
ComputationStrategy,
SamplingStrategy,
)
from pysatl_core.distributions.support import Support
from pysatl_core.types import (
DistributionType,
GenericCharacteristicName,
LabelName,
Method,
)

Expand All @@ -48,7 +50,13 @@ class Distribution(ABC):
----------
distribution_type : DistributionType
Type information about the distribution (kind, dimension, etc.).
analytical_computations : Mapping[str, AnalyticalComputation]
analytical_computations : Mapping[
GenericCharacteristicName,
(
AnalyticalComputation[Any, Any]
| Mapping[LabelName, AnalyticalComputation[Any, Any]]
),
]
Direct analytical computations provided by the distribution.
sampling_strategy : SamplingStrategy
Strategy for generating random samples.
Expand All @@ -58,35 +66,145 @@ class Distribution(ABC):
Support of the distribution, if defined.
"""

def __init__(
self,
distribution_type: DistributionType,
analytical_computations: Mapping[
GenericCharacteristicName,
(AnalyticalComputation[Any, Any] | Mapping[LabelName, AnalyticalComputation[Any, Any]]),
],
support: Support | None = None,
sampling_strategy: SamplingStrategy | None = None,
computation_strategy: ComputationStrategy | None = None,
) -> None:
"""
Initialize common distribution state.

Parameters
----------
distribution_type : DistributionType
Type information about the distribution (kind, dimension, etc.).
analytical_computations :
Mapping[
GenericCharacteristicName,
(
AnalyticalComputation[Any, Any]
| Mapping[LabelName, AnalyticalComputation[Any, Any]]
),
]
Analytical computations provided by the distribution.
support : Support or None, default=None
Support of the distribution.
sampling_strategy : SamplingStrategy or None, default=None
Sampling strategy instance. If omitted, univariate default is used.
computation_strategy : ComputationStrategy or None, default=None
Computation strategy instance. If omitted, default strategy is used.
"""
from pysatl_core.distributions.strategies import (
DefaultComputationStrategy,
DefaultSamplingUnivariateStrategy,
)

self._distribution_type = distribution_type
normalized_analytical: dict[
GenericCharacteristicName, dict[LabelName, AnalyticalComputation[Any, Any]]
] = {}
for characteristic_name, methods in analytical_computations.items():
if isinstance(methods, Mapping):
normalized_analytical[characteristic_name] = dict(methods)
else:
normalized_analytical[characteristic_name] = {
DEFAULT_ANALYTICAL_COMPUTATION_LABEL: methods
}

if not normalized_analytical:
raise ValueError("Distribution requires at least one analytical computation.")

for characteristic_name, labeled_methods in normalized_analytical.items():
if not labeled_methods:
raise ValueError(
f"Characteristic '{characteristic_name}' must provide at least one "
"analytical computation."
)

self._analytical_computations = normalized_analytical
self._support = support
self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy()
self._computation_strategy = computation_strategy or DefaultComputationStrategy()

@property
@abstractmethod
def distribution_type(self) -> DistributionType: ...
def distribution_type(self) -> DistributionType:
"""Return type metadata of the distribution (kind, dimension, etc.)."""
return self._distribution_type

@property
@abstractmethod
def analytical_computations(
self,
) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: ...
) -> Mapping[GenericCharacteristicName, Mapping[LabelName, AnalyticalComputation[Any, Any]]]:
"""Return analytical computations provided directly by this distribution."""
return self._analytical_computations

@property
@abstractmethod
def sampling_strategy(self) -> SamplingStrategy: ...
def sampling_strategy(self) -> SamplingStrategy:
"""Return the currently attached sampling strategy."""
return self._sampling_strategy

@property
@abstractmethod
def computation_strategy(self) -> ComputationStrategy: ...
def computation_strategy(self) -> ComputationStrategy:
"""Return the currently attached computation strategy."""
return self._computation_strategy

@property
@abstractmethod
def support(self) -> Support | None: ...
def support(self) -> Support | None:
"""Return the support of the distribution, if it is defined."""
return self._support

@abstractmethod
def _clone_with_strategies(
self,
*,
sampling_strategy: SamplingStrategy | None | object = _KEEP,
computation_strategy: ComputationStrategy | None | object = _KEEP,
) -> Distribution: ...
) -> Distribution:
"""
Return a cloned distribution with updated strategies.

The ``_KEEP`` sentinel means the existing strategy should be preserved
for that side.
"""
...

def _new_sampling_strategy(
self,
sampling_strategy: SamplingStrategy | None | object = _KEEP,
) -> SamplingStrategy | None:
"""
Resolve sampling strategy for cloning.

When ``sampling_strategy`` is ``_KEEP``, returns a deep copy of the
current sampling strategy.
"""
return cast(
SamplingStrategy | None,
deepcopy(self._sampling_strategy) if sampling_strategy is _KEEP else sampling_strategy,
)

def _new_computation_strategy(
self,
computation_strategy: ComputationStrategy | None | object = _KEEP,
) -> ComputationStrategy | None:
"""
Resolve computation strategy for cloning.

When ``computation_strategy`` is ``_KEEP``, returns a deep copy of the
current computation strategy.
"""
return cast(
ComputationStrategy | None,
deepcopy(self._computation_strategy)
if computation_strategy is _KEEP
else computation_strategy,
)

def with_sampling_strategy(self, sampling_strategy: SamplingStrategy | None) -> Self:
"""Return a copy of this distribution with an updated sampling strategy."""
Expand Down
4 changes: 4 additions & 0 deletions src/pysatl_core/distributions/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from .graph_primitives import (
DEFAULT_COMPUTATION_KEY,
AnalyticalLoopEdgeMeta,
ComputationEdgeMeta,
EdgeMeta,
GraphInvariantError,
)
Expand All @@ -33,6 +35,8 @@
# Graph primitives and constants
"DEFAULT_COMPUTATION_KEY",
"EdgeMeta",
"ComputationEdgeMeta",
"AnalyticalLoopEdgeMeta",
"GraphInvariantError",
# Constraint types
"Constraint",
Expand Down
Loading
Loading