diff --git a/freegs4e/bilinear_interpolation.py b/freegs4e/bilinear_interpolation.py index b66ca06..f34b3cd 100644 --- a/freegs4e/bilinear_interpolation.py +++ b/freegs4e/bilinear_interpolation.py @@ -1,7 +1,14 @@ import numpy as np +try: + from numba import njit +except ImportError: -# + def njit(*args, **kwargs): + return lambda f: f + + +@njit(cache=True, fastmath=True) def biliint(R, Z, psi, points): """Simple bilinear interpolation of 2d map @@ -44,35 +51,27 @@ def biliint(R, Z, psi, points): idxs_R = np.where(idxs_R < nx, idxs_R, nx - 1) idxs_Z = np.where(idxs_Z < ny, idxs_Z, ny - 1) - iR = idxs_R[:, np.newaxis, np.newaxis] - iZ = idxs_Z[:, np.newaxis, np.newaxis] - qq = psi[ - np.concatenate( - ( - np.concatenate((iR - 1, iR - 1), axis=2), - np.concatenate((iR, iR), axis=2), - ), - axis=1, - ), - np.concatenate( - ( - np.concatenate((iZ - 1, iZ), axis=2), - np.concatenate((iZ - 1, iZ), axis=2), - ), - axis=1, - ), - ] - - iR = idxs_R[:, np.newaxis] - iZ = idxs_Z[:, np.newaxis] - xx = points_R[ - np.concatenate((iR, iR - 1), axis=1), - np.arange(len_points)[:, np.newaxis], - ] * (np.array([[1, -1]])) - yy = points_Z[ - np.concatenate((iZ, iZ - 1), axis=1), - np.arange(len_points)[:, np.newaxis], - ] * (np.array([[1, -1]])) + qq = np.empty((len_points, 2, 2)) + + for i in range(len_points): + qq[i, 0, 0] = psi[idxs_R[i] - 1, idxs_Z[i] - 1] + qq[i, 0, 1] = psi[idxs_R[i] - 1, idxs_Z[i]] + qq[i, 1, 0] = psi[idxs_R[i], idxs_Z[i] - 1] + qq[i, 1, 1] = psi[idxs_R[i], idxs_Z[i]] + + xx = np.empty((len_points, 2)) + for i in range(len_points): + xx[i, 0] = points_R[idxs_R[i], i] + xx[i, 1] = points_R[idxs_R[i] - 1, i] + + xx = xx * np.array([[1, -1]]) + + yy = np.empty((len_points, 2)) + for i in range(len_points): + yy[i, 0] = points_Z[idxs_Z[i], i] + yy[i, 1] = points_Z[idxs_Z[i] - 1, i] + + yy = yy * np.array([[1, -1]]) vals = ( np.sum(np.sum(qq * yy[:, np.newaxis, :], axis=-1) * xx, axis=-1) / dRdZ