Skip to content
187 changes: 168 additions & 19 deletions HARK/distributions/discrete.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Union

import warnings
from copy import deepcopy
import numpy as np
import xarray as xr
Expand All @@ -26,7 +27,8 @@ def __init__(
dist : rv_discrete
Discrete distribution from scipy.stats.
seed : int, optional
Seed for random number generator, by default 0
Seed for random number generator. If None (default), a random
seed is generated from system entropy.
"""

rv_discrete_frozen.__init__(self, dist, *args, **kwds)
Expand Down Expand Up @@ -118,12 +120,16 @@ def __init__(

def __repr__(self):
out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, "
if self.atoms.shape[0] > 1:
out += "inf=" + str(tuple(self.limit["infimum"])) + ", "
out += "sup=" + str(tuple(self.limit["supremum"])) + ", "
inf = self.limit.get("infimum", np.array([]))
sup = self.limit.get("supremum", np.array([]))
if inf.size == 0:
out += "inf=[], sup=[], "
elif self.atoms.shape[0] > 1:
out += "inf=" + str(tuple(inf)) + ", "
out += "sup=" + str(tuple(sup)) + ", "
else:
out += "inf=" + str(self.limit["infimum"][0]) + ", "
out += "sup=" + str(self.limit["supremum"][0]) + ", "
out += "inf=" + str(inf[0]) + ", "
out += "sup=" + str(sup[0]) + ", "
out += "seed=" + str(self.seed)
return out

Expand Down Expand Up @@ -189,8 +195,7 @@ def draw(
K = np.floor(K_exact).astype(int) # number of slots allocated to each atom
M = N - np.sum(K) # number of unallocated slots
J = P.size
eps = 1.0 / N
Q = K_exact - eps * K # "missing" probability mass
Q = K_exact - K # fractional part: slots still owed to each atom
draws = self._rng.random(M) # uniform draws for "extra" slots

# Fill in each unallocated slot, one by one
Expand Down Expand Up @@ -469,18 +474,107 @@ def from_unlabeled(
return ldd

@classmethod
def from_dataset(cls, x_obj, pmf):
def from_dataset(cls, x_obj, pmf, seed=None):
"""
Construct a DiscreteDistributionLabeled from xarray objects.

Parameters
----------
x_obj : xr.Dataset, xr.DataArray, or dict
The data containing distribution values. If a Dataset, variables
become the distribution's variables. If a DataArray, it becomes a
single variable (unnamed DataArrays get the name "var_0"). If a
dict, it's converted to a Dataset.
pmf : xr.DataArray
Probability mass values with dimension "atom".
seed : int, optional
Seed for random number generator. If None (default), a random
seed is generated from system entropy.

Returns
-------
ldd : DiscreteDistributionLabeled
A properly initialized labeled distribution.

Raises
------
TypeError
If x_obj is not an xr.Dataset, xr.DataArray, or dict, or if
pmf is not an xr.DataArray.
"""
if not isinstance(pmf, xr.DataArray):
raise TypeError(
f"from_dataset() requires pmf to be an xr.DataArray with "
f"dimension 'atom', but got {type(pmf).__name__}. "
f"Wrap your probabilities: pmf = xr.DataArray(array, dims=('atom',))"
)

ldd = cls.__new__(cls)

if isinstance(x_obj, xr.Dataset):
ldd.dataset = x_obj
elif isinstance(x_obj, xr.DataArray):
ldd.dataset = xr.Dataset({x_obj.name: x_obj})
name = x_obj.name if x_obj.name is not None else "var_0"
ldd.dataset = xr.Dataset({name: x_obj})
elif isinstance(x_obj, dict):
ldd.dataset = xr.Dataset(x_obj)
else:
raise TypeError(
f"from_dataset() expected x_obj to be an xr.Dataset, "
f"xr.DataArray, or dict, but got {type(x_obj).__name__}."
)

ldd.probability = pmf

# Extract pmv from probability DataArray
ldd.pmv = np.asarray(pmf.values)

Comment on lines 527 to +531
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_dataset() documents that pmf must have dimension 'atom', but it only validates pmf’s type. Please also validate that 'atom' is in pmf.dims (and ideally pmf.ndim==1), and that pmf.sizes['atom'] matches the dataset’s 'atom' length when present, to avoid constructing inconsistent distributions.

Copilot uses AI. Check for mistakes.
# Extract atoms from dataset variables that have the "atom" dimension.
# For DiscreteDistribution, atoms has shape (..., n_atoms) where the last
# dimension indexes "atom" (the random realization).
# Variables without the "atom" dimension (e.g., scalar summaries) are kept
# in the dataset but not included in atoms.
var_names = list(ldd.dataset.data_vars)
if var_names:
# Filter to only include variables that have the "atom" dimension
vars_with_atom = [
var for var in var_names if "atom" in ldd.dataset[var].dims
]

if vars_with_atom:
var_arrays = [ldd.dataset[var].values for var in vars_with_atom]
# Check if all arrays have the same shape (required for np.stack)
shapes = [arr.shape for arr in var_arrays]
if len(set(shapes)) == 1:
ldd.atoms = np.atleast_2d(np.stack(var_arrays, axis=0))
else:
Comment on lines +545 to +550
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_dataset() uses .values directly when stacking variables into atoms, but DiscreteDistribution assumes the last axis of atoms corresponds to 'atom'. If a DataArray’s dims are ordered like ('atom','grid') instead of ('grid','atom'), this will build atoms with the wrong axis semantics and break expected()/draw(). Consider transposing each variable so 'atom' is last before extracting values.

Copilot uses AI. Check for mistakes.
raise ValueError(
f"from_dataset(): variables with 'atom' dimension have "
f"incompatible shapes ({dict(zip(vars_with_atom, shapes))}). "
f"Cannot construct a valid distribution with mixed-shape "
f"atoms. Ensure all variables have the same shape."
Comment on lines +551 to +555
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description mentions using a placeholder strategy for incompatible atom shapes, but from_dataset() currently raises ValueError when shapes differ. Either implement the placeholder behavior or update the PR description/docstring to reflect that mixed-shape atom variables are unsupported.

Copilot uses AI. Check for mistakes.
)
else:
ldd.atoms = np.atleast_2d(np.array([]))
else:
ldd.atoms = np.atleast_2d(np.array([]))

# Compute limit from atoms
if ldd.atoms.size > 0:
ldd.limit = {
"infimum": np.min(ldd.atoms, axis=-1),
"supremum": np.max(ldd.atoms, axis=-1),
}
else:
ldd.limit = {"infimum": np.array([]), "supremum": np.array([])}

# Initialize base class attributes that __init__ would normally set
ldd.infimum = ldd.limit["infimum"].copy()
ldd.supremum = ldd.limit["supremum"].copy()

# Initialize seed and RNG using the property setter from Distribution base class
ldd.seed = seed

return ldd

@property
Expand All @@ -490,6 +584,26 @@ def _weighted(self):
"""
return self.dataset.weighted(self.probability)

def _weighted_mean_of(self, f_query):
"""
Compute the probability-weighted mean of function output over
the "atom" dimension without constructing an intermediate
distribution. Handles Dataset, DataArray, and dict results.
"""
if isinstance(f_query, xr.Dataset):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, xr.DataArray):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, dict):
ds = xr.Dataset(f_query)
Comment on lines +593 to +598
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_weighted_mean_of() will raise for xr.DataArray outputs that don’t have an 'atom' dimension (e.g., if the user function already reduced over 'atom'). Consider checking whether 'atom' is in f_query.dims before calling .mean('atom'), and either return f_query unchanged or raise a clearer error about required output dims.

