diff --git a/HARK/distributions/discrete.py b/HARK/distributions/discrete.py index a689707bb..6d31e3134 100644 --- a/HARK/distributions/discrete.py +++ b/HARK/distributions/discrete.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Union +import warnings from copy import deepcopy import numpy as np import xarray as xr @@ -26,7 +27,8 @@ def __init__( dist : rv_discrete Discrete distribution from scipy.stats. seed : int, optional - Seed for random number generator, by default 0 + Seed for random number generator. If None (default), a random + seed is generated from system entropy. """ rv_discrete_frozen.__init__(self, dist, *args, **kwds) @@ -118,12 +120,16 @@ def __init__( def __repr__(self): out = self.__class__.__name__ + " with " + str(self.pmv.size) + " atoms, " - if self.atoms.shape[0] > 1: - out += "inf=" + str(tuple(self.limit["infimum"])) + ", " - out += "sup=" + str(tuple(self.limit["supremum"])) + ", " + inf = self.limit.get("infimum", np.array([])) + sup = self.limit.get("supremum", np.array([])) + if inf.size == 0: + out += "inf=[], sup=[], " + elif self.atoms.shape[0] > 1: + out += "inf=" + str(tuple(inf)) + ", " + out += "sup=" + str(tuple(sup)) + ", " else: - out += "inf=" + str(self.limit["infimum"][0]) + ", " - out += "sup=" + str(self.limit["supremum"][0]) + ", " + out += "inf=" + str(inf[0]) + ", " + out += "sup=" + str(sup[0]) + ", " out += "seed=" + str(self.seed) return out @@ -189,8 +195,7 @@ def draw( K = np.floor(K_exact).astype(int) # number of slots allocated to each atom M = N - np.sum(K) # number of unallocated slots J = P.size - eps = 1.0 / N - Q = K_exact - eps * K # "missing" probability mass + Q = K_exact - K # fractional part: slots still owed to each atom draws = self._rng.random(M) # uniform draws for "extra" slots # Fill in each unallocated slot, one by one @@ -469,18 +474,107 @@ def from_unlabeled( return ldd @classmethod - def from_dataset(cls, x_obj, pmf): + def from_dataset(cls, x_obj, pmf, seed=None): + """ + Construct a DiscreteDistributionLabeled from xarray objects. + + Parameters + ---------- + x_obj : xr.Dataset, xr.DataArray, or dict + The data containing distribution values. If a Dataset, variables + become the distribution's variables. If a DataArray, it becomes a + single variable (unnamed DataArrays get the name "var_0"). If a + dict, it's converted to a Dataset. + pmf : xr.DataArray + Probability mass values with dimension "atom". + seed : int, optional + Seed for random number generator. If None (default), a random + seed is generated from system entropy. + + Returns + ------- + ldd : DiscreteDistributionLabeled + A properly initialized labeled distribution. + + Raises + ------ + TypeError + If x_obj is not an xr.Dataset, xr.DataArray, or dict, or if + pmf is not an xr.DataArray. + """ + if not isinstance(pmf, xr.DataArray): + raise TypeError( + f"from_dataset() requires pmf to be an xr.DataArray with " + f"dimension 'atom', but got {type(pmf).__name__}. " + f"Wrap your probabilities: pmf = xr.DataArray(array, dims=('atom',))" + ) + ldd = cls.__new__(cls) if isinstance(x_obj, xr.Dataset): ldd.dataset = x_obj elif isinstance(x_obj, xr.DataArray): - ldd.dataset = xr.Dataset({x_obj.name: x_obj}) + name = x_obj.name if x_obj.name is not None else "var_0" + ldd.dataset = xr.Dataset({name: x_obj}) elif isinstance(x_obj, dict): ldd.dataset = xr.Dataset(x_obj) + else: + raise TypeError( + f"from_dataset() expected x_obj to be an xr.Dataset, " + f"xr.DataArray, or dict, but got {type(x_obj).__name__}." + ) ldd.probability = pmf + # Extract pmv from probability DataArray + ldd.pmv = np.asarray(pmf.values) + + # Extract atoms from dataset variables that have the "atom" dimension. + # For DiscreteDistribution, atoms has shape (..., n_atoms) where the last + # dimension indexes "atom" (the random realization). + # Variables without the "atom" dimension (e.g., scalar summaries) are kept + # in the dataset but not included in atoms. + var_names = list(ldd.dataset.data_vars) + if var_names: + # Filter to only include variables that have the "atom" dimension + vars_with_atom = [ + var for var in var_names if "atom" in ldd.dataset[var].dims + ] + + if vars_with_atom: + var_arrays = [ldd.dataset[var].values for var in vars_with_atom] + # Check if all arrays have the same shape (required for np.stack) + shapes = [arr.shape for arr in var_arrays] + if len(set(shapes)) == 1: + ldd.atoms = np.atleast_2d(np.stack(var_arrays, axis=0)) + else: + raise ValueError( + f"from_dataset(): variables with 'atom' dimension have " + f"incompatible shapes ({dict(zip(vars_with_atom, shapes))}). " + f"Cannot construct a valid distribution with mixed-shape " + f"atoms. Ensure all variables have the same shape." + ) + else: + ldd.atoms = np.atleast_2d(np.array([])) + else: + ldd.atoms = np.atleast_2d(np.array([])) + + # Compute limit from atoms + if ldd.atoms.size > 0: + ldd.limit = { + "infimum": np.min(ldd.atoms, axis=-1), + "supremum": np.max(ldd.atoms, axis=-1), + } + else: + ldd.limit = {"infimum": np.array([]), "supremum": np.array([])} + + # Initialize base class attributes that __init__ would normally set + ldd.infimum = ldd.limit["infimum"].copy() + ldd.supremum = ldd.limit["supremum"].copy() + + # Initialize seed and RNG using the property setter from Distribution base class + ldd.seed = seed + return ldd @property @@ -490,6 +584,26 @@ def _weighted(self): """ return self.dataset.weighted(self.probability) + def _weighted_mean_of(self, f_query): + """ + Compute the probability-weighted mean of function output over + the "atom" dimension without constructing an intermediate + distribution. Handles Dataset, DataArray, and dict results. + """ + if isinstance(f_query, xr.Dataset): + return f_query.weighted(self.probability).mean("atom") + elif isinstance(f_query, xr.DataArray): + return f_query.weighted(self.probability).mean("atom") + elif isinstance(f_query, dict): + ds = xr.Dataset(f_query) + return ds.weighted(self.probability).mean("atom") + else: + raise TypeError( + f"expected() function returned unsupported type " + f"{type(f_query).__name__}. Function must return an " + f"xr.Dataset, xr.DataArray, or dict." + ) + @property def variables(self): """ @@ -580,11 +694,19 @@ def expected( \\*args : Other inputs for func, representing the non-stochastic arguments. The the expectation is computed at ``f(dstn, *args)``. - labels : bool - If True, the function should use labeled indexing instead of integer - indexing using the distribution's underlying rv coordinates. For example, - if `dims = ('rv', 'x')` and `coords = {'rv': ['a', 'b'], }`, then - the function can be `lambda x: x["a"] + x["b"]`. + labels : bool, optional + Controls whether the function receives labeled or raw indexing. + Defaults to True. When True (default), the function receives a dict + with variable names as keys (e.g., ``lambda x: x["a"] + x["b"]``). + When False, the function receives raw numpy arrays and should use + integer indexing (e.g., ``lambda x: x[0] + x[1]``). + Note: ``labels`` has no effect when the dataset has dimensions + beyond "atom" or when keyword arguments are passed, since those + paths always use xarray-labeled operations. + **kwargs : + Keyword arguments forwarded to func when using xarray operations. + Note: ``labels`` is a reserved parameter for this method and is + never forwarded to func. Returns ------- @@ -592,6 +714,13 @@ def expected( The expectation of the function at the queried values. Scalar if only one value. """ + # Extract the 'labels' parameter from kwargs since it's a reserved parameter + # for this method, not for the user function + labels = kwargs.pop("labels", True) + + # Check if the dataset has dimensions beyond "atom", indicating + # multi-dimensional xarray data that requires xarray operations + requires_xarray_ops = len(set(self.dataset.sizes.keys()) - {"atom"}) > 0 def func_wrapper(x, *args): """ @@ -604,12 +733,32 @@ def func_wrapper(x, *args): return func(wrapped, *args) if len(kwargs): + if func is None: + raise ValueError( + "expected(): keyword arguments were provided but func is " + "None. Provide a callable for func, or remove the keyword " + "arguments." + ) f_query = func(self.dataset, *args, **kwargs) - ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) - - return ldd._weighted.mean("atom") + return self._weighted_mean_of(f_query) + elif requires_xarray_ops: + if not labels: + warnings.warn( + "expected(): labels=False is not supported for distributions " + "with dimensions beyond 'atom'. Falling back to xarray-path " + "operations, which always use labeled indexing.", + stacklevel=2, + ) + if func is None: + # Compute weighted mean directly using xarray weighted operations + return self._weighted.mean("atom") + else: + f_query = func(self.dataset, *args) + return self._weighted_mean_of(f_query) else: if func is None: return super().expected() - else: + elif labels: return super().expected(func_wrapper, *args) + else: + return super().expected(func, *args) diff --git a/HARK/distributions/utils.py b/HARK/distributions/utils.py index 589efbe76..3210770f9 100644 --- a/HARK/distributions/utils.py +++ b/HARK/distributions/utils.py @@ -513,7 +513,9 @@ def calc_expectation(dstn, func=None, *args, **kwargs): f_query = [] for i in range(len(dstn.pmv)): temp_dict = { - key: float(dstn.variables[key][i]) for key in dstn.variables.keys() + key: float(dstn.variables[key][i]) + for key in dstn.variables.keys() + if "atom" in dstn.dataset[key].dims } f_query.append(func(temp_dict, *args, **kwargs)) else: diff --git a/tests/ConsumptionSaving/test_ConsAggShockModel.py b/tests/ConsumptionSaving/test_ConsAggShockModel.py index 3c6ec047e..586c09f3e 100644 --- a/tests/ConsumptionSaving/test_ConsAggShockModel.py +++ b/tests/ConsumptionSaving/test_ConsAggShockModel.py @@ -88,7 +88,7 @@ def test_economy(self): self.economy.AFunc = self.economy.dynamics.AFunc self.assertAlmostEqual( - self.economy.AFunc[0].slope, 1.08797, places=HARK_PRECISION + self.economy.AFunc[0].slope, 1.09061, places=HARK_PRECISION ) def test_small_open_economy(self): diff --git a/tests/test_distribution.py b/tests/test_distribution.py index ad0035582..084490414 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -264,9 +264,16 @@ def test_calc_exp_labeled(self): def my_func(S, z): return S["x"] + 2 * S["y"] + 3 * z + calc_exp_result = calc_expectation(F, my_func, z=3.0) + expected_result = expected(my_func, F, z=3.0) + + # expected() with kwargs returns an xr.Dataset; extract scalar + if hasattr(expected_result, "to_array"): + expected_result = float(expected_result.to_array().values.item()) + self.assertAlmostEqual( - calc_expectation(F, my_func, z=3.0), - expected(my_func, F, z=3.0), + calc_exp_result, + expected_result, places=HARK_PRECISION, ) @@ -672,6 +679,28 @@ def test_self_expected_value(self): self.assertAlmostEqual(ce2[3], 9.51802, places=HARK_PRECISION) + def test_labels_parameter(self): + """Test the labels parameter works correctly (issue #1487).""" + gamma = DiscreteDistributionLabeled.from_unlabeled( + Normal(mu=5.0, sigma=1).discretize(N=7), var_names=["gamma"] + ) + + # Test with labels=True (explicit) using labeled indexing + result_labeled = expected(func=lambda x: x["gamma"], dist=gamma, labels=True) + self.assertAlmostEqual(result_labeled, 5.0, places=4) + + # Test without labels argument (defaults to True for labeled dist) + result_default = expected(func=lambda x: x["gamma"], dist=gamma) + self.assertAlmostEqual(result_default, 5.0, places=4) + + # Test with labels=False using raw array indexing + result_unlabeled = expected(func=lambda x: x[0], dist=gamma, labels=False) + self.assertAlmostEqual(result_unlabeled, 5.0, places=4) + + # All results should be equal + self.assertAlmostEqual(result_labeled, result_default, places=10) + self.assertAlmostEqual(result_labeled, result_unlabeled, places=10) + def test_getters_setters(self): # Create some dummy dsnt dist = DiscreteDistributionLabeled( @@ -751,6 +780,122 @@ def test_Bernoulli_to_labeled(self): bern = DiscreteDistributionLabeled.from_unlabeled(foo, var_names=["foo"]) self.assertTrue(np.allclose(bern.expected(), p)) + def test_from_dataset_creates_valid_distribution(self): + """Test that from_dataset sets atoms and other attributes for a valid distribution.""" + n_variables = 2 + n_atoms = 2 + pmf = xr.DataArray([0.3, 0.7], dims=("atom",)) + ds = xr.Dataset( + { + "var1": xr.DataArray([1.0, 2.0], dims=("atom",)), + "var2": xr.DataArray([3.0, 4.0], dims=("atom",)), + } + ) + ldd = DiscreteDistributionLabeled.from_dataset(ds, pmf) + + # Verify essential attributes have correct types + self.assertIsInstance(ldd.seed, int) + self.assertIsNotNone(ldd._rng) + self.assertIsInstance(ldd._rng, np.random.Generator) + + # Verify atoms and pmv have correct values + self.assertEqual(ldd.atoms.shape, (n_variables, n_atoms)) + np.testing.assert_array_almost_equal(ldd.atoms[0], [1.0, 2.0]) + np.testing.assert_array_almost_equal(ldd.atoms[1], [3.0, 4.0]) + np.testing.assert_array_almost_equal(ldd.pmv, [0.3, 0.7]) + + # Verify limit is computed correctly + np.testing.assert_array_almost_equal(ldd.limit["infimum"], [1.0, 3.0]) + np.testing.assert_array_almost_equal(ldd.limit["supremum"], [2.0, 4.0]) + + # Verify base class attributes are set + np.testing.assert_array_almost_equal(ldd.infimum, [1.0, 3.0]) + np.testing.assert_array_almost_equal(ldd.supremum, [2.0, 4.0]) + + # Verify expected() works correctly + exp = ldd.expected() + np.testing.assert_array_almost_equal( + exp, [0.3 * 1.0 + 0.7 * 2.0, 0.3 * 3.0 + 0.7 * 4.0] + ) + + # Verify draw() works correctly + draws = ldd.draw(100) + self.assertEqual(draws.shape, (n_variables, 100)) + self.assertTrue(np.all(np.isin(draws[0], [1.0, 2.0]))) + self.assertTrue(np.all(np.isin(draws[1], [3.0, 4.0]))) + + def test_from_dataset_with_dataarray_input(self): + """Test from_dataset with a named DataArray as input.""" + pmf = xr.DataArray([0.5, 0.5], dims=("atom",)) + da = xr.DataArray([1.0, 2.0], dims=("atom",), name="x") + ldd = DiscreteDistributionLabeled.from_dataset(da, pmf) + self.assertIn("x", ldd.dataset.data_vars) + np.testing.assert_array_almost_equal(ldd.atoms[0], [1.0, 2.0]) + self.assertAlmostEqual(float(ldd.expected()[0]), 1.5) + + def test_from_dataset_with_dict_input(self): + """Test from_dataset with a dict as input.""" + pmf = xr.DataArray([0.5, 0.5], dims=("atom",)) + d = {"v": xr.DataArray([1.0, 2.0], dims=("atom",))} + ldd = DiscreteDistributionLabeled.from_dataset(d, pmf) + np.testing.assert_array_almost_equal(ldd.pmv, [0.5, 0.5]) + self.assertIn("v", ldd.dataset.data_vars) + + def test_from_dataset_type_validation(self): + """Test that from_dataset raises errors for invalid input types.""" + pmf = xr.DataArray([0.5, 0.5], dims=("atom",)) + + # Invalid x_obj type + with self.assertRaises(TypeError): + DiscreteDistributionLabeled.from_dataset([1.0, 2.0], pmf) + + # Invalid pmf type (numpy array instead of DataArray) + ds = xr.Dataset({"v": xr.DataArray([1.0, 2.0], dims=("atom",))}) + with self.assertRaises(TypeError): + DiscreteDistributionLabeled.from_dataset(ds, np.array([0.5, 0.5])) + + def test_from_dataset_unnamed_dataarray(self): + """Test that unnamed DataArrays get auto-named to 'var_0'.""" + pmf = xr.DataArray([0.5, 0.5], dims=("atom",)) + da = xr.DataArray([1.0, 2.0], dims=("atom",)) + ldd = DiscreteDistributionLabeled.from_dataset(da, pmf) + self.assertIn("var_0", ldd.dataset.data_vars) + + def test_from_dataset_seed_reproducibility(self): + """Test that seed produces reproducible draws.""" + pmf = xr.DataArray([0.5, 0.5], dims=("atom",)) + ds = xr.Dataset({"v": xr.DataArray([1.0, 2.0], dims=("atom",))}) + ldd_a = DiscreteDistributionLabeled.from_dataset(ds, pmf, seed=42) + ldd_b = DiscreteDistributionLabeled.from_dataset(ds, pmf, seed=42) + draws_a = ldd_a.draw(20) + draws_b = ldd_b.draw(20) + np.testing.assert_array_equal(draws_a, draws_b) + + # None seed should produce a valid integer seed + ldd_c = DiscreteDistributionLabeled.from_dataset(ds, pmf, seed=None) + self.assertIsInstance(ldd_c.seed, int) + + def test_expected_with_func_on_multidim_distribution(self): + """Test the requires_xarray_ops path when a function is provided.""" + pmf = xr.DataArray([0.5, 0.5], dims=("atom",)) + ds = xr.Dataset( + { + "x": xr.DataArray([[1.0, 2.0], [3.0, 4.0]], dims=("grid", "atom")), + } + ) + ldd = DiscreteDistributionLabeled.from_dataset(ds, pmf) + # E[x^2] = 0.5 * [1, 9] + 0.5 * [4, 16] = [2.5, 12.5] + result = ldd.expected(func=lambda d: d["x"] ** 2) + np.testing.assert_array_almost_equal(result.values, [2.5, 12.5]) + + def test_expected_func_none_with_kwargs_raises(self): + """Test that expected() raises ValueError when func=None but kwargs given.""" + gamma = DiscreteDistributionLabeled.from_unlabeled( + Normal(mu=0, sigma=1).discretize(N=7), var_names=["gamma"] + ) + with self.assertRaises(ValueError): + gamma.expected(some_kwarg=1.0) + class labeled_transition_tests(unittest.TestCase): def setUp(self) -> None: @@ -781,10 +926,8 @@ def transition(shocks, state): exp1 = base_dist.expected(transition, state=state_grid) # Expectation after transformation new_state_dstn = base_dist.dist_of_func(transition, state=state_grid) - # TODO: needs a cluncky identity function with an extra argument because - # DDL.expected() behavior is very different with and without kwargs. - # Fix! - exp2 = new_state_dstn.expected(lambda x, unused: x, unused=0) + # Now we can simply call expected() without a function to get the mean + exp2 = new_state_dstn.expected() assert np.all(exp1["m"] == exp2["m"]).item() assert np.all(exp1["n"] == exp2["n"]).item()