Skip to content

Update interp.py and Test the Speed#3

Open
HumphreyYang wants to merge 11 commits intoQuantEcon:mainfrom
HumphreyYang:speed_up
Open

Update interp.py and Test the Speed#3
HumphreyYang wants to merge 11 commits intoQuantEcon:mainfrom
HumphreyYang:speed_up

Conversation

@HumphreyYang
Copy link
Member

@HumphreyYang HumphreyYang commented Jun 13, 2023

Speed Test Result (Compile Time):

Number of interpolatation points:  10
Number of grid points:  10
Time Numba version:
 0.32627010345458984
Time JAX version:
 0.047280073165893555

@HumphreyYang
Copy link
Member Author

HumphreyYang commented Jun 13, 2023

Hi @jstac,

Sorry I do not have the write permission in this repo, so I cannot ping everyone in this PR

I have updated the code so that the dimension changes with the dimension of points. I have run a quick performance analysis:

linear_interp

The compile-time performance for the JAX version is better, but the jitted Numba function runs faster in the following runs. I tried to remove the static_argnames, but the trend still holds. Perhaps we can start optimizing from here.

Please kindly let me know your opinion on this.

Many thanks in advance.

@jstac
Copy link
Contributor

jstac commented Jun 13, 2023

Thanks @HumphreyYang, this is a good start . I've given you write permissions.

@JunnanZ
Copy link
Collaborator

JunnanZ commented Jun 13, 2023

I feel like the jax version probably recompiles in each loop when you change the shapes of grids and points. And it should be much faster if you run it a second time in each loop. @HumphreyYang

@HumphreyYang
Copy link
Member Author

HumphreyYang commented Jun 13, 2023

Many thanks for your input! @JunnanZ @Smit-create @jstac

I feel like the jax version probably recompiles in each loop when you change the shapes of grids and points. And it should be much faster if you run it a second time in each loop. @HumphreyYang

Yes, I think we cannot avoid this due to dynamic shapes in JAX, but please see the speed comparison below for the second run in the for loop:

linear_interp_run2

I think the result will be better if run on a better GPU and when optimized for multiple GPU computing.

I also tried to vectorize the for loop, but the JAX version still needs to be recompiled each time we change the shape (I reverted the vectorized version because it looks complicated as we have JAX + Numba).

@HumphreyYang
Copy link
Member Author

HumphreyYang commented Jun 13, 2023

(this might help the discussion about shaping and optimization)

Please kindly see below for jaxpr of lin_interp function:

