From 24a5cf532c8b71e8a959df7078ad938483e2d3fa Mon Sep 17 00:00:00 2001 From: Timothy Nunn Date: Thu, 11 Sep 2025 17:02:08 +0100 Subject: [PATCH] Rewrite biliint as a numba-compilable function --- freegs4e/bilinear_interpolation.py | 59 +++++++++++++++--------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/freegs4e/bilinear_interpolation.py b/freegs4e/bilinear_interpolation.py index b66ca06..f34b3cd 100644 --- a/freegs4e/bilinear_interpolation.py +++ b/freegs4e/bilinear_interpolation.py @@ -1,7 +1,14 @@ import numpy as np +try: + from numba import njit +except ImportError: -# + def njit(*args, **kwargs): + return lambda f: f + + +@njit(cache=True, fastmath=True) def biliint(R, Z, psi, points): """Simple bilinear interpolation of 2d map @@ -44,35 +51,27 @@ def biliint(R, Z, psi, points): idxs_R = np.where(idxs_R < nx, idxs_R, nx - 1) idxs_Z = np.where(idxs_Z < ny, idxs_Z, ny - 1) - iR = idxs_R[:, np.newaxis, np.newaxis] - iZ = idxs_Z[:, np.newaxis, np.newaxis] - qq = psi[ - np.concatenate( - ( - np.concatenate((iR - 1, iR - 1), axis=2), - np.concatenate((iR, iR), axis=2), - ), - axis=1, - ), - np.concatenate( - ( - np.concatenate((iZ - 1, iZ), axis=2), - np.concatenate((iZ - 1, iZ), axis=2), - ), - axis=1, - ), - ] - - iR = idxs_R[:, np.newaxis] - iZ = idxs_Z[:, np.newaxis] - xx = points_R[ - np.concatenate((iR, iR - 1), axis=1), - np.arange(len_points)[:, np.newaxis], - ] * (np.array([[1, -1]])) - yy = points_Z[ - np.concatenate((iZ, iZ - 1), axis=1), - np.arange(len_points)[:, np.newaxis], - ] * (np.array([[1, -1]])) + qq = np.empty((len_points, 2, 2)) + + for i in range(len_points): + qq[i, 0, 0] = psi[idxs_R[i] - 1, idxs_Z[i] - 1] + qq[i, 0, 1] = psi[idxs_R[i] - 1, idxs_Z[i]] + qq[i, 1, 0] = psi[idxs_R[i], idxs_Z[i] - 1] + qq[i, 1, 1] = psi[idxs_R[i], idxs_Z[i]] + + xx = np.empty((len_points, 2)) + for i in range(len_points): + xx[i, 0] = points_R[idxs_R[i], i] + xx[i, 1] = points_R[idxs_R[i] - 1, i] + + xx = xx * np.array([[1, -1]]) + + yy = np.empty((len_points, 2)) + for i in range(len_points): + yy[i, 0] = points_Z[idxs_Z[i], i] + yy[i, 1] = points_Z[idxs_Z[i] - 1, i] + + yy = yy * np.array([[1, -1]]) vals = ( np.sum(np.sum(qq * yy[:, np.newaxis, :], axis=-1) * xx, axis=-1) / dRdZ