Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 91 additions & 97 deletions lectures/arellano.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
jupytext_version: 1.16.7
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -73,14 +73,13 @@ Let’s start with some imports:

import matplotlib.pyplot as plt
import quantecon as qe
import random

import jax
import jax.numpy as jnp
from collections import namedtuple
from typing import NamedTuple
```

Let's check the GPU we are running
Let's check the GPU we are using

```{code-cell} ipython3
!nvidia-smi
Expand Down Expand Up @@ -144,9 +143,9 @@ The bond market has the following features
- The bond matures in one period and is not state contingent.
- A purchase of a bond with face value $ B' $ is a claim to $ B' $ units of the
consumption good next period.
- To purchase $ B' $ next period costs $ q B' $ now, or, what is equivalent.
- For selling $ -B' $ units of next period goods the seller earns $ - q B' $ of
todays goods.
- To purchase $ B' $ next period costs $ q B' $ now, or equivalently, $ q $ per unit of next period's goods.
- For selling $ -B' $ units of next period goods the seller earns $ - q B' $ of
today's goods.
- If $ B' < 0 $, then $ -q B' $ units of the good are received in the current
period, for a promise to repay $ -B' $ units next period.
- There is an equilibrium price function $ q(B', y) $ that makes $ q $ depend
Expand Down Expand Up @@ -365,40 +364,41 @@ We use simple discretization on a grid of asset holdings and income levels.

The output process is discretized using a [quadrature method due to Tauchen](https://github.com/QuantEcon/QuantEcon.py/blob/master/quantecon/markov/approximation.py).

As we have in other places, we accelerate our code using Numba.
As we have in other places, we accelerate our code using JAX and JIT compilation.

We define a namedtuple to store parameters, grids and transition
probabilities.

```{code-cell} ipython3
ArellanoEconomy = namedtuple('ArellanoEconomy',
('β', # Time discount parameter
'γ', # Utility parameter
'r', # Lending rate
'ρ', # Persistence in the income process
'η', # Standard deviation of the income process
'θ', # Prob of re-entering financial markets
'B_size', # Grid size for bonds
'y_size', # Grid size for income
'P', # Markov matrix governing the income process
'B_grid', # Bond unit grid
'y_grid', # State values of the income process
'def_y')) # Default income process
class ArellanoEconomy(NamedTuple):
β: float # Time discount parameter
γ: float # Utility parameter
r: float # Lending rate
ρ: float # Persistence in the income process
η: float # Standard deviation of the income process
θ: float # Prob of re-entering financial markets
B_size: int # Grid size for bonds
y_size: int # Grid size for income
P: jnp.ndarray # Markov matrix governing the income process
B_grid: jnp.ndarray # Bond unit grid
y_grid: jnp.ndarray # State values of the income process
def_y: jnp.ndarray # Default income process
```

```{code-cell} ipython3
def create_arellano(B_size=251, # Grid size for bonds
B_min=-0.45, # Smallest B value
B_max=0.45, # Largest B value
y_size=51, # Grid size for income
β=0.953, # Time discount parameter
γ=2.0, # Utility parameter
r=0.017, # Lending rate
ρ=0.945, # Persistence in the income process
η=0.025, # Standard deviation of the income process
θ=0.282, # Prob of re-entering financial markets
def_y_param=0.969): # Parameter governing income in default

def create_arellano(
B_size=251, # Grid size for bonds
B_min=-0.45, # Smallest B value
B_max=0.45, # Largest B value
y_size=51, # Grid size for income
β=0.953, # Time discount parameter
γ=2.0, # Utility parameter
r=0.017, # Lending rate
ρ=0.945, # Persistence in the income process
η=0.025, # Standard deviation of the income process
θ=0.282, # Prob of re-entering financial markets
def_y_param=0.969,
): # Parameter governing income in default
# Set up grids
B_grid = jnp.linspace(B_min, B_max, B_size)
mc = qe.markov.tauchen(y_size, ρ, η)
Expand All @@ -409,11 +409,21 @@ def create_arellano(B_size=251, # Grid size for bonds

# Output received while in default, with same shape as y_grid
def_y = jnp.minimum(def_y_param * jnp.mean(y_grid), y_grid)

return ArellanoEconomy(β=β, γ=γ, r=r, ρ=ρ, η=η, θ=θ, B_size=B_size,
y_size=y_size, P=P,
B_grid=B_grid, y_grid=y_grid,
def_y=def_y)

return ArellanoEconomy(
β=β,
γ=γ,
r=r,
ρ=ρ,
η=η,
θ=θ,
B_size=B_size,
y_size=y_size,
P=P,
B_grid=B_grid,
y_grid=y_grid,
def_y=def_y,
)
```

Here is the utility function.
Expand All @@ -423,7 +433,7 @@ Here is the utility function.

@jax.jit
def u(c, γ):
return c**(1-γ)/(1-γ)
return c ** (1 - γ) / (1 - γ)
```

Here is a function to compute the bond price at each state, given $ v_c $ and
Expand All @@ -438,7 +448,6 @@ def compute_q(v_c, v_d, params, sizes, arrays):
step is to calculate the default probabilities

δ(B, y) := Σ_{y'} 1{v_c(B, y') < v_d(y')} P(y, y') dy'

"""

# Unpack
Expand All @@ -455,7 +464,7 @@ def compute_q(v_c, v_d, params, sizes, arrays):
default_states = v_c < v_d
delta = jnp.sum(default_states * P, axis=(2,))

q = (1 - delta ) / (1 + r)
q = (1 - delta) / (1 + r)
return q
```

Expand All @@ -474,7 +483,6 @@ def T_d(v_c, v_d, params, sizes, arrays):
B_size, y_size = sizes
P, B_grid, y_grid, def_y = arrays


B0_idx = jnp.searchsorted(B_grid, 1e-10) # Index at which B is near zero

current_utility = u(def_y, γ)
Expand Down Expand Up @@ -576,15 +584,9 @@ def solve(model, tol=1e-8, max_iter=10_000):
policy and value functions.
"""
# Unpack

β, γ, r, ρ, η, θ, B_size, y_size, P, B_grid, y_grid, def_y = model

params = β, γ, r, ρ, η, θ
sizes = B_size, y_size
arrays = P, B_grid, y_grid, def_y


β, γ, r, ρ, η, θ, B_size, y_size, P, B_grid, y_grid, def_y = model

params = β, γ, r, ρ, η, θ
sizes = B_size, y_size
arrays = P, B_grid, y_grid, def_y
Expand All @@ -598,9 +600,12 @@ def solve(model, tol=1e-8, max_iter=10_000):
while (current_iter < max_iter) and (error > tol):
if current_iter % 100 == 0:
print(f"Entering iteration {current_iter} with error {error}.")
new_v_c, new_v_d = update_values_and_prices(v_c, v_d, params,
sizes, arrays)
error = jnp.max(jnp.abs(new_v_c - v_c)) + jnp.max(jnp.abs(new_v_d - v_d))
new_v_c, new_v_d = update_values_and_prices(
v_c, v_d, params, sizes, arrays
)
error = jnp.max(jnp.abs(new_v_c - v_c)) + jnp.max(
jnp.abs(new_v_d - v_d)
)
v_c, v_d = new_v_c, new_v_d
current_iter += 1

Expand All @@ -620,17 +625,17 @@ ae = create_arellano()
```

```{code-cell} ipython3
%%time
v_c, v_d, q, B_star = solve(ae)
with qe.Timer():
v_c, v_d, q, B_star = solve(ae)
```

We run it again to get rid of compile time.

```{code-cell} ipython3
:hide-output: false

%%time
v_c, v_d, q, B_star = solve(ae)
with qe.Timer():
v_c, v_d, q, B_star = solve(ae)
```

Finally, we write a function that will allow us to simulate the economy once
Expand All @@ -646,7 +651,6 @@ def simulate(model, T, v_c, v_d, q, B_star, key):
Here `model` is an instance of `ArellanoEconomy` and `T` is the length of
the simulation. Endogenous objects `v_c`, `v_d`, `q` and `B_star` are
assumed to come from a solution to `model`.

"""
# Unpack elements of the model
B_size, y_size = model.B_size, model.y_size
Expand All @@ -661,47 +665,41 @@ def simulate(model, T, v_c, v_d, q, B_star, key):

# Create Markov chain and simulate income process
mc = qe.MarkovChain(P, y_grid)
y_sim_indices = mc.simulate_indices(T+1, init=y_idx)
y_sim_indices = mc.simulate_indices(T + 1, init=y_idx)

# Allocate memory for outputs
y_sim = jnp.empty(T)
y_a_sim = jnp.empty(T)
B_sim = jnp.empty(T)
q_sim = jnp.empty(T)
d_sim = jnp.empty(T, dtype=int)
y_sim, y_a_sim = [], []
B_sim, q_sim, d_sim = [], [], []

# Perform simulation
t = 0
while t < T:

# Update y_sim and B_sim
y_sim = y_sim.at[t].set(y_grid[y_idx])
B_sim = B_sim.at[t].set(B_grid[B_idx])

# if in default:
y_sim.append(y_grid[y_idx])
B_sim.append(B_grid[B_idx])
if v_c[B_idx, y_idx] < v_d[y_idx] or in_default:
# Update y_a_sim
y_a_sim = y_a_sim.at[t].set(model.def_y[y_idx])
d_sim = d_sim.at[t].set(1)
y_a_sim.append(model.def_y[y_idx])
d_sim.append(1)
Bp_idx = B0_idx
# Re-enter financial markets next period with prob θ
# in_default = False if jnp.random.rand() < model.θ else True
in_default = False if random.uniform(key) < model.θ else True
key, _ = random.split(key) # Update the random key
in_default = False if jax.random.uniform(key) < model.θ else True
key, _ = jax.random.split(key) # Update the random key
else:
# Update y_a_sim
y_a_sim = y_a_sim.at[t].set(y_sim[t])
d_sim = d_sim.at[t].set(0)
y_a_sim.append(y_sim[t])
d_sim.append(0)
Bp_idx = B_star[B_idx, y_idx]

q_sim = q_sim.at[t].set(q[Bp_idx, y_idx])

q_sim.append(q[Bp_idx, y_idx])
# Update time and indices
t += 1
y_idx = y_sim_indices[t]
B_idx = Bp_idx

return y_sim, y_a_sim, B_sim, q_sim, d_sim
return (
jnp.array(y_sim),
jnp.array(y_a_sim),
jnp.array(B_sim),
jnp.array(q_sim),
jnp.array(d_sim),
)
```

## Results
Expand All @@ -727,7 +725,7 @@ values of output $ y $.


The grid used to compute this figure was relatively fine (`y_size, B_size = 51, 251`),
which explains the minor differences between this and Arrelano’s figure.
which explains the minor differences between this and Arellano's figure.

The figure shows that

Expand Down Expand Up @@ -803,7 +801,7 @@ B_size, y_size = ae.B_size, ae.y_size
r = ae.r

# Create "Y High" and "Y Low" values as 5% devs from mean
high, low = jnp.mean(y_grid) * 1.05, jnp.mean(y_grid) * .95
high, low = jnp.mean(y_grid) * 1.05, jnp.mean(y_grid) * 0.95
iy_high, iy_low = (jnp.searchsorted(y_grid, x) for x in (high, low))

fig, ax = plt.subplots(figsize=(10, 6.5))
Expand All @@ -821,7 +819,7 @@ for i, B in enumerate(B_grid):
ax.plot(x, q_high, label="$y_H$", lw=2, alpha=0.7)
ax.plot(x, q_low, label="$y_L$", lw=2, alpha=0.7)
ax.set_xlabel("$B'$")
ax.legend(loc='upper left', frameon=False)
ax.legend(loc="upper left", frameon=False)
plt.show()
```

Expand All @@ -836,7 +834,7 @@ fig, ax = plt.subplots(figsize=(10, 6.5))
ax.set_title("Value Functions")
ax.plot(B_grid, v[:, iy_high], label="$y_H$", lw=2, alpha=0.7)
ax.plot(B_grid, v[:, iy_low], label="$y_L$", lw=2, alpha=0.7)
ax.legend(loc='upper left')
ax.legend(loc="upper left")
ax.set(xlabel="$B$", ylabel="$v(y, B)$")
ax.set_xlim(min(B_grid), max(B_grid))
plt.show()
Expand All @@ -859,7 +857,7 @@ delta = jnp.sum(default_states * shaped_P, axis=(2,))
# Create figure
fig, ax = plt.subplots(figsize=(10, 6.5))
hm = ax.pcolormesh(B_grid, y_grid, delta.T)
cax = fig.add_axes([.92, .1, .02, .8])
cax = fig.add_axes([0.92, 0.1, 0.02, 0.8])
fig.colorbar(hm, cax=cax)
ax.axis([B_grid.min(), 0.05, y_grid.min(), y_grid.max()])
ax.set(xlabel="$B'$", ylabel="$y$", title="Probability of Default")
Expand All @@ -871,14 +869,9 @@ Plot a time series of major variables simulated from the model
```{code-cell} ipython3
:hide-output: false

import jax.random as random
T = 250
key = random.PRNGKey(42)
key = jax.random.PRNGKey(42)
y_sim, y_a_sim, B_sim, q_sim, d_sim = simulate(ae, T, v_c, v_d, q, B_star, key)

# T = 250
# jnp.random.seed(42)
# y_sim, y_a_sim, B_sim, q_sim, d_sim = simulate(ae, T, v_c, v_d, q, B_star)
```

```{code-cell} ipython3
Expand All @@ -899,7 +892,7 @@ while i < len(d_sim):
start_end_pairs.append((start_default, end_default))

plot_series = (y_sim, B_sim, q_sim)
titles = 'output', 'foreign assets', 'bond price'
titles = "output", "foreign assets", "bond price"

fig, axes = plt.subplots(len(plot_series), 1, figsize=(10, 12))
fig.subplots_adjust(hspace=0.3)
Expand All @@ -912,8 +905,9 @@ for ax, series, title in zip(axes, plot_series, titles):
y_min = s_min - s_range * 0.1
ax.set_ylim(y_min, y_max)
for pair in start_end_pairs:
ax.fill_between(pair, (y_min, y_min), (y_max, y_max),
color='k', alpha=0.3)
ax.fill_between(
pair, (y_min, y_min), (y_max, y_max), color="k", alpha=0.3
)
ax.grid()
ax.plot(range(T), series, lw=2, alpha=0.7)
ax.set(title=title, xlabel="time")
Expand Down
Loading