Suggested change
if isinstance(f_query, xr.Dataset):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, xr.DataArray):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, dict):
ds = xr.Dataset(f_query)
if isinstance(f_query, xr.Dataset):
if "atom" not in f_query.dims:
raise ValueError(
"expected() function must return an xr.Dataset with an "
"'atom' dimension to compute a probability-weighted mean; "
f"got dimensions {tuple(f_query.dims)} instead."
)
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, xr.DataArray):
if "atom" not in f_query.dims:
raise ValueError(
"expected() function must return an xr.DataArray with an "
"'atom' dimension to compute a probability-weighted mean; "
f"got dimensions {tuple(f_query.dims)} instead."
)
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, dict):
ds = xr.Dataset(f_query)
if "atom" not in ds.dims:
raise ValueError(
"expected() function must return data with an 'atom' "
"dimension to compute a probability-weighted mean; "
f"got dimensions {tuple(ds.dims)} instead."
)

Copilot uses AI. Check for mistakes.
return ds.weighted(self.probability).mean("atom")
else:
raise TypeError(
f"expected() function returned unsupported type "
f"{type(f_query).__name__}. Function must return an "
f"xr.Dataset, xr.DataArray, or dict."
)

@property
def variables(self):
"""
Expand Down Expand Up @@ -580,18 +694,33 @@ def expected(
\\*args :
Other inputs for func, representing the non-stochastic arguments.
The the expectation is computed at ``f(dstn, *args)``.
labels : bool
If True, the function should use labeled indexing instead of integer
indexing using the distribution's underlying rv coordinates. For example,
if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then
the function can be `lambda x: x["a"] + x["b"]`.
labels : bool, optional
Controls whether the function receives labeled or raw indexing.
Defaults to True. When True (default), the function receives a dict
with variable names as keys (e.g., ``lambda x: x["a"] + x["b"]``).
When False, the function receives raw numpy arrays and should use
integer indexing (e.g., ``lambda x: x[0] + x[1]``).
Note: ``labels`` has no effect when the dataset has dimensions
beyond "atom" or when keyword arguments are passed, since those
paths always use xarray-labeled operations.
**kwargs :
Keyword arguments forwarded to func when using xarray operations.
Note: ``labels`` is a reserved parameter for this method and is
never forwarded to func.

Returns
-------
f_exp : np.array or scalar
The expectation of the function at the queried values.
Scalar if only one value.
"""
# Extract the 'labels' parameter from kwargs since it's a reserved parameter
# for this method, not for the user function
labels = kwargs.pop("labels", True)