`jaxpr` of `lin_interp`
{ lambda ; a:f64[202] b:f64[202] c:f64[202,202] d:f64[2,2]. let
    e:f64[2] = pjit[
      jaxpr={ lambda ; f:f64[202] g:f64[202] h:f64[202,202] i:f64[2,2]. let
          j:f64[2,2] = pjit[
            jaxpr={ lambda ; k:f64[202] l:f64[202] m:f64[2,2] n:i64[]. let
                o:f64[1] = dynamic_slice[slice_sizes=(1,)] k 1
                p:f64[] = squeeze[dimensions=(0,)] o
                q:f64[1] = dynamic_slice[slice_sizes=(1,)] k 0
                r:f64[] = squeeze[dimensions=(0,)] q
                s:f64[] = sub p r
                t:f64[1] = dynamic_slice[slice_sizes=(1,)] l 1
                u:f64[] = squeeze[dimensions=(0,)] t
                v:f64[1] = dynamic_slice[slice_sizes=(1,)] l 0
                w:f64[] = squeeze[dimensions=(0,)] v
                x:f64[] = sub u w
                y:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
                z:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] x
                ba:f64[2] = concatenate[dimension=0] y z
                bb:f64[1] = dynamic_slice[slice_sizes=(1,)] k 0
                bc:f64[] = squeeze[dimensions=(0,)] bb
                bd:f64[1] = dynamic_slice[slice_sizes=(1,)] l 0
                be:f64[] = squeeze[dimensions=(0,)] bd
                bf:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bc
                bg:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] be
                bh:f64[2] = concatenate[dimension=0] bf bg
                bi:f64[2,1] = reshape[dimensions=None new_sizes=(2, 1)] ba
                bj:f64[2,1] = reshape[dimensions=None new_sizes=(2, 1)] bh
                bk:f64[2,2] = sub m bj
                bl:f64[2,2] = div bk bi
              in (bl,) }
            name=vals_to_coords
          ] f g i 2
          bm:f64[2] = pjit[
            jaxpr={ lambda ; bn:f64[202,202] bo:f64[2,2]. let
                bp:f64[2] = pjit[
                  jaxpr={ lambda ; bq:f64[202,202] br:f64[2,2]. let
                      bs:f64[1,2] = slice[
                        limit_indices=(1, 2)
                        start_indices=(0, 0)
                        strides=(1, 1)
                      ] br
                      bt:f64[2] = squeeze[dimensions=(0,)] bs
                      bu:f64[1,2] = slice[
                        limit_indices=(2, 2)
                        start_indices=(1, 0)
                        strides=(1, 1)
                      ] br
                      bv:f64[2] = squeeze[dimensions=(0,)] bu
                      bw:f64[2] = floor bt
                      bx:f64[2] = sub bt bw
                      by:f64[2] = sub 1.0 bx
                      bz:i32[2] = convert_element_type[
                        new_dtype=int32
                        weak_type=False
                      ] bw
                      ca:i32[2] = add bz 1
                      cb:i32[2] = pjit[
                        jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
                            cf:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] cd
                            cg:i32[2] = max cf cc
                            ch:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] ce
                            ci:i32[2] = min ch cg
                          in (ci,) }
                        name=clip
                      ] bz 0 201
                      cj:i32[2] = pjit[
                        jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
                            cf:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] cd
                            cg:i32[2] = max cf cc
                            ch:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] ce
                            ci:i32[2] = min ch cg
                          in (ci,) }
                        name=clip
                      ] ca 0 201
                      ck:f64[2] = floor bv
                      cl:f64[2] = sub bv ck
                      cm:f64[2] = sub 1.0 cl
                      cn:i32[2] = convert_element_type[
                        new_dtype=int32
                        weak_type=False
                      ] ck
                      co:i32[2] = add cn 1
                      cp:i32[2] = pjit[
                        jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
                            cf:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] cd
                            cg:i32[2] = max cf cc
                            ch:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] ce
                            ci:i32[2] = min ch cg
                          in (ci,) }
                        name=clip
                      ] cn 0 201
                      cq:i32[2] = pjit[
                        jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
                            cf:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] cd
                            cg:i32[2] = max cf cc
                            ch:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] ce
                            ci:i32[2] = min ch cg
                          in (ci,) }
                        name=clip
                      ] co 0 201
                      cr:bool[2] = lt cb 0
                      cs:i32[2] = add cb 202
                      ct:i32[2] = select_n cr cb cs
                      cu:bool[2] = lt cp 0
                      cv:i32[2] = add cp 202
                      cw:i32[2] = select_n cu cp cv
                      cx:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] ct
                      cy:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] cw
                      cz:i32[2,2] = concatenate[dimension=1] cx cy
                      da:f64[2] = gather[
                        dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
                        fill_value=None
                        indices_are_sorted=False
                        mode=GatherScatterMode.PROMISE_IN_BOUNDS
                        slice_sizes=(1, 1)
                        unique_indices=False
                      ] bq cz
                      db:f64[2] = mul by cm
                      dc:f64[2] = mul db da
                      dd:bool[2] = lt cb 0
                      de:i32[2] = add cb 202
                      df:i32[2] = select_n dd cb de
                      dg:bool[2] = lt cq 0
                      dh:i32[2] = add cq 202
                      di:i32[2] = select_n dg cq dh
                      dj:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] df
                      dk:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] di
                      dl:i32[2,2] = concatenate[dimension=1] dj dk
                      dm:f64[2] = gather[
                        dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
                        fill_value=None
                        indices_are_sorted=False
                        mode=GatherScatterMode.PROMISE_IN_BOUNDS
                        slice_sizes=(1, 1)
                        unique_indices=False
                      ] bq dl
                      dn:f64[2] = mul by cl
                      do:f64[2] = mul dn dm
                      dp:bool[2] = lt cj 0
                      dq:i32[2] = add cj 202
                      dr:i32[2] = select_n dp cj dq
                      ds:bool[2] = lt cp 0
                      dt:i32[2] = add cp 202
                      du:i32[2] = select_n ds cp dt
                      dv:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] dr
                      dw:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] du
                      dx:i32[2,2] = concatenate[dimension=1] dv dw
                      dy:f64[2] = gather[
                        dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
                        fill_value=None
                        indices_are_sorted=False
                        mode=GatherScatterMode.PROMISE_IN_BOUNDS
                        slice_sizes=(1, 1)
                        unique_indices=False
                      ] bq dx
                      dz:f64[2] = mul bx cm
                      ea:f64[2] = mul dz dy
                      eb:bool[2] = lt cj 0
                      ec:i32[2] = add cj 202
                      ed:i32[2] = select_n eb cj ec
                      ee:bool[2] = lt cq 0
                      ef:i32[2] = add cq 202
                      eg:i32[2] = select_n ee cq ef
                      eh:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] ed
                      ei:i32[2,1] = broadcast_in_dim[
                        broadcast_dimensions=(0,)
                        shape=(2, 1)
                      ] eg
                      ej:i32[2,2] = concatenate[dimension=1] eh ei
                      ek:f64[2] = gather[
                        dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
                        fill_value=None
                        indices_are_sorted=False
                        mode=GatherScatterMode.PROMISE_IN_BOUNDS
                        slice_sizes=(1, 1)
                        unique_indices=False
                      ] bq ej
                      el:f64[2] = mul bx cl
                      em:f64[2] = mul el ek
                      en:f64[2] = add dc do
                      eo:f64[2] = add en ea
                      ep:f64[2] = add eo em
                    in (ep,) }
                  name=_map_coordinates
                ] bn bo
              in (bp,) }
            name=jit_map_coordinates
          ] h j
        in (bm,) }
      name=lin_interp
    ] a b c d
  in (e,) }

