Fix DiscreteDistributionLabeled: labels parameter and from_dataset initialization#1717
Fix DiscreteDistributionLabeled: labels parameter and from_dataset initialization#1717
Conversation
… and improve consistency Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
…support Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
… incompatible shapes Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
…cretedistributionlabeled-design
…nd documentation - Added warnings for random seed generation in DiscreteFrozenDistribution. - Improved the __repr__ method in DiscreteDistribution for better output. - Refined error messages in from_dataset methods of DiscreteDistributionLabeled. - Updated tests to validate new functionality and ensure accuracy in expectations.
There was a problem hiding this comment.
Pull request overview
This PR fixes DiscreteDistributionLabeled expectation behavior and improves construction from xarray inputs so labeled discrete distributions can be safely created and used with expected()/draw().
Changes:
- Treat
labelsas a reserved kwarg inDiscreteDistributionLabeled.expected()and improve dispatch between numpy vs xarray-based expectation paths. - Rework
DiscreteDistributionLabeled.from_dataset()to initialize required distribution attributes (pmv,atoms,limit, RNG/seed). - Extend distribution tests to cover
labelsbehavior,from_dataset()initialization, and multidimensional xarray expectation.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
HARK/distributions/discrete.py |
Updates DiscreteDistributionLabeled.from_dataset() and expected() dispatch; adds xarray-weighted expectation helper. |
HARK/distributions/utils.py |
Tweaks calc_expectation() for labeled distributions to ignore non-atom variables. |
tests/test_distribution.py |
Adds tests for labels kwarg, from_dataset() correctness, and multidimensional expectation. |
tests/ConsumptionSaving/test_ConsAggShockModel.py |
Updates a numeric assertion (slope) to reflect changed behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ldd.probability = pmf | ||
|
|
||
| # Extract pmv from probability DataArray | ||
| ldd.pmv = np.asarray(pmf.values) | ||
|
|
There was a problem hiding this comment.
from_dataset() documents that pmf must have dimension 'atom', but it only validates pmf’s type. Please also validate that 'atom' is in pmf.dims (and ideally pmf.ndim==1), and that pmf.sizes['atom'] matches the dataset’s 'atom' length when present, to avoid constructing inconsistent distributions.
| var_arrays = [ldd.dataset[var].values for var in vars_with_atom] | ||
| # Check if all arrays have the same shape (required for np.stack) | ||
| shapes = [arr.shape for arr in var_arrays] | ||
| if len(set(shapes)) == 1: | ||
| ldd.atoms = np.atleast_2d(np.stack(var_arrays, axis=0)) | ||
| else: |
There was a problem hiding this comment.
from_dataset() uses .values directly when stacking variables into atoms, but DiscreteDistribution assumes the last axis of atoms corresponds to 'atom'. If a DataArray’s dims are ordered like ('atom','grid') instead of ('grid','atom'), this will build atoms with the wrong axis semantics and break expected()/draw(). Consider transposing each variable so 'atom' is last before extracting values.
| raise ValueError( | ||
| f"from_dataset(): variables with 'atom' dimension have " | ||
| f"incompatible shapes ({dict(zip(vars_with_atom, shapes))}). " | ||
| f"Cannot construct a valid distribution with mixed-shape " | ||
| f"atoms. Ensure all variables have the same shape." |
There was a problem hiding this comment.
The PR description mentions using a placeholder strategy for incompatible atom shapes, but from_dataset() currently raises ValueError when shapes differ. Either implement the placeholder behavior or update the PR description/docstring to reflect that mixed-shape atom variables are unsupported.
| if isinstance(f_query, xr.Dataset): | ||
| return f_query.weighted(self.probability).mean("atom") | ||
| elif isinstance(f_query, xr.DataArray): | ||
| return f_query.weighted(self.probability).mean("atom") | ||
| elif isinstance(f_query, dict): | ||
| ds = xr.Dataset(f_query) |
There was a problem hiding this comment.
_weighted_mean_of() will raise for xr.DataArray outputs that don’t have an 'atom' dimension (e.g., if the user function already reduced over 'atom'). Consider checking whether 'atom' is in f_query.dims before calling .mean('atom'), and either return f_query unchanged or raise a clearer error about required output dims.
| if isinstance(f_query, xr.Dataset): | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, xr.DataArray): | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, dict): | |
| ds = xr.Dataset(f_query) | |
| if isinstance(f_query, xr.Dataset): | |
| if "atom" not in f_query.dims: | |
| raise ValueError( | |
| "expected() function must return an xr.Dataset with an " | |
| "'atom' dimension to compute a probability-weighted mean; " | |
| f"got dimensions {tuple(f_query.dims)} instead." | |
| ) | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, xr.DataArray): | |
| if "atom" not in f_query.dims: | |
| raise ValueError( | |
| "expected() function must return an xr.DataArray with an " | |
| "'atom' dimension to compute a probability-weighted mean; " | |
| f"got dimensions {tuple(f_query.dims)} instead." | |
| ) | |
| return f_query.weighted(self.probability).mean("atom") | |
| elif isinstance(f_query, dict): | |
| ds = xr.Dataset(f_query) | |
| if "atom" not in ds.dims: | |
| raise ValueError( | |
| "expected() function must return data with an 'atom' " | |
| "dimension to compute a probability-weighted mean; " | |
| f"got dimensions {tuple(ds.dims)} instead." | |
| ) |
| key: float(dstn.variables[key][i]) | ||
| for key in dstn.variables.keys() | ||
| if "atom" in dstn.dataset[key].dims | ||
| } | ||
| f_query.append(func(temp_dict, *args, **kwargs)) |
There was a problem hiding this comment.
In the labeled-dstn branch, calc_expectation() indexes with dstn.variables[key][i] and coerces to float. This breaks or mis-indexes when a variable has dims beyond 'atom' (e.g. ('grid','atom')) because [i] selects along the first dimension and float() fails for non-scalars. Use explicit xarray indexing along the atom dimension (e.g., isel(atom=i)) and avoid forcing float so functions can receive arrays when appropriate.
| # expected() with kwargs returns an xr.Dataset; extract scalar | ||
| if hasattr(expected_result, "to_array"): | ||
| expected_result = float(expected_result.to_array().values.item()) |
There was a problem hiding this comment.
This test only unwraps xr.Dataset results, but expected(my_func, F, z=3.0) returns an xr.DataArray in the kwargs/xarray path (not a Dataset). If xarray’s DataArray-to-float conversion is not supported in the test environment, assertAlmostEqual will fail. Consider also unwrapping 0-d DataArrays (e.g., via .item()) before comparing.
| # expected() with kwargs returns an xr.Dataset; extract scalar | |
| if hasattr(expected_result, "to_array"): | |
| expected_result = float(expected_result.to_array().values.item()) | |
| # expected() with kwargs may return an xarray Dataset or DataArray; extract scalar | |
| if isinstance(expected_result, xr.Dataset): | |
| expected_result = float(expected_result.to_array().values.item()) | |
| elif isinstance(expected_result, xr.DataArray): | |
| expected_result = float(expected_result.values.item()) |
DiscreteDistributionLabeledhad two design issues: the documentedlabelsparameter inexpected()was passed through to user functions causingTypeError, andfrom_dataset()created incomplete objects missing essential attributes.Changes
Fix
labelsparameter inexpected()labelsfrom kwargs as a reserved parameter instead of passing to user functionlabels=Truefor dict-based indexing,labels=Falsefor raw array indexingFix
from_dataset()to create valid distributionsatomswhen shapes are compatiblepmv,seed,limit,_rngproperlyImprove
expected()dispatch logicrequires_xarray_opsfor clarityExample
Original prompt
💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.