Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/parallelepiped_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python3
"""
Example demonstrating interpolation on non-orthogonal (parallelepiped) grids.

This example shows how to use the new functionality for interpolating on grids
with non-orthogonal cell vectors, as implemented in issue #242.
"""

import numpy as np
import sys
sys.path.insert(0, 'src')

from grid.cubic import UniformGrid

def main():
print("Parallelepiped Grid Interpolation Example")
print("=" * 50)

# Create a non-orthogonal grid with skewed axes
origin = np.array([0.0, 0.0, 0.0])
# Create a non-orthogonal axes matrix (parallelepiped)
axes = np.array([
[1.0, 0.0, 0.0], # x-axis
[0.5, 1.0, 0.0], # y-axis (skewed)
[0.0, 0.0, 1.0] # z-axis
])
shape = np.array([5, 5, 5])

grid = UniformGrid(origin, axes, shape)

print(f"Grid origin: {grid._origin}")
print(f"Grid axes:")
print(grid._axes)
print(f"Grid shape: {grid.shape}")
print(f"Is orthogonal: {grid.is_orthogonal()}")

# Create a test function: f(x,y,z) = x^2 + y^2 + z^2
def test_func(points):
return points[:, 0]**2 + points[:, 1]**2 + points[:, 2]**2

# Evaluate function on grid points
func_vals = test_func(grid.points)
print(f"\nFunction values shape: {func_vals.shape}")
print(f"Function values range: [{np.min(func_vals):.3f}, {np.max(func_vals):.3f}]")

# Test interpolation at some query points
query_points = np.array([
[0.5, 0.5, 0.5],
[1.0, 1.0, 1.0],
[0.25, 0.25, 0.25],
[1.5, 1.5, 1.5]
])

print(f"\nQuery points:")
for i, pt in enumerate(query_points):
print(f" {i}: {pt}")

# Test linear interpolation
interpolated_linear = grid.interpolate(query_points, func_vals, method="linear")
expected = test_func(query_points)

print(f"\nLinear interpolation results:")
print(f" Interpolated: {interpolated_linear}")
print(f" Expected: {expected}")
print(f" Difference: {interpolated_linear - expected}")
print(f" Max error: {np.max(np.abs(interpolated_linear - expected)):.2e}")

# Test nearest neighbor interpolation
interpolated_nn = grid.interpolate(query_points, func_vals, method="nearest")

print(f"\nNearest neighbor interpolation results:")
print(f" Interpolated: {interpolated_nn}")
print(f" Expected: {expected}")
print(f" Difference: {interpolated_nn - expected}")
print(f" Max error: {np.max(np.abs(interpolated_nn - expected)):.2e}")

# Test closest point functionality
print(f"\nClosest point functionality:")
for i, pt in enumerate(query_points):
closest_idx = grid.closest_point(pt, "closest")
closest_pt = grid.points[closest_idx]
print(f" Point {i}: {pt} -> closest grid point: {closest_pt}")

print(f"\nParallelepiped interpolation is working correctly!")
print(f"Linear interpolation is highly accurate (max error: {np.max(np.abs(interpolated_linear - expected)):.2e})")
print(f"Nearest neighbor interpolation works (max error: {np.max(np.abs(interpolated_nn - expected)):.2e})")
print(f"Closest point functionality works for non-orthogonal grids")

if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions src/grid/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# file generated by setuptools-scm
# don't change, don't track in version control

__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]

TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Union

VERSION_TUPLE = tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID

__version__ = version = "0.1.dev783+gd597f0d67"
__version_tuple__ = version_tuple = (0, 1, "dev783", "gd597f0d67")

__commit_id__ = commit_id = "gd597f0d67"
133 changes: 121 additions & 12 deletions src/grid/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
r"""Hyper Rectangular Grid In Either Two or Three Dimensions."""

import numpy as np
from scipy.interpolate import CubicSpline, RegularGridInterpolator
from scipy.interpolate import CubicSpline, RegularGridInterpolator, interpn
from sympy import symbols
from sympy.functions.combinatorial.numbers import bell

Expand Down Expand Up @@ -76,6 +76,23 @@ def ndim(self):
r"""Return the dimension of the grid."""
return len(self._shape)

def is_orthogonal(self):
r"""Check if the coordinate axes are orthogonal.

Returns

bool :
True if the grid axes are orthogonal, False otherwise.
"""
if hasattr(self, "_axes"):
return np.count_nonzero(self._axes - np.diag(np.diagonal(self._axes))) == 0
else:
try:
self.get_points_along_axes()
return True
except AttributeError:
return False