@Smit-create
Copy link
Member

I think there isn't any optimization that we can do with JAX right now. I tried the following function:

@jax.jit
def lin_interp(values, points, intervals, low_bounds):
    coords = (points - low_bounds) / intervals
    return jax.scipy.ndimage.map_coordinates(values, coords, order=1, mode='nearest')

Everything is minimal and inline in the above function where we also provide intervals and low_bounds as an argument for testing purposes and I get the following results for the first and second run respectively. Maybe after looking at this I think the overhead comes from map_coordinates.

Run 1:

linear_interp

Run 2:
linear_interp2

@jstac
Copy link
Contributor

jstac commented Jun 14, 2023

Thanks @Smit-create and @HumphreyYang for tests.

This is a bit disappointing. Perhaps we should look at other options?

What about https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/batch_interp_regular_nd_grid ?

@JunnanZ
Copy link
Collaborator

JunnanZ commented Jun 14, 2023

Hi @Smit-create, would you mind sharing how you benchmarked their performance?

I did something like this:

np.random.seed(12)
m = 5000
points_list = [np.random.random((m, 2)) for i in range(100)]
points_jnp_list = [jnp.asarray(points.T) for points in points_list]

...

%%time
for points in points_list:
    # run eval_linear

%%time
for points in points_jnp_list:
    # run lin_interp in Jax

For m=5000, the computation time is roughly the same for the two methods. I think this is close to our typical use cases where we interpolate a small batch of points (e.g., quadrature points) but we do that multiple times (for every point in the state space).

However, if I only run each method one time, eval_linear is much faster, and m would have to be at least 100000 for Jax to outperform Numba.

@Smit-create
Copy link
Member

Hi @JunnanZ, I used the following script

Benchmark

import jax
import jax.numpy as jnp
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import time
from interpolation.splines import UCGrid, nodes, eval_linear
import jax.numpy as jnp
import matplotlib.pyplot as plt


@jax.jit
def lin_interp(values, points, intervals, low_bounds):
    coords = (points - low_bounds) / intervals
    # Interpolate using coordinates
    return jax.scipy.ndimage.map_coordinates(values, coords, order=1, mode='nearest')



