-
-
Notifications
You must be signed in to change notification settings - Fork 204
Fix DiscreteDistributionLabeled: labels parameter and from_dataset initialization #1717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c892c94
6f79d83
914a4cc
1eb5915
f395114
f245038
819916a
795219d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Callable, Dict, List, Optional, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from copy import deepcopy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import xarray as xr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -26,7 +27,8 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist : rv_discrete | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Discrete distribution from scipy.stats. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed : int, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Seed for random number generator, by default 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Seed for random number generator. If None (default), a random | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed is generated from system entropy. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rv_discrete_frozen.__init__(self, dist, *args, **kwds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -118,12 +120,16 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __repr__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.atoms.shape[0] > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "inf=" + str(tuple(self.limit["infimum"])) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "sup=" + str(tuple(self.limit["supremum"])) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inf = self.limit.get("infimum", np.array([])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sup = self.limit.get("supremum", np.array([])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if inf.size == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "inf=[], sup=[], " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.atoms.shape[0] > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "inf=" + str(tuple(inf)) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "sup=" + str(tuple(sup)) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "inf=" + str(self.limit["infimum"][0]) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "sup=" + str(self.limit["supremum"][0]) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "inf=" + str(inf[0]) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "sup=" + str(sup[0]) + ", " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out += "seed=" + str(self.seed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -189,8 +195,7 @@ def draw( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| K = np.floor(K_exact).astype(int) # number of slots allocated to each atom | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M = N - np.sum(K) # number of unallocated slots | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| J = P.size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| eps = 1.0 / N | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Q = K_exact - eps * K # "missing" probability mass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Q = K_exact - K # fractional part: slots still owed to each atom | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| draws = self._rng.random(M) # uniform draws for "extra" slots | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Fill in each unallocated slot, one by one | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -469,18 +474,107 @@ def from_unlabeled( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ldd | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def from_dataset(cls, x_obj, pmf): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def from_dataset(cls, x_obj, pmf, seed=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Construct a DiscreteDistributionLabeled from xarray objects. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x_obj : xr.Dataset, xr.DataArray, or dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The data containing distribution values. If a Dataset, variables | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| become the distribution's variables. If a DataArray, it becomes a | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| single variable (unnamed DataArrays get the name "var_0"). If a | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dict, it's converted to a Dataset. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pmf : xr.DataArray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Probability mass values with dimension "atom". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed : int, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Seed for random number generator. If None (default), a random | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed is generated from system entropy. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd : DiscreteDistributionLabeled | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A properly initialized labeled distribution. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| If x_obj is not an xr.Dataset, xr.DataArray, or dict, or if | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pmf is not an xr.DataArray. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(pmf, xr.DataArray): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"from_dataset() requires pmf to be an xr.DataArray with " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"dimension 'atom', but got {type(pmf).__name__}. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Wrap your probabilities: pmf = xr.DataArray(array, dims=('atom',))" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd = cls.__new__(cls) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(x_obj, xr.Dataset): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.dataset = x_obj | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(x_obj, xr.DataArray): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.dataset = xr.Dataset({x_obj.name: x_obj}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name = x_obj.name if x_obj.name is not None else "var_0" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.dataset = xr.Dataset({name: x_obj}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(x_obj, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.dataset = xr.Dataset(x_obj) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"from_dataset() expected x_obj to be an xr.Dataset, " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"xr.DataArray, or dict, but got {type(x_obj).__name__}." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.probability = pmf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Extract pmv from probability DataArray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.pmv = np.asarray(pmf.values) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Extract atoms from dataset variables that have the "atom" dimension. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For DiscreteDistribution, atoms has shape (..., n_atoms) where the last | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # dimension indexes "atom" (the random realization). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Variables without the "atom" dimension (e.g., scalar summaries) are kept | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # in the dataset but not included in atoms. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var_names = list(ldd.dataset.data_vars) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if var_names: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Filter to only include variables that have the "atom" dimension | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vars_with_atom = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var for var in var_names if "atom" in ldd.dataset[var].dims | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if vars_with_atom: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var_arrays = [ldd.dataset[var].values for var in vars_with_atom] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if all arrays have the same shape (required for np.stack) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shapes = [arr.shape for arr in var_arrays] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(set(shapes)) == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.atoms = np.atleast_2d(np.stack(var_arrays, axis=0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+545
to
+550
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"from_dataset(): variables with 'atom' dimension have " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"incompatible shapes ({dict(zip(vars_with_atom, shapes))}). " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Cannot construct a valid distribution with mixed-shape " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"atoms. Ensure all variables have the same shape." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+551
to
+555
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.atoms = np.atleast_2d(np.array([])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.atoms = np.atleast_2d(np.array([])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Compute limit from atoms | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ldd.atoms.size > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.limit = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "infimum": np.min(ldd.atoms, axis=-1), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "supremum": np.max(ldd.atoms, axis=-1), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.limit = {"infimum": np.array([]), "supremum": np.array([])} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize base class attributes that __init__ would normally set | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.infimum = ldd.limit["infimum"].copy() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.supremum = ldd.limit["supremum"].copy() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize seed and RNG using the property setter from Distribution base class | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ldd.seed = seed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ldd | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -490,6 +584,26 @@ def _weighted(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.dataset.weighted(self.probability) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _weighted_mean_of(self, f_query): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Compute the probability-weighted mean of function output over | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| the "atom" dimension without constructing an intermediate | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| distribution. Handles Dataset, DataArray, and dict results. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(f_query, xr.Dataset): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f_query.weighted(self.probability).mean("atom") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(f_query, xr.DataArray): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f_query.weighted(self.probability).mean("atom") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(f_query, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ds = xr.Dataset(f_query) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+593
to
+598
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(f_query, xr.Dataset): | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, xr.DataArray): | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, dict): | |
| ds = xr.Dataset(f_query) | |
| if isinstance(f_query, xr.Dataset): | |
| if "atom" not in f_query.dims: | |
| raise ValueError( | |
| "expected() function must return an xr.Dataset with an " | |
| "'atom' dimension to compute a probability-weighted mean; " | |
| f"got dimensions {tuple(f_query.dims)} instead." | |
| ) | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, xr.DataArray): | |
| if "atom" not in f_query.dims: | |
| raise ValueError( | |
| "expected() function must return an xr.DataArray with an " | |
| "'atom' dimension to compute a probability-weighted mean; " | |
| f"got dimensions {tuple(f_query.dims)} instead." | |
| ) | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, dict): | |
| ds = xr.Dataset(f_query) | |
| if "atom" not in ds.dims: | |
| raise ValueError( | |
| "expected() function must return data with an 'atom' " | |
| "dimension to compute a probability-weighted mean; " | |
| f"got dimensions {tuple(ds.dims)} instead." | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -513,7 +513,9 @@ def calc_expectation(dstn, func=None, *args, **kwargs): | |
| f_query = [] | ||
| for i in range(len(dstn.pmv)): | ||
| temp_dict = { | ||
| key: float(dstn.variables[key][i]) for key in dstn.variables.keys() | ||
| key: float(dstn.variables[key][i]) | ||
| for key in dstn.variables.keys() | ||
| if "atom" in dstn.dataset[key].dims | ||
| } | ||
| f_query.append(func(temp_dict, *args, **kwargs)) | ||
|
Comment on lines
+516
to
520
|
||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_dataset() documents that pmf must have dimension 'atom', but it only validates pmf’s type. Please also validate that 'atom' is in pmf.dims (and ideally pmf.ndim==1), and that pmf.sizes['atom'] matches the dataset’s 'atom' length when present, to avoid constructing inconsistent distributions.