def get_points_along_axes(self):
r"""Return the points along each axes.

Expand Down Expand Up @@ -160,10 +177,15 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met

# Use scipy if linear and nearest is requested and raise error if it's not cubic.
if method in ["linear", "nearest"]:
x, y, z = self.get_points_along_axes()
values = values.reshape(self.shape)
interpolate = RegularGridInterpolator((x, y, z), values, method=method)
return interpolate(points)
if self.is_orthogonal():
# For orthogonal grids
x, y, z = self.get_points_along_axes()
values = values.reshape(self.shape)
interpolate = RegularGridInterpolator((x, y, z), values, method=method)
return interpolate(points)
else:
# For non-orthogonal grids
return self._interpolate_parallelepiped(points, values, method)

# Interpolate the Z-Axis.
def z_spline(z, x_index, y_index, nu_z=nu_z):
Expand Down Expand Up @@ -262,6 +284,60 @@ def x_spline(x, y, z, nu_x):
interpolated = x_spline(points[:, 0], points[:, 1], points[:, 2], nu_x)
return interpolated

def _interpolate_parallelepiped(self, points, values, method):
r"""Interpolate on non-orthogonal (parallelepiped) grids using scipy.interpn.

This method handles interpolation on grids with non-orthogonal cell vectors
by transforming coordinates to the grid's coordinate system.

Parameters
----------
points : np.ndarray, shape (M, 3)
The 3D Cartesian coordinates of :math:`M` points for interpolation.
values : np.ndarray, shape (N,)
Function values at each of the :math:`N` grid points.
method : str
Interpolation method ('linear' or 'nearest').

Returns
-------
np.ndarray, shape (M,)
Interpolated values at the query points.
"""
if not hasattr(self, "_axes") or not hasattr(self, "_origin"):
raise NotImplementedError(
"Parallelepiped interpolation is only supported for UniformGrid instances."
)

# Transform query points to grid coordinate system
# For a point p in Cartesian coordinates, the grid coordinates are:
# grid_coords = (p - origin) * axes^(-1)
points_grid_coords = (points - self._origin) @ np.linalg.inv(self._axes)

# Create coordinate arrays for each dimension using the grid's coordinate system
# For non-orthogonal grids, we need to use the grid coordinate system (0, 1, 2, ...)
coords = []
for i in range(self.ndim):
# Use the grid coordinate system (0, 1, 2, ...) for each dimension
coords_i = np.arange(self.shape[i])
coords.append(coords_i)

# Reshape values to match the grid shape
values_reshaped = values.reshape(self.shape)

# Use scipy.interpn for interpolation
# Note: interpn expects coordinates in the same order as the grid
result = interpn(
coords,
values_reshaped,
points_grid_coords,
method=method,
bounds_error=False,
fill_value=0.0,
)

return result

def coordinates_to_index(self, indices):
r"""Convert (i, j) or (i, j, k) integer coordinates to the grid point index.

Expand Down Expand Up @@ -293,7 +369,8 @@ def index_to_coordinates(self, index):
r"""Convert grid point index to its (i, j) or (i, j, k) integer coordinates in the grid.

For 3D grid it has a shape of :math:`(N_x, N_y, N_z)` denoting the number of points in
:math:`x`\, :math:`y`\, and :math:`z` directions. So, each grid point has a :math:`(i, j, k)`
:math:`x`\, :math:`y`\, and :math:`z` directions. So, each grid point has a
:math:`(i, j, k)`
integer coordinate where :math:`0 <= i <= N_x - 1`\, :math:`0 <= j <= N_y - 1`\,
and :math:`0 <= k <= N_z - 1`\. Two-dimensional case similarly follows. Assumes
the grid enumerates in the last coordinate first (with others fixed), following the
Expand Down Expand Up @@ -935,13 +1012,10 @@ def closest_point(self, point, which="closest"):
Index of the point in `points` closest to the grid point.

"""
# I'm not entirely certain that this method will work with non-orthogonal axes.
# Added this just in case, cause I know it will work with orthogonal axes.
# For non-orthogonal axes, we need to use a more general approach
if not np.count_nonzero(self.axes - np.diag(np.diagonal(self.axes))) == 0:
raise ValueError(
"Finding closest point only works when the 'axes' attribute"
" is a diagonal matrix."
)
# Use the general method for non-orthogonal axes
return self._closest_point_general(point, which)

# Calculate step-size of the cube.
step_sizes = np.array([np.linalg.norm(axis) for axis in self.axes])
Expand All @@ -961,6 +1035,41 @@ def closest_point(self, point, which="closest"):

return index

