Skip to content

Fix DiscreteDistributionLabeled: labels parameter and from_dataset initialization#1717

Open
Copilot wants to merge 8 commits intomainfrom
copilot/improve-discretedistributionlabeled-design
Open

Fix DiscreteDistributionLabeled: labels parameter and from_dataset initialization#1717
Copilot wants to merge 8 commits intomainfrom
copilot/improve-discretedistributionlabeled-design

Conversation

Copy link
Contributor

Copilot AI commented Feb 5, 2026

DiscreteDistributionLabeled had two design issues: the documented labels parameter in expected() was passed through to user functions causing TypeError, and from_dataset() created incomplete objects missing essential attributes.

Changes

Fix labels parameter in expected()

  • Extract labels from kwargs as a reserved parameter instead of passing to user function
  • Default labels=True for dict-based indexing, labels=False for raw array indexing

Fix from_dataset() to create valid distributions

  • Filter variables to only include those with "atom" dimension
  • Extract and stack atoms when shapes are compatible
  • For incompatible shapes (mixed dimension structures), use valid placeholder based on pmv length
  • Initialize pmv, seed, limit, _rng properly

Improve expected() dispatch logic

  • Use xarray operations when dataset has dimensions beyond "atom"
  • Rename internal variable to requires_xarray_ops for clarity

Example

from HARK.distributions import expected, Normal, DiscreteDistributionLabeled

gamma = DiscreteDistributionLabeled.from_unlabeled(
    Normal(mu=0, sigma=1).discretize(N=7), var_names=["gamma"]
)

# Previously: TypeError: <lambda>() got an unexpected keyword argument 'labels'
# Now works:
expected(func=lambda x: x["gamma"], dist=gamma, labels=True)

# Raw array indexing
expected(func=lambda x: x[0], dist=gamma, labels=False)
Original prompt

improve design of DiscreteDistributionLabeled


💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.

Copilot AI and others added 2 commits February 5, 2026 00:36
… and improve consistency

Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
Copilot AI changed the title [WIP] Improve design of DiscreteDistributionLabeled Fix DiscreteDistributionLabeled.expected() labels parameter and improve consistency Feb 5, 2026
Copilot AI requested a review from alanlujan91 February 5, 2026 00:42
Copilot AI and others added 2 commits February 5, 2026 00:57
…support

Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
Copilot AI changed the title Fix DiscreteDistributionLabeled.expected() labels parameter and improve consistency Fix DiscreteDistributionLabeled: labels parameter and from_dataset initialization Feb 5, 2026
… incompatible shapes

Co-authored-by: alanlujan91 <5382704+alanlujan91@users.noreply.github.com>
…nd documentation

- Added warnings for random seed generation in DiscreteFrozenDistribution.
- Improved the __repr__ method in DiscreteDistribution for better output.
- Refined error messages in from_dataset methods of DiscreteDistributionLabeled.
- Updated tests to validate new functionality and ensure accuracy in expectations.
@alanlujan91 alanlujan91 marked this pull request as ready for review February 26, 2026 18:09
Copilot AI review requested due to automatic review settings February 26, 2026 18:09
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes DiscreteDistributionLabeled expectation behavior and improves construction from xarray inputs so labeled discrete distributions can be safely created and used with expected()/draw().

Changes:

  • Treat labels as a reserved kwarg in DiscreteDistributionLabeled.expected() and improve dispatch between numpy vs xarray-based expectation paths.
  • Rework DiscreteDistributionLabeled.from_dataset() to initialize required distribution attributes (pmv, atoms, limit, RNG/seed).
  • Extend distribution tests to cover labels behavior, from_dataset() initialization, and multidimensional xarray expectation.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
HARK/distributions/discrete.py Updates DiscreteDistributionLabeled.from_dataset() and expected() dispatch; adds xarray-weighted expectation helper.
HARK/distributions/utils.py Tweaks calc_expectation() for labeled distributions to ignore non-atom variables.
tests/test_distribution.py Adds tests for labels kwarg, from_dataset() correctness, and multidimensional expectation.
tests/ConsumptionSaving/test_ConsAggShockModel.py Updates a numeric assertion (slope) to reflect changed behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 527 to +531
ldd.probability = pmf

# Extract pmv from probability DataArray
ldd.pmv = np.asarray(pmf.values)

Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_dataset() documents that pmf must have dimension 'atom', but it only validates pmf’s type. Please also validate that 'atom' is in pmf.dims (and ideally pmf.ndim==1), and that pmf.sizes['atom'] matches the dataset’s 'atom' length when present, to avoid constructing inconsistent distributions.

