Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 195 additions & 1 deletion HARK/mat_methods.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,203 @@
from typing import List

from dataclasses import dataclass
import numpy as np
from numba import njit


# %% Class and methods that facilitates simulating populations in discretized
# state spaces
@dataclass
class DiscreteTransitions:
"""
Class to facilitate simulating transitions of populations in discretized state spaces,
supporting both life-cycle and infinite-horizon models.
The class assumes that:
- Death is exogenous and independent of every state.
- Agents that die are replaced by newborns.
- Newborns draw their state from a distribution that is constatn over time.
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the typo 'constatn' to 'constant'.

Suggested change
- Newborns draw their state from a distribution that is constatn over time.
- Newborns draw their state from a distribution that is constant over time.

Copilot uses AI. Check for mistakes.

Parameters
----------
living_tmats : [np.array]
List of transition matrices conditional on survival for each period (life-cycle) or
a single matrix (infinite-horizon).
surv_probs : list
List of survival probabilities for each period (life-cycle) or a single probability (infinite-horizon).
life_cycle : bool, optional
If True, use life-cycle mode; otherwise, use infinite-horizon mode.
newborn_dstn : np.array
Distribution of newborns (initial distribution).
"""

living_tmats: list
surv_probs: list
life_cycle: bool = False
newborn_dstn: np.array

def __post_init__(self):
"""
Initialize the DiscreteTransitions object and check parameter consistency.
"""
if self.life_cycle:
self.T = len(self.living_tmats) + 1
if len(self.surv_probs) != (self.T - 1):
raise ValueError("surv_probs must have length T-1.")
else:
self.T = 1

def iterate_dstn_forward(self, dstn_init):
"""
Propagate a distribution forward one period.

Parameters
----------
dstn_init : [np.array]
Initial distribution to propagate. Must be a list of length T (life-cycle) or a
single array (infinite-horizon).

Returns
-------
[np.array]
The propagated distribution(s).
"""
if self.life_cycle:
return _iterate_dstn_forward_lc(
dstn_init, self.living_tmats, self.surv_probs, self.newborn_dstn
)
else:
return _iterate_dstn_forward_ih(
dstn_init[0], self.living_tmat[0], self.surv_prob[0], self.newborn_dstn
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In infinite-horizon mode, the properties 'self.living_tmat' and 'self.surv_prob' are referenced, but the class defines 'living_tmats' and 'surv_probs'. Please update the variable names to ensure consistency.

Suggested change
dstn_init[0], self.living_tmat[0], self.surv_prob[0], self.newborn_dstn
dstn_init[0], self.living_tmats[0], self.surv_probs[0], self.newborn_dstn

Copilot uses AI. Check for mistakes.
)

def find_conditional_age_dsnt(self, dstn_init):
"""
Given a distribution of agents over states for the first period of life,
find the distribution of agents over states in every age conditional on
their survival.

Parameters
----------
dstn_init : [np.array]
Initial distribution.

Returns
-------
[np.array]
List of distributions by age (life-cycle) or a single-element list (infinite-horizon).
"""
if self.life_cycle:
return _find_conditional_age_dsnt(dstn_init, self.living_tmats)
else:
return dstn_init

def find_steady_state_dstn(self, **kwargs):
"""
Find the steady-state distribution.

Parameters
----------
**kwargs : dict
Additional arguments for infinite-horizon steady-state solver.

Returns
-------
[np.array]
List of steady-state distributions by age (life-cycle) or a single-element list (infinite-horizon).
"""
if self.life_cycle:
if kwargs:
raise ValueError(
"kwargs are not used in the life cycle version of find_steady_state_dstn."
)
return _find_steady_state_dstn_lc(
self.surv_probs, self.newborn_dstn, self.living_tmats
)
else:
return [
_find_steady_state_dstn_ih(
self.newborn_dstn,
self.living_tmats[0],
self.surv_probs[0],
**kwargs,
)
]


# Life cycle methods
@njit
def _iterate_dstn_forward_lc(dstn_init, living_tmats, surv_probs, newborn_dstn):
new_dstn = [newborn_dstn]
dead_mass = 0.0
for i, (d0, tmat) in enumerate(zip(dstn_init, living_tmats)):
new_dstn.append((surv_probs[i] * d0) @ tmat)
dead_mass += (1.0 - surv_probs[i]) * np.sum(d0)
dead_mass += np.sum(dstn_init[-1])
new_dstn[0] *= dead_mass

return new_dstn


def _find_conditional_age_dsnt(dstn_init, living_tmats):
dstns = [dstn_init]
for tmat in living_tmats:
dstns.append(dstns[-1] @ tmat)
return dstns


def _find_steady_state_dstn_lc(surv_probs, newborn_dstn, living_tmats):

ss_age_mass = np.concatenate([np.array([1.0]), np.cumprod(surv_probs)])
ss_age_mass /= np.sum(ss_age_mass)
age_dstns = _find_conditional_age_dsnt(newborn_dstn, living_tmats)
return [age_dstn * age_mass for age_dstn, age_mass in zip(age_dstns, ss_age_mass)]


# Infinite horizon methods
@njit
def _iterate_dstn_forward_ih(dstn_init, living_tmat, surv_prob, newborn_dstn):
dead_mass = 1.0 - surv_prob
new_dstn = surv_prob * dstn_init @ living_tmat
new_dstn += dead_mass * newborn_dstn

return new_dstn


@njit
def _find_steady_state_dstn_ih(
newborn_dstn,
living_tmat,
surv_prob,
max_iter=10000,
tol=1e-10,
normalize_every=100,
dstn_init=None,
):
if dstn_init is None:
dstn = newborn_dstn
else:
dstn = dstn_init
go = True
i = 0
while go:
new_dstn = _iterate_dstn_forward_ih(dstn, living_tmat, surv_prob, newborn_dstn)
if np.max(np.abs(new_dstn - dstn)) < tol:
go = False
dstn = new_dstn
i += 1
if i > max_iter:
go = False
# Renormalize every given number of iterations
if i % normalize_every == 0:
dstn /= np.sum(dstn)

# Return as list just for compatibility with LC methods that return
# a list of age dstns
return dstn


# %% Methods to distribute mass to a grid


@njit
def ravel_index(ind_mat: np.ndarray, dims: np.ndarray) -> np.ndarray:
"""
Expand Down
Loading