def test_linear_interp(N_points, N_grid):
    f = lambda x,y: np.sin(np.sqrt(x**2+y**2+0.00001))/np.sqrt(x**2+y**2+0.00001)

    grid = UCGrid((-1.0, 1.0, N_grid), (-1.0, 1.0, N_grid))
    # get grid points
    gp = nodes(grid)   # 100x2 matrix

    # compute values on grid points
    values = f(gp[:,0], gp[:,1]).reshape((N_grid, N_grid))

    points = np.random.random((N_points,2))

    time_start = time.time()
    linear_interp_numba = eval_linear(grid, values, points) # 10000 vector
    time_numba = time.time() - time_start

    time_start = time.time()
    linear_interp_numba_2 = eval_linear(grid, values, points) # 10000 vector
    time_numba_2 = time.time() - time_start

    grids = (jnp.linspace(*grid[0]), jnp.linspace(*grid[1]))
    intervals = jnp.asarray([grid[1] - grid[0] for grid in grids]).reshape(-1, 1)
    low_bounds = jnp.asarray([grid[0] for grid in grids]).reshape(-1, 1)
    values_jnp = jnp.asarray(values)
    points_jnp = jnp.asarray(points.T)

    time_start = time.time()
    linear_interp_jax = lin_interp(values_jnp, points_jnp,
                                   intervals, low_bounds).block_until_ready()
    time_jax = time.time() - time_start

    time_start = time.time()
    linear_interp_jax_2 = lin_interp(values_jnp, points_jnp,
                                   intervals, low_bounds).block_until_ready()
    time_jax_2 = time.time() - time_start

    assert jnp.allclose(linear_interp_numba, linear_interp_jax), 'Interpolation results are not the same'

    return time_numba, time_jax, time_numba_2, time_jax_2


time_numba_list = []
time_jax_list = []
time_numba_list_2 = []
time_jax_list_2 = []
for i in np.arange(2, 1000, 200):
    for j in np.arange(2, 1000, 200):
        time_numba, time_jax, time_numba_2, time_jax_2 = test_linear_interp(i, j)
        time_numba_list.append(time_numba)
        time_jax_list.append(time_jax)
        time_numba_list_2.append(time_numba_2)
        time_jax_list_2.append(time_jax_2)

# plot the results
plt.plot(time_numba_list, label='Numba')
plt.plot(time_jax_list, label='JAX')
plt.legend()
plt.xlabel('# Trial for Run 1')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 1')
plt.savefig('linear_interp.png')
plt.show()

plt.plot(time_numba_list_2, label='Numba')
plt.plot(time_jax_list_2, label='JAX')
plt.legend()
plt.xlabel('# Trial for Run 2')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 2')
plt.savefig('linear_interp2.png')
plt.show()

@Smit-create
Copy link
Member

What about https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/batch_interp_regular_nd_grid ?

Please see the results with TensorFlow included:

Run 1:
linear_interp (1)

Run 2:
linear_interp2 (1)

Script to reproduce this:

Script

import jax
import jax.numpy as jnp
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import time
from interpolation.splines import UCGrid, nodes, eval_linear
import jax.numpy as jnp
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

@jax.jit
def lin_interp(values, points, intervals, low_bounds):
    coords = (points - low_bounds) / intervals
    # Interpolate using coordinates
    return jax.scipy.ndimage.map_coordinates(values, coords, order=1, mode='nearest')



def test_linear_interp(N_points, N_grid):
    f = lambda x,y: np.sin(np.sqrt(x**2+y**2+0.00001))/np.sqrt(x**2+y**2+0.00001)

    grid = UCGrid((-1.0, 1.0, N_grid), (-1.0, 1.0, N_grid))
    # get grid points
    gp = nodes(grid)   # 100x2 matrix

    # compute values on grid points
    values = f(gp[:,0], gp[:,1]).reshape((N_grid, N_grid))

    points = np.random.random((N_points,2))

    time_start = time.time()
    linear_interp_numba = eval_linear(grid, values, points) # 10000 vector
    time_numba = time.time() - time_start

    time_start = time.time()
    linear_interp_numba_2 = eval_linear(grid, values, points) # 10000 vector
    time_numba_2 = time.time() - time_start

    grids = (jnp.linspace(*grid[0]), jnp.linspace(*grid[1]))
    intervals = jnp.asarray([grid[1] - grid[0] for grid in grids]).reshape(-1, 1)
    low_bounds = jnp.asarray([grid[0] for grid in grids]).reshape(-1, 1)
    values_jnp = jnp.asarray(values)
    points_jnp = jnp.asarray(points.T)

    time_start = time.time()
    linear_interp_jax = lin_interp(values_jnp, points_jnp,
                                   intervals, low_bounds).block_until_ready()
    time_jax = time.time() - time_start

    time_start = time.time()
    linear_interp_jax_2 = lin_interp(values_jnp, points_jnp,
                                   intervals, low_bounds).block_until_ready()
    time_jax_2 = time.time() - time_start
    x_ref_min = tf.constant([-1.0, -1.0], dtype=tf.float64)
    x_ref_max = tf.constant([1.0, 1.0],  dtype=tf.float64)
    values_tf = tf.convert_to_tensor(values)
    points_tf = tf.convert_to_tensor(points)
    time_start = time.time()
    tf_res = tfp.math.batch_interp_regular_nd_grid(points_tf, x_ref_min, x_ref_max, values_tf, axis=0)
    time_tf = time.time() - time_start
    time_start = time.time()
    tf_res = tfp.math.batch_interp_regular_nd_grid(points_tf, x_ref_min, x_ref_max, values_tf, axis=0)
    time_tf_2 = time.time() - time_start
    assert jnp.allclose(linear_interp_numba, linear_interp_jax), 'Interpolation results are not the same JAX'
    assert np.allclose(linear_interp_numba, tf_res), 'Interpolation results are not the same TF'

    return time_numba, time_jax, time_tf, time_numba_2, time_jax_2, time_tf_2


