From 910da08796862da5844c2bf8bd7f27a4a3f17f0a Mon Sep 17 00:00:00 2001 From: martin-millon Date: Fri, 11 Jul 2025 12:15:19 +0200 Subject: [PATCH 1/5] fixing the Bilinear Interpolator --- utax/interpolation.py | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/utax/interpolation.py b/utax/interpolation.py index 63475bb..35ad64f 100644 --- a/utax/interpolation.py +++ b/utax/interpolation.py @@ -16,17 +16,33 @@ class BilinearInterpolator(object): """ def __init__(self, x, y, z, allow_extrapolation=True): - self.z = jnp.array(z) # z - if np.all(np.diff(x) >= 0): # check if sorted in increasing order - self.x = jnp.array(x) - else: - self.x = jnp.array(np.sort(x)) - self.z = jnp.flip(self.z, axis=0) - if np.all(np.diff(y) >= 0): # check if sorted in increasing order - self.y = jnp.array(y) - else: - self.y = jnp.array(np.sort(y)) - self.z = jnp.flip(self.z, axis=1) + self.z = jnp.array(z) + + # Sort x if not increasing + x = jnp.array(x) + x_sorted = jnp.sort(x) + flip_x = ~jnp.all(jnp.diff(x) >= 0) + + def x_keep_fn(_): + return x, self.z + + def x_sort_fn(_): + return x_sorted, jnp.flip(self.z, axis=0) + + self.x, self.z = lax.cond(flip_x, x_sort_fn, x_keep_fn, operand=None) + + # Sort y if not increasing + y = jnp.array(y) + y_sorted = jnp.sort(y) + flip_y = ~jnp.all(jnp.diff(y) >= 0) + + def y_keep_fn(_): + return y, self.z + + def y_sort_fn(_): + return y_sorted, jnp.flip(self.z, axis=1) + + self.y, self.z = lax.cond(flip_y, y_sort_fn, y_keep_fn, operand=None) self._extrapol_bool = allow_extrapolation def __call__(self, x, y, dx=0, dy=0): From dd3eece8cf50c00d5eab2b0bfc01cf66c2bf045c Mon Sep 17 00:00:00 2001 From: martin-millon Date: Fri, 11 Jul 2025 14:26:48 +0200 Subject: [PATCH 2/5] updating python version workflow --- .github/workflows/ci_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index 6880f06..e7cc4b3 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7"] + python-version: ["3.12"] steps: From da3da24a95e0df0bd9bf2c5db80ec003300a5cca Mon Sep 17 00:00:00 2001 From: martin-millon Date: Fri, 11 Jul 2025 14:31:24 +0200 Subject: [PATCH 3/5] fixing tests --- test/convolution_test.py | 2 +- test/wavelet_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/convolution_test.py b/test/convolution_test.py index f760bc4..87429b2 100644 --- a/test/convolution_test.py +++ b/test/convolution_test.py @@ -5,7 +5,7 @@ import numpy.testing as npt from scipy import signal, ndimage -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # makes a difference when comparing to scipy's routines!! from utax.convolution import * diff --git a/test/wavelet_test.py b/test/wavelet_test.py index fc76cc2..ac6344d 100644 --- a/test/wavelet_test.py +++ b/test/wavelet_test.py @@ -7,7 +7,7 @@ import numpy.testing as npt from test.convolution_test import gaussian_kernel -from jax.config import config +from jax import config config.update("jax_enable_x64", True) import utax From 1ae163b1a732febf4fee4966c608e8102c51a3a2 Mon Sep 17 00:00:00 2001 From: martin-millon Date: Fri, 11 Jul 2025 14:36:41 +0200 Subject: [PATCH 4/5] switching to python 3.11 --- .github/workflows/ci_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index e7cc4b3..793ff46 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.11"] steps: From 2c3226aaa295ccf34a2aeec7e832211753dd5554 Mon Sep 17 00:00:00 2001 From: martin-millon Date: Fri, 11 Jul 2025 14:39:48 +0200 Subject: [PATCH 5/5] fixing tests --- .github/workflows/ci_tests.yml | 2 +- test/wavelet_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index 793ff46..e7cc4b3 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.11"] + python-version: ["3.12"] steps: diff --git a/test/wavelet_test.py b/test/wavelet_test.py index ac6344d..7dd97f3 100644 --- a/test/wavelet_test.py +++ b/test/wavelet_test.py @@ -16,7 +16,7 @@ class TestWaveletTransform(object): - def setup(self): + def setup_method(self): utax_path = os.path.dirname(utax.__path__[0]) data_path = os.path.join(utax_path, 'test', 'data')