Copilot uses AI. Check for mistakes.
Comment on lines +545 to +550
var_arrays = [ldd.dataset[var].values for var in vars_with_atom]
# Check if all arrays have the same shape (required for np.stack)
shapes = [arr.shape for arr in var_arrays]
if len(set(shapes)) == 1:
ldd.atoms = np.atleast_2d(np.stack(var_arrays, axis=0))
else:
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_dataset() uses .values directly when stacking variables into atoms, but DiscreteDistribution assumes the last axis of atoms corresponds to 'atom'. If a DataArray’s dims are ordered like ('atom','grid') instead of ('grid','atom'), this will build atoms with the wrong axis semantics and break expected()/draw(). Consider transposing each variable so 'atom' is last before extracting values.

Copilot uses AI. Check for mistakes.
Comment on lines +551 to +555
raise ValueError(
f"from_dataset(): variables with 'atom' dimension have "
f"incompatible shapes ({dict(zip(vars_with_atom, shapes))}). "
f"Cannot construct a valid distribution with mixed-shape "
f"atoms. Ensure all variables have the same shape."
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description mentions using a placeholder strategy for incompatible atom shapes, but from_dataset() currently raises ValueError when shapes differ. Either implement the placeholder behavior or update the PR description/docstring to reflect that mixed-shape atom variables are unsupported.

Copilot uses AI. Check for mistakes.
Comment on lines +593 to +598
if isinstance(f_query, xr.Dataset):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, xr.DataArray):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, dict):
ds = xr.Dataset(f_query)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_weighted_mean_of() will raise for xr.DataArray outputs that don’t have an 'atom' dimension (e.g., if the user function already reduced over 'atom'). Consider checking whether 'atom' is in f_query.dims before calling .mean('atom'), and either return f_query unchanged or raise a clearer error about required output dims.

Suggested change
if isinstance(f_query, xr.Dataset):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, xr.DataArray):
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, dict):
ds = xr.Dataset(f_query)
if isinstance(f_query, xr.Dataset):
if "atom" not in f_query.dims:
raise ValueError(
"expected() function must return an xr.Dataset with an "
"'atom' dimension to compute a probability-weighted mean; "
f"got dimensions {tuple(f_query.dims)} instead."
)
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, xr.DataArray):
if "atom" not in f_query.dims:
raise ValueError(
"expected() function must return an xr.DataArray with an "
"'atom' dimension to compute a probability-weighted mean; "
f"got dimensions {tuple(f_query.dims)} instead."
)
return f_query.weighted(self.probability).mean("atom")
elif isinstance(f_query, dict):
ds = xr.Dataset(f_query)
if "atom" not in ds.dims:
raise ValueError(
"expected() function must return data with an 'atom' "
"dimension to compute a probability-weighted mean; "
f"got dimensions {tuple(ds.dims)} instead."
)

Copilot uses AI. Check for mistakes.
Comment on lines +516 to 520
key: float(dstn.variables[key][i])
for key in dstn.variables.keys()
if "atom" in dstn.dataset[key].dims
}
f_query.append(func(temp_dict, *args, **kwargs))
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the labeled-dstn branch, calc_expectation() indexes with dstn.variables[key][i] and coerces to float. This breaks or mis-indexes when a variable has dims beyond 'atom' (e.g. ('grid','atom')) because [i] selects along the first dimension and float() fails for non-scalars. Use explicit xarray indexing along the atom dimension (e.g., isel(atom=i)) and avoid forcing float so functions can receive arrays when appropriate.

Copilot uses AI. Check for mistakes.
Comment on lines +270 to +272
# expected() with kwargs returns an xr.Dataset; extract scalar
if hasattr(expected_result, "to_array"):
expected_result = float(expected_result.to_array().values.item())
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only unwraps xr.Dataset results, but expected(my_func, F, z=3.0) returns an xr.DataArray in the kwargs/xarray path (not a Dataset). If xarray’s DataArray-to-float conversion is not supported in the test environment, assertAlmostEqual will fail. Consider also unwrapping 0-d DataArrays (e.g., via .item()) before comparing.

Suggested change
# expected() with kwargs returns an xr.Dataset; extract scalar
if hasattr(expected_result, "to_array"):
expected_result = float(expected_result.to_array().values.item())
# expected() with kwargs may return an xarray Dataset or DataArray; extract scalar
if isinstance(expected_result, xr.Dataset):
expected_result = float(expected_result.to_array().values.item())
elif isinstance(expected_result, xr.DataArray):
expected_result = float(expected_result.values.item())

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants