diff --git a/lectures/arellano.md b/lectures/arellano.md index 817a601..daf7e42 100644 --- a/lectures/arellano.md +++ b/lectures/arellano.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.7 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -73,14 +73,13 @@ Let’s start with some imports: import matplotlib.pyplot as plt import quantecon as qe -import random import jax import jax.numpy as jnp -from collections import namedtuple +from typing import NamedTuple ``` -Let's check the GPU we are running +Let's check the GPU we are using ```{code-cell} ipython3 !nvidia-smi @@ -144,9 +143,9 @@ The bond market has the following features - The bond matures in one period and is not state contingent. - A purchase of a bond with face value $ B' $ is a claim to $ B' $ units of the consumption good next period. -- To purchase $ B' $ next period costs $ q B' $ now, or, what is equivalent. -- For selling $ -B' $ units of next period goods the seller earns $ - q B' $ of - today’s goods. +- To purchase $ B' $ next period costs $ q B' $ now, or equivalently, $ q $ per unit of next period's goods. +- For selling $ -B' $ units of next period goods the seller earns $ - q B' $ of + today's goods. - If $ B' < 0 $, then $ -q B' $ units of the good are received in the current period, for a promise to repay $ -B' $ units next period. - There is an equilibrium price function $ q(B', y) $ that makes $ q $ depend @@ -365,40 +364,41 @@ We use simple discretization on a grid of asset holdings and income levels. The output process is discretized using a [quadrature method due to Tauchen](https://github.com/QuantEcon/QuantEcon.py/blob/master/quantecon/markov/approximation.py). -As we have in other places, we accelerate our code using Numba. +As we have in other places, we accelerate our code using JAX and JIT compilation. We define a namedtuple to store parameters, grids and transition probabilities. ```{code-cell} ipython3 -ArellanoEconomy = namedtuple('ArellanoEconomy', - ('β', # Time discount parameter - 'γ', # Utility parameter - 'r', # Lending rate - 'ρ', # Persistence in the income process - 'η', # Standard deviation of the income process - 'θ', # Prob of re-entering financial markets - 'B_size', # Grid size for bonds - 'y_size', # Grid size for income - 'P', # Markov matrix governing the income process - 'B_grid', # Bond unit grid - 'y_grid', # State values of the income process - 'def_y')) # Default income process +class ArellanoEconomy(NamedTuple): + β: float # Time discount parameter + γ: float # Utility parameter + r: float # Lending rate + ρ: float # Persistence in the income process + η: float # Standard deviation of the income process + θ: float # Prob of re-entering financial markets + B_size: int # Grid size for bonds + y_size: int # Grid size for income + P: jnp.ndarray # Markov matrix governing the income process + B_grid: jnp.ndarray # Bond unit grid + y_grid: jnp.ndarray # State values of the income process + def_y: jnp.ndarray # Default income process ``` ```{code-cell} ipython3 -def create_arellano(B_size=251, # Grid size for bonds - B_min=-0.45, # Smallest B value - B_max=0.45, # Largest B value - y_size=51, # Grid size for income - β=0.953, # Time discount parameter - γ=2.0, # Utility parameter - r=0.017, # Lending rate - ρ=0.945, # Persistence in the income process - η=0.025, # Standard deviation of the income process - θ=0.282, # Prob of re-entering financial markets - def_y_param=0.969): # Parameter governing income in default - +def create_arellano( + B_size=251, # Grid size for bonds + B_min=-0.45, # Smallest B value + B_max=0.45, # Largest B value + y_size=51, # Grid size for income + β=0.953, # Time discount parameter + γ=2.0, # Utility parameter + r=0.017, # Lending rate + ρ=0.945, # Persistence in the income process + η=0.025, # Standard deviation of the income process + θ=0.282, # Prob of re-entering financial markets + def_y_param=0.969, +): # Parameter governing income in default # Set up grids B_grid = jnp.linspace(B_min, B_max, B_size) mc = qe.markov.tauchen(y_size, ρ, η) @@ -409,11 +409,21 @@ def create_arellano(B_size=251, # Grid size for bonds # Output received while in default, with same shape as y_grid def_y = jnp.minimum(def_y_param * jnp.mean(y_grid), y_grid) - - return ArellanoEconomy(β=β, γ=γ, r=r, ρ=ρ, η=η, θ=θ, B_size=B_size, - y_size=y_size, P=P, - B_grid=B_grid, y_grid=y_grid, - def_y=def_y) + + return ArellanoEconomy( + β=β, + γ=γ, + r=r, + ρ=ρ, + η=η, + θ=θ, + B_size=B_size, + y_size=y_size, + P=P, + B_grid=B_grid, + y_grid=y_grid, + def_y=def_y, + ) ``` Here is the utility function. @@ -423,7 +433,7 @@ Here is the utility function. @jax.jit def u(c, γ): - return c**(1-γ)/(1-γ) + return c ** (1 - γ) / (1 - γ) ``` Here is a function to compute the bond price at each state, given $ v_c $ and @@ -438,7 +448,6 @@ def compute_q(v_c, v_d, params, sizes, arrays): step is to calculate the default probabilities δ(B, y) := Σ_{y'} 1{v_c(B, y') < v_d(y')} P(y, y') dy' - """ # Unpack @@ -455,7 +464,7 @@ def compute_q(v_c, v_d, params, sizes, arrays): default_states = v_c < v_d delta = jnp.sum(default_states * P, axis=(2,)) - q = (1 - delta ) / (1 + r) + q = (1 - delta) / (1 + r) return q ``` @@ -474,7 +483,6 @@ def T_d(v_c, v_d, params, sizes, arrays): B_size, y_size = sizes P, B_grid, y_grid, def_y = arrays - B0_idx = jnp.searchsorted(B_grid, 1e-10) # Index at which B is near zero current_utility = u(def_y, γ) @@ -576,15 +584,9 @@ def solve(model, tol=1e-8, max_iter=10_000): policy and value functions. """ # Unpack - - β, γ, r, ρ, η, θ, B_size, y_size, P, B_grid, y_grid, def_y = model - - params = β, γ, r, ρ, η, θ - sizes = B_size, y_size - arrays = P, B_grid, y_grid, def_y - + β, γ, r, ρ, η, θ, B_size, y_size, P, B_grid, y_grid, def_y = model - + params = β, γ, r, ρ, η, θ sizes = B_size, y_size arrays = P, B_grid, y_grid, def_y @@ -598,9 +600,12 @@ def solve(model, tol=1e-8, max_iter=10_000): while (current_iter < max_iter) and (error > tol): if current_iter % 100 == 0: print(f"Entering iteration {current_iter} with error {error}.") - new_v_c, new_v_d = update_values_and_prices(v_c, v_d, params, - sizes, arrays) - error = jnp.max(jnp.abs(new_v_c - v_c)) + jnp.max(jnp.abs(new_v_d - v_d)) + new_v_c, new_v_d = update_values_and_prices( + v_c, v_d, params, sizes, arrays + ) + error = jnp.max(jnp.abs(new_v_c - v_c)) + jnp.max( + jnp.abs(new_v_d - v_d) + ) v_c, v_d = new_v_c, new_v_d current_iter += 1 @@ -620,8 +625,8 @@ ae = create_arellano() ``` ```{code-cell} ipython3 -%%time -v_c, v_d, q, B_star = solve(ae) +with qe.Timer(): + v_c, v_d, q, B_star = solve(ae) ``` We run it again to get rid of compile time. @@ -629,8 +634,8 @@ We run it again to get rid of compile time. ```{code-cell} ipython3 :hide-output: false -%%time -v_c, v_d, q, B_star = solve(ae) +with qe.Timer(): + v_c, v_d, q, B_star = solve(ae) ``` Finally, we write a function that will allow us to simulate the economy once @@ -646,7 +651,6 @@ def simulate(model, T, v_c, v_d, q, B_star, key): Here `model` is an instance of `ArellanoEconomy` and `T` is the length of the simulation. Endogenous objects `v_c`, `v_d`, `q` and `B_star` are assumed to come from a solution to `model`. - """ # Unpack elements of the model B_size, y_size = model.B_size, model.y_size @@ -661,47 +665,41 @@ def simulate(model, T, v_c, v_d, q, B_star, key): # Create Markov chain and simulate income process mc = qe.MarkovChain(P, y_grid) - y_sim_indices = mc.simulate_indices(T+1, init=y_idx) + y_sim_indices = mc.simulate_indices(T + 1, init=y_idx) - # Allocate memory for outputs - y_sim = jnp.empty(T) - y_a_sim = jnp.empty(T) - B_sim = jnp.empty(T) - q_sim = jnp.empty(T) - d_sim = jnp.empty(T, dtype=int) + y_sim, y_a_sim = [], [] + B_sim, q_sim, d_sim = [], [], [] # Perform simulation t = 0 while t < T: - - # Update y_sim and B_sim - y_sim = y_sim.at[t].set(y_grid[y_idx]) - B_sim = B_sim.at[t].set(B_grid[B_idx]) - - # if in default: + y_sim.append(y_grid[y_idx]) + B_sim.append(B_grid[B_idx]) if v_c[B_idx, y_idx] < v_d[y_idx] or in_default: - # Update y_a_sim - y_a_sim = y_a_sim.at[t].set(model.def_y[y_idx]) - d_sim = d_sim.at[t].set(1) + y_a_sim.append(model.def_y[y_idx]) + d_sim.append(1) Bp_idx = B0_idx # Re-enter financial markets next period with prob θ - # in_default = False if jnp.random.rand() < model.θ else True - in_default = False if random.uniform(key) < model.θ else True - key, _ = random.split(key) # Update the random key + in_default = False if jax.random.uniform(key) < model.θ else True + key, _ = jax.random.split(key) # Update the random key else: - # Update y_a_sim - y_a_sim = y_a_sim.at[t].set(y_sim[t]) - d_sim = d_sim.at[t].set(0) + y_a_sim.append(y_sim[t]) + d_sim.append(0) Bp_idx = B_star[B_idx, y_idx] - q_sim = q_sim.at[t].set(q[Bp_idx, y_idx]) - + q_sim.append(q[Bp_idx, y_idx]) # Update time and indices t += 1 y_idx = y_sim_indices[t] B_idx = Bp_idx - return y_sim, y_a_sim, B_sim, q_sim, d_sim + return ( + jnp.array(y_sim), + jnp.array(y_a_sim), + jnp.array(B_sim), + jnp.array(q_sim), + jnp.array(d_sim), + ) ``` ## Results @@ -727,7 +725,7 @@ values of output $ y $. The grid used to compute this figure was relatively fine (`y_size, B_size = 51, 251`), -which explains the minor differences between this and Arrelano’s figure. +which explains the minor differences between this and Arellano's figure. The figure shows that @@ -803,7 +801,7 @@ B_size, y_size = ae.B_size, ae.y_size r = ae.r # Create "Y High" and "Y Low" values as 5% devs from mean -high, low = jnp.mean(y_grid) * 1.05, jnp.mean(y_grid) * .95 +high, low = jnp.mean(y_grid) * 1.05, jnp.mean(y_grid) * 0.95 iy_high, iy_low = (jnp.searchsorted(y_grid, x) for x in (high, low)) fig, ax = plt.subplots(figsize=(10, 6.5)) @@ -821,7 +819,7 @@ for i, B in enumerate(B_grid): ax.plot(x, q_high, label="$y_H$", lw=2, alpha=0.7) ax.plot(x, q_low, label="$y_L$", lw=2, alpha=0.7) ax.set_xlabel("$B'$") -ax.legend(loc='upper left', frameon=False) +ax.legend(loc="upper left", frameon=False) plt.show() ``` @@ -836,7 +834,7 @@ fig, ax = plt.subplots(figsize=(10, 6.5)) ax.set_title("Value Functions") ax.plot(B_grid, v[:, iy_high], label="$y_H$", lw=2, alpha=0.7) ax.plot(B_grid, v[:, iy_low], label="$y_L$", lw=2, alpha=0.7) -ax.legend(loc='upper left') +ax.legend(loc="upper left") ax.set(xlabel="$B$", ylabel="$v(y, B)$") ax.set_xlim(min(B_grid), max(B_grid)) plt.show() @@ -859,7 +857,7 @@ delta = jnp.sum(default_states * shaped_P, axis=(2,)) # Create figure fig, ax = plt.subplots(figsize=(10, 6.5)) hm = ax.pcolormesh(B_grid, y_grid, delta.T) -cax = fig.add_axes([.92, .1, .02, .8]) +cax = fig.add_axes([0.92, 0.1, 0.02, 0.8]) fig.colorbar(hm, cax=cax) ax.axis([B_grid.min(), 0.05, y_grid.min(), y_grid.max()]) ax.set(xlabel="$B'$", ylabel="$y$", title="Probability of Default") @@ -871,14 +869,9 @@ Plot a time series of major variables simulated from the model ```{code-cell} ipython3 :hide-output: false -import jax.random as random T = 250 -key = random.PRNGKey(42) +key = jax.random.PRNGKey(42) y_sim, y_a_sim, B_sim, q_sim, d_sim = simulate(ae, T, v_c, v_d, q, B_star, key) - -# T = 250 -# jnp.random.seed(42) -# y_sim, y_a_sim, B_sim, q_sim, d_sim = simulate(ae, T, v_c, v_d, q, B_star) ``` ```{code-cell} ipython3 @@ -899,7 +892,7 @@ while i < len(d_sim): start_end_pairs.append((start_default, end_default)) plot_series = (y_sim, B_sim, q_sim) -titles = 'output', 'foreign assets', 'bond price' +titles = "output", "foreign assets", "bond price" fig, axes = plt.subplots(len(plot_series), 1, figsize=(10, 12)) fig.subplots_adjust(hspace=0.3) @@ -912,8 +905,9 @@ for ax, series, title in zip(axes, plot_series, titles): y_min = s_min - s_range * 0.1 ax.set_ylim(y_min, y_max) for pair in start_end_pairs: - ax.fill_between(pair, (y_min, y_min), (y_max, y_max), - color='k', alpha=0.3) + ax.fill_between( + pair, (y_min, y_min), (y_max, y_max), color="k", alpha=0.3 + ) ax.grid() ax.plot(range(T), series, lw=2, alpha=0.7) ax.set(title=title, xlabel="time")