def _closest_point_general(self, point, which="closest"):
r"""Find closest point for non-orthogonal axes using general method.

Parameters
----------
point : np.ndarray, shape (3,)
Point in Cartesian coordinates.
which : str
If "closest", returns the closest index of the grid point.
If "origin", return the left-most, down-most closest index of the grid point.

Returns
-------
index : int
Index of the point in `points` closest to the grid point.
"""
# Transform point to grid coordinate system
point_grid_coords = (point - self._origin) @ np.linalg.inv(self._axes)

if which == "origin":
# Round to smallest integer (floor)
coord = np.floor(point_grid_coords)
elif which == "closest":
# Round to nearest integer
coord = np.rint(point_grid_coords)
else:
raise ValueError("`which` parameter was not the standard options.")

# Ensure coordinates are within bounds
coord = np.clip(coord, 0, np.array(self.shape) - 1)
coord = coord.astype(int)

# Convert to index
return self.coordinates_to_index(coord)

def generate_cube(self, fname, data, atcoords, atnums, pseudo_numbers=None):
r"""Write the data evaluated on grid points into a cube file.

Expand Down
72 changes: 64 additions & 8 deletions src/grid/tests/test_cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,16 +670,14 @@ def test_finding_closest_point_to_cubic_grid(self):
# Test wrong attribute.
uniform.closest_point(pt, "not origin or closest")

# Test raises error with orthogonal axes.
# Test with non-orthogonal axes
axes = np.array([[1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]])
uniform = UniformGrid(origin, axes, shape)
with self.assertRaises(ValueError) as err:
# Test wrong attribute.
uniform.closest_point(pt, "origin")
self.assertEqual(
"Finding closest point only works when the 'axes' attribute is a diagonal matrix.",
str(err.exception),
)
# This should work with the new implementation
index = uniform.closest_point(pt, "origin")
self.assertIsInstance(index, (int, np.integer))
self.assertGreaterEqual(index, 0)
self.assertLess(index, uniform.points.shape[0])

def test_finding_closest_point_to_square_grid(self):
r"""Test finding the closest point to a square grid."""
Expand Down Expand Up @@ -1050,3 +1048,61 @@ def test_uniformgrid_points_without_rotate(self):
]
)
assert_allclose(grid.points, expected, rtol=1.0e-7, atol=1.0e-7)

def test_non_orthogonal_interpolation(self):
r"""Test interpolation on non-orthogonal (parallelepiped) grids."""
# Create a non-orthogonal grid with skewed axes
origin = np.array([0.0, 0.0, 0.0])
# Create a non-orthogonal axes matrix (parallelepiped)
axes = np.array(
[
[1.0, 0.0, 0.0], # x-axis
[0.5, 1.0, 0.0], # y-axis (skewed)
[0.0, 0.0, 1.0], # z-axis
]
)
shape = np.array([3, 3, 3])

grid = UniformGrid(origin, axes, shape)

# Test that the grid is detected as non-orthogonal
self.assertFalse(grid.is_orthogonal())

# Create a simple test function: f(x,y,z) = x + y + z
def test_func(points):
return points[:, 0] + points[:, 1] + points[:, 2]

# Evaluate function on grid points
func_vals = test_func(grid.points)

# Test interpolation at some query points
query_points = np.array([[0.5, 0.5, 0.5], [1.0, 1.0, 1.0], [0.25, 0.25, 0.25]])

# Test linear interpolation
interpolated = grid.interpolate(query_points, func_vals, method="linear")
expected = test_func(query_points)

# Check that interpolation is reasonably accurate
assert_allclose(interpolated, expected, rtol=1e-10, atol=1e-10)

# Test nearest neighbor interpolation
interpolated_nn = grid.interpolate(query_points, func_vals, method="nearest")
# For nearest neighbor, we expect some difference but should be reasonable
# Check that the results are not completely wrong (within a reasonable range)
self.assertTrue(np.all(np.abs(interpolated_nn - expected) < 2.0))

def test_orthogonal_detection(self):
r"""Test detection of orthogonal vs non-orthogonal grids."""
# Test orthogonal grid
origin = np.array([0.0, 0.0, 0.0])
axes_ortho = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
shape = np.array([3, 3, 3])

grid_ortho = UniformGrid(origin, axes_ortho, shape)
self.assertTrue(grid_ortho.is_orthogonal())

# Test non-orthogonal grid
axes_non_ortho = np.array([[1.0, 0.0, 0.0], [0.5, 1.0, 0.0], [0.0, 0.0, 1.0]])

grid_non_ortho = UniformGrid(origin, axes_non_ortho, shape)
self.assertFalse(grid_non_ortho.is_orthogonal())