# Check if the dataset has dimensions beyond "atom", indicating
# multi-dimensional xarray data that requires xarray operations
requires_xarray_ops = len(set(self.dataset.sizes.keys()) - {"atom"}) > 0

def func_wrapper(x, *args):
"""
Expand All @@ -604,12 +733,32 @@ def func_wrapper(x, *args):
return func(wrapped, *args)

if len(kwargs):
if func is None:
raise ValueError(
"expected(): keyword arguments were provided but func is "
"None. Provide a callable for func, or remove the keyword "
"arguments."
)
f_query = func(self.dataset, *args, **kwargs)
ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)

return ldd._weighted.mean("atom")
return self._weighted_mean_of(f_query)
elif requires_xarray_ops:
if not labels:
warnings.warn(
"expected(): labels=False is not supported for distributions "
"with dimensions beyond 'atom'. Falling back to xarray-path "
"operations, which always use labeled indexing.",
stacklevel=2,
)
if func is None:
# Compute weighted mean directly using xarray weighted operations
return self._weighted.mean("atom")
else:
f_query = func(self.dataset, *args)
return self._weighted_mean_of(f_query)
else:
if func is None:
return super().expected()
else:
elif labels:
return super().expected(func_wrapper, *args)
else:
return super().expected(func, *args)
4 changes: 3 additions & 1 deletion HARK/distributions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ def calc_expectation(dstn, func=None, *args, **kwargs):
f_query = []
for i in range(len(dstn.pmv)):
temp_dict = {
key: float(dstn.variables[key][i]) for key in dstn.variables.keys()
key: float(dstn.variables[key][i])
for key in dstn.variables.keys()
if "atom" in dstn.dataset[key].dims
}
f_query.append(func(temp_dict, *args, **kwargs))
Comment on lines +516 to 520
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the labeled-dstn branch, calc_expectation() indexes with dstn.variables[key][i] and coerces to float. This breaks or mis-indexes when a variable has dims beyond 'atom' (e.g. ('grid','atom')) because [i] selects along the first dimension and float() fails for non-scalars. Use explicit xarray indexing along the atom dimension (e.g., isel(atom=i)) and avoid forcing float so functions can receive arrays when appropriate.

Copilot uses AI. Check for mistakes.
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/ConsumptionSaving/test_ConsAggShockModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_economy(self):

self.economy.AFunc = self.economy.dynamics.AFunc
self.assertAlmostEqual(
self.economy.AFunc[0].slope, 1.08797, places=HARK_PRECISION
self.economy.AFunc[0].slope, 1.09061, places=HARK_PRECISION
)

def test_small_open_economy(self):
Expand Down
Loading