time_numba_list = []
time_jax_list = []
time_tf_list = []
time_numba_list_2 = []
time_jax_list_2 = []
time_tf_list_2 = []
for i in np.arange(2, 1000, 200):
    for j in np.arange(2, 1000, 200):
        time_numba, time_jax, time_tf, time_numba_2, time_jax_2, time_tf_2 = test_linear_interp(i, j)
        time_numba_list.append(time_numba)
        time_jax_list.append(time_jax)
        time_numba_list_2.append(time_numba_2)
        time_jax_list_2.append(time_jax_2)
        time_tf_list.append(time_tf)
        time_tf_list_2.append(time_tf_2)

# plot the results
plt.plot(time_numba_list, label='Numba')
plt.plot(time_jax_list, label='JAX')
plt.plot(time_tf_list, label='TF')
plt.legend()
plt.xlabel('# Trial for Run 1')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 1')
plt.savefig('linear_interp.png')
plt.show()

plt.plot(time_numba_list_2, label='Numba')
plt.plot(time_jax_list_2, label='JAX')
plt.plot(time_tf_list_2, label='TF')
plt.legend()
plt.xlabel('# Trial for Run 2')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 2')
plt.savefig('linear_interp2.png')
plt.show()

@jstac
Copy link
Contributor

jstac commented Jun 14, 2023

@HumphreyYang @Smit-create @chappiewuzefan

Could someone please write up a notebook with a summary of what we have learned so far. Please add a small description of your experiments.

What is being tested? What is on the horizontal axis? Are the tests with or without compile time. On what machine are they being run?

It will be very strange if the GPU options cannot be competitive with Numba on a CPU.

If that's the conclusion backed by these experiments, then do we need to extract the small piece of code that directly implements linear interpolation from interpolation (Numba library) and try to optimize it for JAX directly?

@HumphreyYang
Copy link
Member Author

Could someone please write up a notebook with a summary of what we have learned so far. Please add a small description of your experiments.

Hi @Smit-create and @chappiewuzefan, I will work on this if you have not started yet : )

@Smit-create
Copy link
Member

Sure, thanks @HumphreyYang

@chappiewuzefan
Copy link
Collaborator

Could someone please write up a notebook with a summary of what we have learned so far. Please add a small description of your experiments.

Hi @Smit-create and @chappiewuzefan, I will work on this if you have not started yet : )

Sure. Sorry for the late reply, I was busy with other tasks, if you need my help just let me know!

@HumphreyYang
Copy link
Member Author

HumphreyYang commented Jun 15, 2023

Hi @jstac, please see the summary for the experiments here

The conclusion is that if the sizes (shapes) of the grid and interpolated points do not change, JAX can achieve notable improvement over Numba, but in cases where shapes are dynamic, JAX performs worse than Numba.

@Smit-create
Copy link
Member

Thanks @HumphreyYang. The notebook looks great. Thanks for the detailed analysis. As I said earlier that even the most inline function in JAX (#3 (comment)) is not able to beat numba with changing shapes as a result of overhead from map_coordinates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants