diff --git a/lectures/lucas_model.md b/lectures/lucas_model.md index 3112a11..9b7740e 100644 --- a/lectures/lucas_model.md +++ b/lectures/lucas_model.md @@ -4,21 +4,13 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.7 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -(lucas_asset)= -```{raw} html -
- - QuantEcon - -
-``` # Asset Pricing: The Lucas Asset Pricing Model @@ -51,19 +43,22 @@ large as it is in some other lectures. Nonetheless, the gain is nontrivial. -Let's start with some imports: +Let's start with installing `quantecon` package and some imports: + +```{code-cell} ipython3 +!pip install --upgrade quantecon +``` ```{code-cell} ipython3 import jax.numpy as jnp import jax import numpy as np import numba -from scipy.stats import lognorm +import quantecon as qe import matplotlib.pyplot as plt -from time import time ``` -## The Lucas Model +## The Lucas model ```{index} single: Lucas Model ``` @@ -80,7 +75,7 @@ eradicate desires to trade. This makes it very easy to compute competitive equilibrium prices. -### Basic Setup +### Basic setup Let's review the setup. @@ -128,7 +123,7 @@ Here * $u$ is a strictly increasing, strictly concave, continuously differentiable period utility function. * $\mathbb{E}$ is a mathematical expectation. -### Pricing a Lucas Tree +### Pricing a Lucas tree ```{index} single: Lucas Model; Pricing ``` @@ -161,7 +156,7 @@ The decision to hold share $\pi_t$ is actually made at time $t-1$. But this value is inherited as a state variable at time $t$, which explains the choice of subscript. -#### The Dynamic Program +#### The dynamic program ```{index} single: Lucas Model; Dynamic Program ``` @@ -175,7 +170,7 @@ information is the current state $y \in \mathbb R_+$ (dropping the time subscrip This leads us to guess an equilibrium where price is a function $p$ of $y$. -Remarks on the solution method +Remarks on the solution method: * Since this is a competitive (read: price taking) model, the consumer will take this function $p$ as given. * In this way, we determine consumer behavior given $p$ and then use equilibrium conditions to recover $p$. @@ -213,7 +208,7 @@ The solution to this dynamic programming problem is an optimal policy expressing * Each one determines the other, since $c(\pi, y) = \pi (y + p(y))- \pi' (\pi, y) p(y)$ -#### Next Steps +#### Next steps What we need to do now is determine equilibrium prices. @@ -225,7 +220,7 @@ It seems that to obtain these, we will have to However, as Lucas showed, there is a related but more straightforward way to do this. -#### Equilibrium Constraints +#### Equilibrium constraints ```{index} single: Lucas Model; Equilibrium Constraints ``` @@ -239,7 +234,7 @@ In particular, the representative consumer owns the whole tree in every period, Prices must adjust to satisfy these two constraints. -#### The Equilibrium Price Function +#### The equilibrium price function ```{index} single: Lucas Model; Equilibrium Price Function ``` @@ -280,7 +275,7 @@ This is the famous consumption-based asset pricing equation. Before discussing it further we want to solve out for prices. -### Solving the Model +### Solving the model ```{index} single: Lucas Model; Solving ``` @@ -291,7 +286,7 @@ The solution is an equilibrium price function $p^*$. Let's look at how to obtain it. -#### Setting up the Problem +#### Setting up the problem Instead of solving for it directly we'll follow Lucas' indirect approach, first setting @@ -335,17 +330,17 @@ In other words, a solution is a *fixed point* of $T$. This means that we can use fixed point theory to obtain and compute the solution. -#### A Little Fixed Point Theory +#### A little fixed point theory ```{index} single: Fixed Point Theory ``` -Let $cb\mathbb{R}_+$ be the set of continuous bounded functions $f \colon \mathbb{R}_+ \to \mathbb{R}_+$. +Let $C_b(\mathbb{R}_+)$ be the set of continuous bounded functions $f \colon \mathbb{R}_+ \to \mathbb{R}_+$. We now show that -1. $T$ has exactly one fixed point $f^*$ in $cb\mathbb{R}_+$. -1. For any $f \in cb\mathbb{R}_+$, the sequence $T^k f$ converges +1. $T$ has exactly one fixed point $f^*$ in $C_b(\mathbb{R}_+)$. +1. For any $f \in C_b(\mathbb{R}_+)$, the sequence $T^k f$ converges uniformly to $f^*$. ```{note} @@ -361,12 +356,12 @@ $\alpha < 1$ such that :label: ltbc \| Tf - Tg \| \leq \alpha \| f - g \|, -\qquad \forall \, f, g \in cb\mathbb{R}_+ +\qquad \forall \, f, g \in C_b(\mathbb{R}_+) ``` Here $\|h\| := \sup_{x \in \mathbb{R}_+} |h(x)|$. -To see that {eq}`ltbc` is valid, pick any $f,g \in cb\mathbb{R}_+$ and any $y \in \mathbb{R}_+$. +To see that {eq}`ltbc` is valid, pick any $f,g \in C_b(\mathbb{R}_+)$ and any $y \in \mathbb{R}_+$. Observe that, since integrals get larger when absolute values are moved to the inside, @@ -395,7 +390,7 @@ on the left-hand side gives {eq}`ltbc` with $\alpha := \beta$. ``` -The preceding discussion tells that we can compute $f^*$ by picking any arbitrary $f \in cb\mathbb{R}_+$ and then iterating with $T$. +The preceding discussion tells us that we can compute $f^*$ by picking any arbitrary $f \in C_b(\mathbb{R}_+)$ and then iterating with $T$. The equilibrium price function $p^*$ can then be recovered by $p^*(y) = f^*(y) / u'(y)$. @@ -414,7 +409,7 @@ Monte Carlo is not always the fastest method for computing low-dimensional integrals, but it is extremely flexible (for example, it's straightforward to change the underlying state process). -### Numba Code +### Numba code Let's start with code using NumPy / Numba (and then compare it to code using JAX). @@ -423,30 +418,31 @@ We create a function that returns tuples containing parameters and arrays needed for computation. ```{code-cell} ipython3 -def create_lucas_tree_model(γ=2, # CRRA utility parameter - β=0.95, # Discount factor - α=0.90, # Correlation coefficient - σ=0.1, # Volatility coefficient - grid_size=500, - draw_size=1_000, - seed=11): - # Set the grid interval to contain most of the mass of the - # stationary distribution of the consumption endowment - ssd = σ / np.sqrt(1 - α**2) - grid_min, grid_max = np.exp(-4 * ssd), np.exp(4 * ssd) - grid = np.linspace(grid_min, grid_max, grid_size) - # Set up distribution for shocks - np.random.seed(seed) - ϕ = lognorm(σ) - draws = ϕ.rvs(500) - # And the vector h - h = np.empty(grid_size) - for i, y in enumerate(grid): - h[i] = β * np.mean((y**α * draws)**(1 - γ)) - # Pack and return - params = γ, β, α, σ - arrays = grid, draws, h - return params, arrays +def create_lucas_tree_model( + γ=2, # CRRA utility parameter + β=0.95, # Discount factor + α=0.90, # Correlation coefficient + σ=0.1, # Volatility coefficient + grid_size=500, + draw_size=1_000, + seed=11, +): + # Set the grid interval to contain most of the mass of the + # stationary distribution of the consumption endowment + ssd = σ / np.sqrt(1 - α**2) + grid_min, grid_max = np.exp(-4 * ssd), np.exp(4 * ssd) + grid = np.linspace(grid_min, grid_max, grid_size) + # Set up distribution for shocks + np.random.seed(seed) + draws = np.random.lognormal(mean=0, sigma=σ, size=500) + # And the vector h + h = np.empty(grid_size) + for i, y in enumerate(grid): + h[i] = β * np.mean((y**α * draws) ** (1 - γ)) + # Pack and return + params = γ, β, α, σ + arrays = grid, draws, h + return params, arrays ``` Here's a Numba-jitted version of the Lucas operator @@ -478,7 +474,6 @@ to find the fixed point. def solve_model(params, arrays, tol=1e-6, max_iter=500): """ Compute the equilibrium price function. - """ # Unpack γ, β, α, σ = params @@ -504,25 +499,21 @@ params, arrays = create_lucas_tree_model() grid, draws, h = arrays # Solve once to compile -start = time() -price_vals = solve_model(params, arrays) -numba_with_compile_time = time() - start -print("Numba compile plus execution time = ", numba_with_compile_time) +with qe.Timer() as numba_with_compile_time: + price_vals = solve_model(params, arrays) ``` ```{code-cell} ipython3 # Now time execution without compile time -start = time() -price_vals = solve_model(params, arrays) -numba_without_compile_time = time() - start -print("Numba execution time = ", numba_without_compile_time) +with qe.Timer() as numba_without_compile_time: + price_vals = solve_model(params, arrays) ``` ```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(10, 6)) -ax.plot(grid, price_vals, label='$p*(y)$') -ax.set_xlabel('$y$') -ax.set_ylabel('price') +ax.plot(grid, price_vals, label="$p*(y)$") +ax.set_xlabel("$y$") +ax.set_ylabel("price") ax.legend() plt.show() ``` @@ -536,33 +527,35 @@ The price must therefore rise to induce the household to consume the entire endo -### JAX Code +### JAX code Here's a JAX version of the same problem. ```{code-cell} ipython3 -def create_lucas_tree_model(γ=2, # CRRA utility parameter - β=0.95, # Discount factor - α=0.90, # Correlation coefficient - σ=0.1, # Volatility coefficient - grid_size=500, - draw_size=1_000, - seed=11): - # Set the grid interval to contain most of the mass of the - # stationary distribution of the consumption endowment - ssd = σ / jnp.sqrt(1 - α**2) - grid_min, grid_max = jnp.exp(-4 * ssd), jnp.exp(4 * ssd) - grid = jnp.linspace(grid_min, grid_max, grid_size) - - # Set up distribution for shocks - key = jax.random.key(seed) - draws = jax.random.lognormal(key, σ, shape=(draw_size,)) - grid_reshaped = grid.reshape((grid_size, 1)) - draws_reshaped = draws.reshape((-1, draw_size)) - h = β * jnp.mean((grid_reshaped**α * draws_reshaped) ** (1-γ), axis=1) - params = γ, β, α, σ - arrays = grid, draws, h - return params, arrays +def create_lucas_tree_model( + γ=2, # CRRA utility parameter + β=0.95, # Discount factor + α=0.90, # Correlation coefficient + σ=0.1, # Volatility coefficient + grid_size=500, + draw_size=1_000, + seed=11, +): + # Set the grid interval to contain most of the mass of the + # stationary distribution of the consumption endowment + ssd = σ / jnp.sqrt(1 - α**2) + grid_min, grid_max = jnp.exp(-4 * ssd), jnp.exp(4 * ssd) + grid = jnp.linspace(grid_min, grid_max, grid_size) + + # Set up distribution for shocks + key = jax.random.key(seed) + draws = jax.random.lognormal(key, σ, shape=(draw_size,)) + grid_reshaped = grid.reshape((grid_size, 1)) + draws_reshaped = draws.reshape((-1, draw_size)) + h = β * jnp.mean((grid_reshaped**α * draws_reshaped) ** (1 - γ), axis=1) + params = γ, β, α, σ + arrays = grid, draws, h + return params, arrays ``` We'll use the following function to simultaneously compute the expectation @@ -574,13 +567,15 @@ $$ over all $y$ in the grid, under the current specifications. ```{code-cell} ipython3 -@jax.jit +@jax.jit def compute_expectation(y, α, draws, grid, f): return jnp.mean(jnp.interp(y**α * draws, grid, f)) + # Vectorize over y -compute_expectation = jax.vmap(compute_expectation, - in_axes=(0, None, None, None, None)) +compute_expectation = jax.vmap( + compute_expectation, in_axes=(0, None, None, None, None) +) ``` Here's the Lucas operator @@ -601,10 +596,12 @@ def T(params, arrays, f): We'll use successive approximation to compute the fixed point. ```{code-cell} ipython3 -def successive_approx_jax(T, # Operator (callable) - x_0, # Initial condition - tol=1e-6 , # Error tolerance - max_iter=10_000): # Max iteration bound +def successive_approx_jax( + T, # Operator (callable) + x_0, # Initial condition + tol=1e-6, # Error tolerance + max_iter=10_000, +): # Max iteration bound def body_fun(k_x_err): k, x, error = k_x_err x_new = T(x) @@ -615,12 +612,11 @@ def successive_approx_jax(T, # Operator (callable) k, x, error = k_x_err return jnp.logical_and(error > tol, k < max_iter) - k, x, error = jax.lax.while_loop(cond_fun, body_fun, - (1, x_0, tol + 1)) + k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1)) return x -successive_approx_jax = \ - jax.jit(successive_approx_jax, static_argnums=(0,)) + +successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,)) ``` Here's a function that solves the model @@ -629,7 +625,6 @@ Here's a function that solves the model def solve_model(params, arrays, tol=1e-6, max_iter=500): """ Compute the equilibrium price function. - """ # Simplify notation grid, draws, h = arrays @@ -652,28 +647,27 @@ grid, draws, h = arrays γ, β, α, σ = params # Solve once to compile -start = time() -price_vals = solve_model(params, arrays).block_until_ready() -jax_with_compile_time = time() - start -print("JAX compile plus execution time = ", jax_with_compile_time) +with qe.Timer() as jax_with_compile_time: + price_vals = solve_model(params, arrays).block_until_ready() ``` ```{code-cell} ipython3 # Now time execution without compile time -start = time() -price_vals = solve_model(params, arrays).block_until_ready() -jax_without_compile_time = time() - start -print("JAX execution time = ", jax_without_compile_time) -print("Speedup factor = ", numba_without_compile_time/jax_without_compile_time) +with qe.Timer() as jax_without_compile_time: + price_vals = solve_model(params, arrays).block_until_ready() +print( + "Speedup factor = ", + numba_without_compile_time.elapsed / jax_without_compile_time.elapsed, +) ``` Let's check the solutions are similar ```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(10, 6)) -ax.plot(grid, price_vals, label='$p*(y)$') -ax.set_xlabel('$y$') -ax.set_ylabel('price') +ax.plot(grid, price_vals, label="$p*(y)$") +ax.set_xlabel("$y$") +ax.set_ylabel("price") ax.legend() plt.show() ``` @@ -700,16 +694,16 @@ Show this by plotting the price function for the Lucas tree when $\beta = 0.95$ ```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(10, 6)) -for β in (.95, 0.98): +for β in (0.95, 0.98): params, arrays = create_lucas_tree_model(β=β) grid, draws, h = arrays γ, beta, α, σ = params price_vals = solve_model(params, arrays) - label = rf'$\beta = {beta}$' + label = rf"$\beta = {beta}$" ax.plot(grid, price_vals, lw=2, alpha=0.7, label=label) -ax.legend(loc='upper left') -ax.set(xlabel='$y$', ylabel='price', xlim=(min(grid), max(grid))) +ax.legend(loc="upper left") +ax.set(xlabel="$y$", ylabel="price", xlim=(min(grid), max(grid))) plt.show() ```