diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index 6880f06..e7cc4b3 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7"] + python-version: ["3.12"] steps: diff --git a/test/convolution_test.py b/test/convolution_test.py index f760bc4..87429b2 100644 --- a/test/convolution_test.py +++ b/test/convolution_test.py @@ -5,7 +5,7 @@ import numpy.testing as npt from scipy import signal, ndimage -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # makes a difference when comparing to scipy's routines!! from utax.convolution import * diff --git a/test/wavelet_test.py b/test/wavelet_test.py index fc76cc2..7dd97f3 100644 --- a/test/wavelet_test.py +++ b/test/wavelet_test.py @@ -7,7 +7,7 @@ import numpy.testing as npt from test.convolution_test import gaussian_kernel -from jax.config import config +from jax import config config.update("jax_enable_x64", True) import utax @@ -16,7 +16,7 @@ class TestWaveletTransform(object): - def setup(self): + def setup_method(self): utax_path = os.path.dirname(utax.__path__[0]) data_path = os.path.join(utax_path, 'test', 'data') diff --git a/utax/interpolation.py b/utax/interpolation.py index 63475bb..35ad64f 100644 --- a/utax/interpolation.py +++ b/utax/interpolation.py @@ -16,17 +16,33 @@ class BilinearInterpolator(object): """ def __init__(self, x, y, z, allow_extrapolation=True): - self.z = jnp.array(z) # z - if np.all(np.diff(x) >= 0): # check if sorted in increasing order - self.x = jnp.array(x) - else: - self.x = jnp.array(np.sort(x)) - self.z = jnp.flip(self.z, axis=0) - if np.all(np.diff(y) >= 0): # check if sorted in increasing order - self.y = jnp.array(y) - else: - self.y = jnp.array(np.sort(y)) - self.z = jnp.flip(self.z, axis=1) + self.z = jnp.array(z) + + # Sort x if not increasing + x = jnp.array(x) + x_sorted = jnp.sort(x) + flip_x = ~jnp.all(jnp.diff(x) >= 0) + + def x_keep_fn(_): + return x, self.z + + def x_sort_fn(_): + return x_sorted, jnp.flip(self.z, axis=0) + + self.x, self.z = lax.cond(flip_x, x_sort_fn, x_keep_fn, operand=None) + + # Sort y if not increasing + y = jnp.array(y) + y_sorted = jnp.sort(y) + flip_y = ~jnp.all(jnp.diff(y) >= 0) + + def y_keep_fn(_): + return y, self.z + + def y_sort_fn(_): + return y_sorted, jnp.flip(self.z, axis=1) + + self.y, self.z = lax.cond(flip_y, y_sort_fn, y_keep_fn, operand=None) self._extrapol_bool = allow_extrapolation def __call__(self, x, y, dx=0, dy=0):