Update interp.py and Test the Speed#3
Conversation
|
Hi @jstac, Sorry I do not have the write permission in this repo, so I cannot ping everyone in this PR I have updated the code so that the dimension changes with the dimension of points. I have run a quick performance analysis: The compile-time performance for the JAX version is better, but the jitted Numba function runs faster in the following runs. I tried to remove the Please kindly let me know your opinion on this. Many thanks in advance. |
|
Thanks @HumphreyYang, this is a good start . I've given you write permissions. |
|
I feel like the jax version probably recompiles in each loop when you change the shapes of grids and points. And it should be much faster if you run it a second time in each loop. @HumphreyYang |
|
Many thanks for your input! @JunnanZ @Smit-create @jstac
Yes, I think we cannot avoid this due to dynamic shapes in JAX, but please see the speed comparison below for the second run in the I think the result will be better if run on a better GPU and when optimized for multiple GPU computing. I also tried to vectorize the |
|
(this might help the discussion about shaping and optimization) Please kindly see below for `jaxpr` of `lin_interp`{ lambda ; a:f64[202] b:f64[202] c:f64[202,202] d:f64[2,2]. let
e:f64[2] = pjit[
jaxpr={ lambda ; f:f64[202] g:f64[202] h:f64[202,202] i:f64[2,2]. let
j:f64[2,2] = pjit[
jaxpr={ lambda ; k:f64[202] l:f64[202] m:f64[2,2] n:i64[]. let
o:f64[1] = dynamic_slice[slice_sizes=(1,)] k 1
p:f64[] = squeeze[dimensions=(0,)] o
q:f64[1] = dynamic_slice[slice_sizes=(1,)] k 0
r:f64[] = squeeze[dimensions=(0,)] q
s:f64[] = sub p r
t:f64[1] = dynamic_slice[slice_sizes=(1,)] l 1
u:f64[] = squeeze[dimensions=(0,)] t
v:f64[1] = dynamic_slice[slice_sizes=(1,)] l 0
w:f64[] = squeeze[dimensions=(0,)] v
x:f64[] = sub u w
y:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
z:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] x
ba:f64[2] = concatenate[dimension=0] y z
bb:f64[1] = dynamic_slice[slice_sizes=(1,)] k 0
bc:f64[] = squeeze[dimensions=(0,)] bb
bd:f64[1] = dynamic_slice[slice_sizes=(1,)] l 0
be:f64[] = squeeze[dimensions=(0,)] bd
bf:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bc
bg:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] be
bh:f64[2] = concatenate[dimension=0] bf bg
bi:f64[2,1] = reshape[dimensions=None new_sizes=(2, 1)] ba
bj:f64[2,1] = reshape[dimensions=None new_sizes=(2, 1)] bh
bk:f64[2,2] = sub m bj
bl:f64[2,2] = div bk bi
in (bl,) }
name=vals_to_coords
] f g i 2
bm:f64[2] = pjit[
jaxpr={ lambda ; bn:f64[202,202] bo:f64[2,2]. let
bp:f64[2] = pjit[
jaxpr={ lambda ; bq:f64[202,202] br:f64[2,2]. let
bs:f64[1,2] = slice[
limit_indices=(1, 2)
start_indices=(0, 0)
strides=(1, 1)
] br
bt:f64[2] = squeeze[dimensions=(0,)] bs
bu:f64[1,2] = slice[
limit_indices=(2, 2)
start_indices=(1, 0)
strides=(1, 1)
] br
bv:f64[2] = squeeze[dimensions=(0,)] bu
bw:f64[2] = floor bt
bx:f64[2] = sub bt bw
by:f64[2] = sub 1.0 bx
bz:i32[2] = convert_element_type[
new_dtype=int32
weak_type=False
] bw
ca:i32[2] = add bz 1
cb:i32[2] = pjit[
jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
cf:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] cd
cg:i32[2] = max cf cc
ch:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] ce
ci:i32[2] = min ch cg
in (ci,) }
name=clip
] bz 0 201
cj:i32[2] = pjit[
jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
cf:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] cd
cg:i32[2] = max cf cc
ch:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] ce
ci:i32[2] = min ch cg
in (ci,) }
name=clip
] ca 0 201
ck:f64[2] = floor bv
cl:f64[2] = sub bv ck
cm:f64[2] = sub 1.0 cl
cn:i32[2] = convert_element_type[
new_dtype=int32
weak_type=False
] ck
co:i32[2] = add cn 1
cp:i32[2] = pjit[
jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
cf:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] cd
cg:i32[2] = max cf cc
ch:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] ce
ci:i32[2] = min ch cg
in (ci,) }
name=clip
] cn 0 201
cq:i32[2] = pjit[
jaxpr={ lambda ; cc:i32[2] cd:i64[] ce:i64[]. let
cf:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] cd
cg:i32[2] = max cf cc
ch:i32[] = convert_element_type[
new_dtype=int32
weak_type=False
] ce
ci:i32[2] = min ch cg
in (ci,) }
name=clip
] co 0 201
cr:bool[2] = lt cb 0
cs:i32[2] = add cb 202
ct:i32[2] = select_n cr cb cs
cu:bool[2] = lt cp 0
cv:i32[2] = add cp 202
cw:i32[2] = select_n cu cp cv
cx:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] ct
cy:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] cw
cz:i32[2,2] = concatenate[dimension=1] cx cy
da:f64[2] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bq cz
db:f64[2] = mul by cm
dc:f64[2] = mul db da
dd:bool[2] = lt cb 0
de:i32[2] = add cb 202
df:i32[2] = select_n dd cb de
dg:bool[2] = lt cq 0
dh:i32[2] = add cq 202
di:i32[2] = select_n dg cq dh
dj:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] df
dk:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] di
dl:i32[2,2] = concatenate[dimension=1] dj dk
dm:f64[2] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bq dl
dn:f64[2] = mul by cl
do:f64[2] = mul dn dm
dp:bool[2] = lt cj 0
dq:i32[2] = add cj 202
dr:i32[2] = select_n dp cj dq
ds:bool[2] = lt cp 0
dt:i32[2] = add cp 202
du:i32[2] = select_n ds cp dt
dv:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] dr
dw:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] du
dx:i32[2,2] = concatenate[dimension=1] dv dw
dy:f64[2] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bq dx
dz:f64[2] = mul bx cm
ea:f64[2] = mul dz dy
eb:bool[2] = lt cj 0
ec:i32[2] = add cj 202
ed:i32[2] = select_n eb cj ec
ee:bool[2] = lt cq 0
ef:i32[2] = add cq 202
eg:i32[2] = select_n ee cq ef
eh:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] ed
ei:i32[2,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(2, 1)
] eg
ej:i32[2,2] = concatenate[dimension=1] eh ei
ek:f64[2] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1)
unique_indices=False
] bq ej
el:f64[2] = mul bx cl
em:f64[2] = mul el ek
en:f64[2] = add dc do
eo:f64[2] = add en ea
ep:f64[2] = add eo em
in (ep,) }
name=_map_coordinates
] bn bo
in (bp,) }
name=jit_map_coordinates
] h j
in (bm,) }
name=lin_interp
] a b c d
in (e,) } |
|
I think there isn't any optimization that we can do with JAX right now. I tried the following function: @jax.jit
def lin_interp(values, points, intervals, low_bounds):
coords = (points - low_bounds) / intervals
return jax.scipy.ndimage.map_coordinates(values, coords, order=1, mode='nearest')Everything is minimal and inline in the above function where we also provide Run 1: |
|
Thanks @Smit-create and @HumphreyYang for tests. This is a bit disappointing. Perhaps we should look at other options? What about https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/batch_interp_regular_nd_grid ? |
|
Hi @Smit-create, would you mind sharing how you benchmarked their performance? I did something like this: np.random.seed(12)
m = 5000
points_list = [np.random.random((m, 2)) for i in range(100)]
points_jnp_list = [jnp.asarray(points.T) for points in points_list]
...
%%time
for points in points_list:
# run eval_linear
%%time
for points in points_jnp_list:
# run lin_interp in JaxFor However, if I only run each method one time, |
|
Hi @JunnanZ, I used the following script Benchmark
import jax
import jax.numpy as jnp
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import time
from interpolation.splines import UCGrid, nodes, eval_linear
import jax.numpy as jnp
import matplotlib.pyplot as plt
@jax.jit
def lin_interp(values, points, intervals, low_bounds):
coords = (points - low_bounds) / intervals
# Interpolate using coordinates
return jax.scipy.ndimage.map_coordinates(values, coords, order=1, mode='nearest')
def test_linear_interp(N_points, N_grid):
f = lambda x,y: np.sin(np.sqrt(x**2+y**2+0.00001))/np.sqrt(x**2+y**2+0.00001)
grid = UCGrid((-1.0, 1.0, N_grid), (-1.0, 1.0, N_grid))
# get grid points
gp = nodes(grid) # 100x2 matrix
# compute values on grid points
values = f(gp[:,0], gp[:,1]).reshape((N_grid, N_grid))
points = np.random.random((N_points,2))
time_start = time.time()
linear_interp_numba = eval_linear(grid, values, points) # 10000 vector
time_numba = time.time() - time_start
time_start = time.time()
linear_interp_numba_2 = eval_linear(grid, values, points) # 10000 vector
time_numba_2 = time.time() - time_start
grids = (jnp.linspace(*grid[0]), jnp.linspace(*grid[1]))
intervals = jnp.asarray([grid[1] - grid[0] for grid in grids]).reshape(-1, 1)
low_bounds = jnp.asarray([grid[0] for grid in grids]).reshape(-1, 1)
values_jnp = jnp.asarray(values)
points_jnp = jnp.asarray(points.T)
time_start = time.time()
linear_interp_jax = lin_interp(values_jnp, points_jnp,
intervals, low_bounds).block_until_ready()
time_jax = time.time() - time_start
time_start = time.time()
linear_interp_jax_2 = lin_interp(values_jnp, points_jnp,
intervals, low_bounds).block_until_ready()
time_jax_2 = time.time() - time_start
assert jnp.allclose(linear_interp_numba, linear_interp_jax), 'Interpolation results are not the same'
return time_numba, time_jax, time_numba_2, time_jax_2
time_numba_list = []
time_jax_list = []
time_numba_list_2 = []
time_jax_list_2 = []
for i in np.arange(2, 1000, 200):
for j in np.arange(2, 1000, 200):
time_numba, time_jax, time_numba_2, time_jax_2 = test_linear_interp(i, j)
time_numba_list.append(time_numba)
time_jax_list.append(time_jax)
time_numba_list_2.append(time_numba_2)
time_jax_list_2.append(time_jax_2)
# plot the results
plt.plot(time_numba_list, label='Numba')
plt.plot(time_jax_list, label='JAX')
plt.legend()
plt.xlabel('# Trial for Run 1')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 1')
plt.savefig('linear_interp.png')
plt.show()
plt.plot(time_numba_list_2, label='Numba')
plt.plot(time_jax_list_2, label='JAX')
plt.legend()
plt.xlabel('# Trial for Run 2')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 2')
plt.savefig('linear_interp2.png')
plt.show() |
Please see the results with TensorFlow included: Script to reproduce this: Script
import jax
import jax.numpy as jnp
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import time
from interpolation.splines import UCGrid, nodes, eval_linear
import jax.numpy as jnp
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
@jax.jit
def lin_interp(values, points, intervals, low_bounds):
coords = (points - low_bounds) / intervals
# Interpolate using coordinates
return jax.scipy.ndimage.map_coordinates(values, coords, order=1, mode='nearest')
def test_linear_interp(N_points, N_grid):
f = lambda x,y: np.sin(np.sqrt(x**2+y**2+0.00001))/np.sqrt(x**2+y**2+0.00001)
grid = UCGrid((-1.0, 1.0, N_grid), (-1.0, 1.0, N_grid))
# get grid points
gp = nodes(grid) # 100x2 matrix
# compute values on grid points
values = f(gp[:,0], gp[:,1]).reshape((N_grid, N_grid))
points = np.random.random((N_points,2))
time_start = time.time()
linear_interp_numba = eval_linear(grid, values, points) # 10000 vector
time_numba = time.time() - time_start
time_start = time.time()
linear_interp_numba_2 = eval_linear(grid, values, points) # 10000 vector
time_numba_2 = time.time() - time_start
grids = (jnp.linspace(*grid[0]), jnp.linspace(*grid[1]))
intervals = jnp.asarray([grid[1] - grid[0] for grid in grids]).reshape(-1, 1)
low_bounds = jnp.asarray([grid[0] for grid in grids]).reshape(-1, 1)
values_jnp = jnp.asarray(values)
points_jnp = jnp.asarray(points.T)
time_start = time.time()
linear_interp_jax = lin_interp(values_jnp, points_jnp,
intervals, low_bounds).block_until_ready()
time_jax = time.time() - time_start
time_start = time.time()
linear_interp_jax_2 = lin_interp(values_jnp, points_jnp,
intervals, low_bounds).block_until_ready()
time_jax_2 = time.time() - time_start
x_ref_min = tf.constant([-1.0, -1.0], dtype=tf.float64)
x_ref_max = tf.constant([1.0, 1.0], dtype=tf.float64)
values_tf = tf.convert_to_tensor(values)
points_tf = tf.convert_to_tensor(points)
time_start = time.time()
tf_res = tfp.math.batch_interp_regular_nd_grid(points_tf, x_ref_min, x_ref_max, values_tf, axis=0)
time_tf = time.time() - time_start
time_start = time.time()
tf_res = tfp.math.batch_interp_regular_nd_grid(points_tf, x_ref_min, x_ref_max, values_tf, axis=0)
time_tf_2 = time.time() - time_start
assert jnp.allclose(linear_interp_numba, linear_interp_jax), 'Interpolation results are not the same JAX'
assert np.allclose(linear_interp_numba, tf_res), 'Interpolation results are not the same TF'
return time_numba, time_jax, time_tf, time_numba_2, time_jax_2, time_tf_2
time_numba_list = []
time_jax_list = []
time_tf_list = []
time_numba_list_2 = []
time_jax_list_2 = []
time_tf_list_2 = []
for i in np.arange(2, 1000, 200):
for j in np.arange(2, 1000, 200):
time_numba, time_jax, time_tf, time_numba_2, time_jax_2, time_tf_2 = test_linear_interp(i, j)
time_numba_list.append(time_numba)
time_jax_list.append(time_jax)
time_numba_list_2.append(time_numba_2)
time_jax_list_2.append(time_jax_2)
time_tf_list.append(time_tf)
time_tf_list_2.append(time_tf_2)
# plot the results
plt.plot(time_numba_list, label='Numba')
plt.plot(time_jax_list, label='JAX')
plt.plot(time_tf_list, label='TF')
plt.legend()
plt.xlabel('# Trial for Run 1')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 1')
plt.savefig('linear_interp.png')
plt.show()
plt.plot(time_numba_list_2, label='Numba')
plt.plot(time_jax_list_2, label='JAX')
plt.plot(time_tf_list_2, label='TF')
plt.legend()
plt.xlabel('# Trial for Run 2')
plt.ylabel('Time (s)')
plt.title('Comparison of Numba and JAX versions of linear interpolation - 2')
plt.savefig('linear_interp2.png')
plt.show() |
|
@HumphreyYang @Smit-create @chappiewuzefan Could someone please write up a notebook with a summary of what we have learned so far. Please add a small description of your experiments. What is being tested? What is on the horizontal axis? Are the tests with or without compile time. On what machine are they being run? It will be very strange if the GPU options cannot be competitive with Numba on a CPU. If that's the conclusion backed by these experiments, then do we need to extract the small piece of code that directly implements linear interpolation from |
Hi @Smit-create and @chappiewuzefan, I will work on this if you have not started yet : ) |
|
Sure, thanks @HumphreyYang |
Sure. Sorry for the late reply, I was busy with other tasks, if you need my help just let me know! |
|
Thanks @HumphreyYang. The notebook looks great. Thanks for the detailed analysis. As I said earlier that even the most inline function in JAX (#3 (comment)) is not able to beat numba with changing shapes as a result of overhead from |






Speed Test Result (Compile Time):