Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7"]
python-version: ["3.12"]

steps:

Expand Down
2 changes: 1 addition & 1 deletion test/convolution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy.testing as npt
from scipy import signal, ndimage

from jax.config import config
from jax import config
config.update("jax_enable_x64", True) # makes a difference when comparing to scipy's routines!!

from utax.convolution import *
Expand Down
4 changes: 2 additions & 2 deletions test/wavelet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.testing as npt
from test.convolution_test import gaussian_kernel

from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

import utax
Expand All @@ -16,7 +16,7 @@

class TestWaveletTransform(object):

def setup(self):
def setup_method(self):
utax_path = os.path.dirname(utax.__path__[0])
data_path = os.path.join(utax_path, 'test', 'data')

Expand Down
38 changes: 27 additions & 11 deletions utax/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,33 @@ class BilinearInterpolator(object):

"""
def __init__(self, x, y, z, allow_extrapolation=True):
self.z = jnp.array(z) # z
if np.all(np.diff(x) >= 0): # check if sorted in increasing order
self.x = jnp.array(x)
else:
self.x = jnp.array(np.sort(x))
self.z = jnp.flip(self.z, axis=0)
if np.all(np.diff(y) >= 0): # check if sorted in increasing order
self.y = jnp.array(y)
else:
self.y = jnp.array(np.sort(y))
self.z = jnp.flip(self.z, axis=1)
self.z = jnp.array(z)

# Sort x if not increasing
x = jnp.array(x)
x_sorted = jnp.sort(x)
flip_x = ~jnp.all(jnp.diff(x) >= 0)

def x_keep_fn(_):
return x, self.z

def x_sort_fn(_):
return x_sorted, jnp.flip(self.z, axis=0)

self.x, self.z = lax.cond(flip_x, x_sort_fn, x_keep_fn, operand=None)

# Sort y if not increasing
y = jnp.array(y)
y_sorted = jnp.sort(y)
flip_y = ~jnp.all(jnp.diff(y) >= 0)

def y_keep_fn(_):
return y, self.z

def y_sort_fn(_):
return y_sorted, jnp.flip(self.z, axis=1)

self.y, self.z = lax.cond(flip_y, y_sort_fn, y_keep_fn, operand=None)
self._extrapol_bool = allow_extrapolation

def __call__(self, x, y, dx=0, dy=0):
Expand Down