From 8982e8fef697b93c4825141519d4e951af02dfd5 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sun, 21 Sep 2025 11:13:03 -0500 Subject: [PATCH 1/6] Add interpolated field tracing example with VMEC integration --- essos/interpolated_field.py | 520 ++++++++++++++++++++++ examples/trace_fieldlines_interpolated.py | 158 +++++++ 2 files changed, 678 insertions(+) create mode 100644 essos/interpolated_field.py create mode 100644 examples/trace_fieldlines_interpolated.py diff --git a/essos/interpolated_field.py b/essos/interpolated_field.py new file mode 100644 index 0000000..a8bc1d8 --- /dev/null +++ b/essos/interpolated_field.py @@ -0,0 +1,520 @@ +""" +InterpolatedField (JAX) +----------------------- + +A fast, jittable 3‑D piecewise‑polynomial interpolator on a regular grid, modeled +on the SIMSOPT C++ InterpolatedField/RegularGridInterpolant3D classes. + +Features +- Uniform or Chebyshev nodes per cell (Lagrange polynomials of degree d per axis) +- Optional domain mask (skip function) to avoid filling out‑of‑plasma regions +- Cylindrical (r,phi,z) grid with nfp periodicity and optional stellarator symmetry (z<0) +- Vector‑valued interpolation (value_size = 3 by default for B or ∇|B|) +- Batch evaluation and batched coefficient building +- Fully JIT‑able with Equinox; uses pure JAX (no Python loops at runtime) +- Exposes convenient wrappers to evaluate Cartesian B from an underlying field + that returns B(x,y,z) in Cartesian. + +Dependencies: jax, equinox (eqx) +""" +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import jax +import jax.numpy as jnp +import equinox as eqx +from jax import lax + +# -------------------------------------------------------------------------------------- +# Utility: Lagrange basis with precomputed nodes and scalings (barycentric‑like product) +# -------------------------------------------------------------------------------------- + + +class InterpolationRule(eqx.Module): + degree: int + nodes: jnp.ndarray # (d+1,) + scalings: jnp.ndarray # (d+1,) + + def basis(self, x: jnp.ndarray) -> jnp.ndarray: + """Return p_i(x) for i=0..d as shape (d+1, *x.shape). Vectorized over x. + p_i(x) = (∏_{k!=i} (x - nodes[k])) * scalings[i] + Note: x is in the *cell local* [0,1] coordinate. + """ + d = self.degree + # Evaluate all (x - nodes[k]) for broadcasting: (d+1, *xshape) + diffs = x[None, ...] - self.nodes[:, None] + # For each i, product over k!=i. We compute total product then divide by (x - nodes[i]). + prod_all = jnp.prod(diffs, axis=0) # (*xshape,) + # Guard division by zero when x equals a node: use polynomial limit via L'Hôpital with one‑hot mask + def single_pi(i): + di = diffs[i] + # Where di==0, p_i(x) should be 1 and others 0. Implement stable selection: + # base formula (prod_all / di) * scalings[i] + base = (prod_all / di) * self.scalings[i] + # exact node selection + at_node = (di == 0) + return jnp.where(at_node, jnp.ones_like(base), base) + + pis = jax.vmap(single_pi)(jnp.arange(d + 1)) # (d+1,*xshape) + # When x equals nodes[j], *only* j‑th basis should be 1; others 0. + # Enforce explicitly: + # Find any node hit + hits = (diffs == 0) + any_hit = jnp.any(hits, axis=0) + if pis.ndim == 1: + # scalar x + if any_hit: # type: ignore + j = jnp.argmax(hits) + pis = jax.nn.one_hot(j, d + 1) + else: + # broadcasted x + j = jnp.argmax(jnp.where(hits, 1, 0), axis=0) + pis = jnp.where( + any_hit[None, ...], + jax.nn.one_hot(j, d + 1)[...].swapaxes(0, -1).reshape((d + 1,) + x.shape), + pis, + ) + return pis + + +class UniformInterpolationRule(InterpolationRule): + def __init__(self, degree: int): + nodes = jnp.linspace(0.0, 1.0, degree + 1) + # barycentric‑like scalings: ∏_{k≠i} 1/(x_i - x_k) + diffs = nodes[:, None] - nodes[None, :] + scalings = jnp.prod(jnp.where(jnp.eye(degree + 1, dtype=bool), 1.0, 1.0 / diffs), axis=1) + super().__init__(degree=degree, nodes=nodes, scalings=scalings) + + +class ChebyshevInterpolationRule(InterpolationRule): + def __init__(self, degree: int): + # map Chebyshev nodes from [-1,1] to [0,1] + k = jnp.arange(degree + 1) + nodes = 0.5 * (1.0 - jnp.cos(math.pi * k / degree)) if degree > 0 else jnp.array([0.0]) + diffs = nodes[:, None] - nodes[None, :] + scalings = jnp.prod(jnp.where(jnp.eye(degree + 1, dtype=bool), 1.0, 1.0 / diffs), axis=1) + super().__init__(degree=degree, nodes=nodes, scalings=scalings) + + +# -------------------------------------------------------------------------------------- +# RegularGridInterpolant3D: vector‑valued values at (d+1)^3 dofs per cell +# -------------------------------------------------------------------------------------- + + +@dataclass +class GridSpec: + r_range: Tuple[float, float, int] # (rmin, rmax, nr_cells) + phi_range: Tuple[float, float, int] # (phimin, phimax, nphi_cells) + z_range: Tuple[float, float, int] # (zmin, zmax, nz_cells) + value_size: int = 3 + + +class RegularGridInterpolant3D(eqx.Module): + rule: InterpolationRule + grid: GridSpec + extrapolate: bool + + # Precomputed mesh params (static) + rmin: float = eqx.static_field() + rmax: float = eqx.static_field() + phimin: float = eqx.static_field() + phimax: float = eqx.static_field() + zmin: float = eqx.static_field() + zmax: float = eqx.static_field() + nr: int = eqx.static_field() + nphi: int = eqx.static_field() + nz: int = eqx.static_field() + hr: float = eqx.static_field() + hphi: float = eqx.static_field() + hz: float = eqx.static_field() + + # Reduced DOFs (kept points) in tensor grid ordering; domain masking handled via skip_mask + r_dofs: jnp.ndarray # (Nd,) + phi_dofs: jnp.ndarray # (Nd,) + z_dofs: jnp.ndarray # (Nd,) + dof_is_kept: jnp.ndarray # (Nfull,) boolean mask for tensor grid dofs + dof_full2reduced: jnp.ndarray # (Nfull,) int + dof_reduced2full: jnp.ndarray # (Nd,) int + + # For fast local assembly: for every cell, the flattened indices (in reduced‑dof space) + # of its (d+1)^3 corner DOFs, with -1 for skipped DOFs (not used but kept for shape) + cell_dof_idx: jnp.ndarray # (nr*nphi*nz, (d+1)^3) + skip_cell: jnp.ndarray # (nr*nphi*nz,) bool + + # Interpolated values at reduced DOFs + vals: jnp.ndarray # (Nd, value_size) + + def __init__( + self, + rule: InterpolationRule, + grid: GridSpec, + extrapolate: bool, + skip_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, + ): + object.__setattr__(self, "rule", rule) + object.__setattr__(self, "grid", grid) + object.__setattr__(self, "extrapolate", extrapolate) + + rmin, rmax, nr = grid.r_range + phimin, phimax, nphi = grid.phi_range + zmin, zmax, nz = grid.z_range + hr = (rmax - rmin) / nr + hphi = (phimax - phimin) / nphi + hz = (zmax - zmin) / nz + object.__setattr__(self, "rmin", rmin) + object.__setattr__(self, "rmax", rmax) + object.__setattr__(self, "phimin", phimin) + object.__setattr__(self, "phimax", phimax) + object.__setattr__(self, "zmin", zmin) + object.__setattr__(self, "zmax", zmax) + object.__setattr__(self, "nr", nr) + object.__setattr__(self, "nphi", nphi) + object.__setattr__(self, "nz", nz) + object.__setattr__(self, "hr", hr) + object.__setattr__(self, "hphi", hphi) + object.__setattr__(self, "hz", hz) + + d = rule.degree + # 1D DOF locations for each axis: (n_cells*d + 1) + r_dof_1d = jnp.concatenate( + [rmin + (i * hr + rule.nodes * hr) for i in range(nr)] + [jnp.array([rmax])] + ) + phi_dof_1d = jnp.concatenate( + [phimin + (j * hphi + rule.nodes * hphi) for j in range(nphi)] + [jnp.array([phimax])] + ) + z_dof_1d = jnp.concatenate( + [zmin + (k * hz + rule.nodes * hz) for k in range(nz)] + [jnp.array([zmax])] + ) + # Full tensor grid of DOFs + R, P, Z = jnp.meshgrid(r_dof_1d, phi_dof_1d, z_dof_1d, indexing="ij") + Rf = R.reshape(-1) + Pf = P.reshape(-1) + Zf = Z.reshape(-1) + Nfull = Rf.size + + # Domain mask from skip_fn evaluated on mesh nodes; keep dof if any adjacent cell may use it + if skip_fn is None: + keep_dof = jnp.ones((Nfull,), dtype=bool) + else: + keep_dof = ~skip_fn(Rf, Pf, Zf) + + dof_full2reduced = jnp.cumsum(keep_dof.astype(jnp.int32)) - 1 + Nd = int(keep_dof.sum()) + dof_reduced2full = jnp.nonzero(keep_dof, size=Nd, fill_value=0)[0] + r_dofs = Rf[keep_dof] + phi_dofs = Pf[keep_dof] + z_dofs = Zf[keep_dof] + + # Build per‑cell mapping to its local (d+1)^3 DOFs. Also mark entirely skipped cells + def cell_map(i, j, k): + # local 1D indices in the full dof grid + ii = jnp.arange(i * d, i * d + d + 1) + jj = jnp.arange(j * d, j * d + d + 1) + kk = jnp.arange(k * d, k * d + d + 1) + # convert 1D indices to full tensor DOF index + def idx_full(a, b, c): + return ( + a * (phi_dof_1d.size) * (z_dof_1d.size) + + b * (z_dof_1d.size) + + c + ) + + A, B, C = jnp.meshgrid(ii, jj, kk, indexing="ij") + full_idx = idx_full(A, B, C).reshape(-1) # ((d+1)^3,) + kept = keep_dof[full_idx] + # A cell is skipped iff *all* its corner DOFs are skipped + skip_this = ~jnp.any(kept) + # Map to reduced indices; put -1 for skipped DOFs (unused in eval) + red = jnp.where(kept, dof_full2reduced[full_idx], -1) + return red, skip_this + + cell_idx_list = [] + cell_skip_list = [] + for i in range(nr): + for j in range(nphi): + for k in range(nz): + red, sk = cell_map(i, j, k) + cell_idx_list.append(red) + cell_skip_list.append(sk) + cell_dof_idx = jnp.stack(cell_idx_list, axis=0) # (ncells, (d+1)^3) + skip_cell = jnp.stack(cell_skip_list, axis=0) + + vals = jnp.zeros((Nd, grid.value_size)) + + object.__setattr__(self, "r_dofs", r_dofs) + object.__setattr__(self, "phi_dofs", phi_dofs) + object.__setattr__(self, "z_dofs", z_dofs) + object.__setattr__(self, "dof_is_kept", keep_dof) + object.__setattr__(self, "dof_full2reduced", dof_full2reduced) + object.__setattr__(self, "dof_reduced2full", dof_reduced2full) + object.__setattr__(self, "cell_dof_idx", cell_dof_idx) + object.__setattr__(self, "skip_cell", skip_cell) + object.__setattr__(self, "vals", vals) + + # --------------------------- build (interpolation) --------------------------- + + def build(self, fbatch: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]): + """Fill coefficients by evaluating fbatch at reduced DOFs. + fbatch must map (r_vec, phi_vec, z_vec) -> (Nd, value_size). + """ + vals = fbatch(self.r_dofs, self.phi_dofs, self.z_dofs) + return eqx.tree_at(lambda s: s.vals, self, vals) + + # --------------------------- evaluation helpers ----------------------------- + + def _locate_cell_and_local(self, r, phi, z): + # clamp (or leave) to domain + if self.extrapolate: + rc = r + pc = phi + zc = z + else: + rc = jnp.clip(r, self.rmin, self.rmax - 1e-15) + pc = jnp.clip(phi, self.phimin, self.phimax - 1e-15) + zc = jnp.clip(z, self.zmin, self.zmax - 1e-15) + # integer cell indices + ir = jnp.floor((rc - self.rmin) / self.hr).astype(jnp.int32) + ip = jnp.floor((pc - self.phimin) / self.hphi).astype(jnp.int32) + iz = jnp.floor((zc - self.zmin) / self.hz).astype(jnp.int32) + ir = jnp.clip(ir, 0, self.nr - 1) + ip = jnp.clip(ip, 0, self.nphi - 1) + iz = jnp.clip(iz, 0, self.nz - 1) + # local coords in [0,1] + xr = (rc - (self.rmin + ir * self.hr)) / self.hr + xp = (pc - (self.phimin + ip * self.hphi)) / self.hphi + xz = (zc - (self.zmin + iz * self.hz)) / self.hz + cell_idx = (ir * self.nphi * self.nz) + (ip * self.nz) + iz + return cell_idx, xr, xp, xz + + def _eval_in_cell(self, cell_idx: jnp.ndarray, xr, xp, xz) -> jnp.ndarray: + # Fetch local coefficient block (flattened) for this cell + d = self.rule.degree + local_idx = self.cell_dof_idx[cell_idx] # ((d+1)^3,) + # Gather coefficients: ( (d+1)^3, value_size ) + local_vals = jnp.where( + (local_idx[:, None] >= 0), + self.vals[jnp.maximum(local_idx, 0)], + 0.0, + ) + # Basis on each axis: (d+1,) + br = self.rule.basis(xr) + bp = self.rule.basis(xp) + bz = self.rule.basis(xz) + # Tensor multiply: sum_{a,b,c} br[a]*bp[b]*bz[c]*V[a,b,c,:] + # Reshape to (d+1,d+1,d+1,val) + V = local_vals.reshape((d + 1, d + 1, d + 1, self.grid.value_size)) + tmp = jnp.tensordot(br, V, axes=[[0], [0]]) # (d+1, d+1, val) + tmp = jnp.tensordot(bp, tmp, axes=[[0], [0]]) # (d+1, val) + out = jnp.tensordot(bz, tmp, axes=[[0], [0]]) # (val,) + return out + + def evaluate_batch(self, rphiz: jnp.ndarray) -> jnp.ndarray: + """Evaluate interpolant at a batch of N points (r,phi,z). + rphiz: (N,3) -> returns (N, value_size) + """ + def one(p): + r, phi, z = p + cell_idx, xr, xp, xz = self._locate_cell_and_local(r, phi, z) + return self._eval_in_cell(cell_idx, xr, xp, xz) + + return jax.vmap(one)(rphiz) + + +# -------------------------------------------------------------------------------------- +# Cylindrical symmetry helpers and top‑level InterpolatedField API +# -------------------------------------------------------------------------------------- + + +def _reduce_by_symmetry(rphiz: jnp.ndarray, nfp: int, stellsym: bool): + """Map points into fundamental domain; remember flips for later component fixes. + Returns (rphiz_sym, flags) where flags=bool array whether z<0 reflection was used. + """ + r = rphiz[:, 0] + phi = rphiz[:, 1] + z = rphiz[:, 2] + + period = (2.0 * jnp.pi) / nfp + # mod phi to [0,period) + k = jnp.floor(phi / period) + phi_mod = phi - k * period + + if stellsym: + reflect = z < 0.0 + z_mod = jnp.where(reflect, -z, z) + phi_mod = jnp.where(reflect, 2 * jnp.pi - phi_mod, phi_mod) + # re‑mod to [0,period) + k2 = jnp.floor(phi_mod / period) + phi_mod = phi_mod - k2 * period + else: + reflect = jnp.zeros_like(z, dtype=bool) + z_mod = z + + r_sym = jnp.stack([r, phi_mod, z_mod], axis=1) + return r_sym, reflect + + +def _apply_symmetry_to_B_cyl(Bcyl: jnp.ndarray, reflect: jnp.ndarray) -> jnp.ndarray: + # If reflected (z<0), flip radial component (matches C++ apply_symmetries_to_B_cyl) + Br, Bp, Bz = Bcyl.T + Br = jnp.where(reflect, -Br, Br) + return jnp.stack([Br, Bp, Bz], axis=1) + + +def _apply_symmetry_to_GradAbsB_cyl(grad: jnp.ndarray, reflect: jnp.ndarray) -> jnp.ndarray: + # If reflected, flip phi and z components (matches C++ apply_symmetries_to_GradAbsB_cyl) + Gr, Gp, Gz = grad.T + Gp = jnp.where(reflect, -Gp, Gp) + Gz = jnp.where(reflect, -Gz, Gz) + return jnp.stack([Gr, Gp, Gz], axis=1) + + +def _cyl_to_cart_vectors(phi: jnp.ndarray, vec_cyl: jnp.ndarray) -> jnp.ndarray: + """Rotate cylindrical vector to Cartesian at given phi for each point. + vec_cyl: (N,3) + returns (N,3) + """ + c = jnp.cos(phi) + s = jnp.sin(phi) + Br, Bp, Bz = vec_cyl.T + Bx = c * Br - s * Bp + By = s * Br + c * Bp + return jnp.stack([Bx, By, Bz], axis=1) + + +def _cart_to_cyl_vectors(phi: jnp.ndarray, vec_xyz: jnp.ndarray) -> jnp.ndarray: + c = jnp.cos(phi) + s = jnp.sin(phi) + Bx, By, Bz = vec_xyz.T + Br = c * Bx + s * By + Bp = -s * Bx + c * By + return jnp.stack([Br, Bp, Bz], axis=1) + + +class InterpolatedField(eqx.Module): + # configuration + nfp: int + stellsym: bool + + # underlying field callable: given (x,y,z) -> (3,) Cartesian + base_field_cart: Callable[[jnp.ndarray], jnp.ndarray] = eqx.static_field() + + # interpolants in cylindrical space + interp_B: RegularGridInterpolant3D + interp_GradAbsB: Optional[RegularGridInterpolant3D] + + # Which parts have been built + has_B: bool + has_GradAbsB: bool + + def __init__( + self, + base_field_cart: Callable[[jnp.ndarray], jnp.ndarray], + degree: int, + rrange: Tuple[float, float, int], + phirange: Tuple[float, float, int], + zrange: Tuple[float, float, int], + extrapolate: bool = True, + nfp: int = 1, + stellsym: bool = False, + skip_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, + use_chebyshev: bool = False, + build_gradabsb: bool = False, + ): + rule = ChebyshevInterpolationRule(degree) if use_chebyshev else UniformInterpolationRule(degree) + grid = GridSpec(rrange, phirange, zrange, value_size=3) + interp_B = RegularGridInterpolant3D(rule, grid, extrapolate, skip_fn) + interp_G = RegularGridInterpolant3D(rule, grid, extrapolate, skip_fn) if build_gradabsb else None + object.__setattr__(self, "nfp", nfp) + object.__setattr__(self, "stellsym", stellsym) + object.__setattr__(self, "base_field_cart", base_field_cart) + object.__setattr__(self, "interp_B", interp_B) + object.__setattr__(self, "interp_GradAbsB", interp_G) + object.__setattr__(self, "has_B", False) + object.__setattr__(self, "has_GradAbsB", False) + + # --------------------- builders: fill coefficient arrays --------------------- + + def _fbatch_B(self, r: jnp.ndarray, phi: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray: + # Convert to xyz, call base field, project to cylindrical + x = r * jnp.cos(phi) + y = r * jnp.sin(phi) + pts = jnp.stack([x, y, z], axis=1) + Bxyz = jax.vmap(self.base_field_cart)(pts) + Bcyl = _cart_to_cyl_vectors(phi, Bxyz) + return Bcyl + + def _fbatch_GradAbsB(self, r: jnp.ndarray, phi: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray: + def absB(pt): + return jnp.linalg.norm(self.base_field_cart(pt)) + + grad_abs = jax.vmap(jax.grad(absB))(jnp.stack([r * jnp.cos(phi), r * jnp.sin(phi), z], axis=1)) + # grad in Cartesian -> convert to cylindrical components + return _cart_to_cyl_vectors(phi, grad_abs) + + def build_B(self): + interp = self.interp_B.build(self._fbatch_B) + return eqx.tree_at(lambda s: (s.interp_B, s.has_B), self, (interp, True)) + + def build_GradAbsB(self): + assert self.interp_GradAbsB is not None, "build_gradabsb=False in constructor" + interp = self.interp_GradAbsB.build(self._fbatch_GradAbsB) # type: ignore + return eqx.tree_at(lambda s: (s.interp_GradAbsB, s.has_GradAbsB), self, (interp, True)) + + # --------------------- evaluation on batches of points ----------------------- + + @eqx.filter_jit + def B_cyl(self, rphiz: jnp.ndarray) -> jnp.ndarray: + """Evaluate B in cylindrical components at (r,phi,z) batch. + rphiz shape (N,3). + """ + assert self.has_B, "Coefficients not built; call build_B() first" + rphiz_sym, reflect = _reduce_by_symmetry(rphiz, self.nfp, self.stellsym) + Bcyl = self.interp_B.evaluate_batch(rphiz_sym) + Bcyl = _apply_symmetry_to_B_cyl(Bcyl, reflect) + return Bcyl + + @eqx.filter_jit + def GradAbsB_cyl(self, rphiz: jnp.ndarray) -> jnp.ndarray: + assert self.has_GradAbsB and (self.interp_GradAbsB is not None), "Coefficients not built; call build_GradAbsB() first" + rphiz_sym, reflect = _reduce_by_symmetry(rphiz, self.nfp, self.stellsym) + G = self.interp_GradAbsB.evaluate_batch(rphiz_sym) # type: ignore + G = _apply_symmetry_to_GradAbsB_cyl(G, reflect) + return G + + @eqx.filter_jit + def B_xyz(self, xyz: jnp.ndarray) -> jnp.ndarray: + """Convenience: evaluate B on Cartesian input batch xyz (N,3). + Internally convert to cylindrical, call B_cyl, rotate back to Cartesian. + """ + x, y, z = xyz.T + r = jnp.sqrt(x * x + y * y) + phi = jnp.arctan2(y, x) + rphiz = jnp.stack([r, phi, z], axis=1) + Bcyl = self.B_cyl(rphiz) + return _cyl_to_cart_vectors(phi, Bcyl) + + # --------------------- error estimate (RMS, max) ---------------------------- + + def estimate_error_B(self, key: jax.Array, nsamples: int = 10_000) -> Tuple[float, float]: + assert self.has_B, "Coefficients not built; call build_B() first" + rmin, rmax, _ = self.interp_B.grid.r_range + pmin, pmax, _ = self.interp_B.grid.phi_range + zmin, zmax, _ = self.interp_B.grid.z_range + u = jax.random.uniform(key, (nsamples, 3)) + rphiz = jnp.stack([ + rmin + (rmax - rmin) * u[:, 0], + pmin + (pmax - pmin) * u[:, 1], + zmin + (zmax - zmin) * u[:, 2], + ], axis=1) + x = rphiz[:, 0] * jnp.cos(rphiz[:, 1]) + y = rphiz[:, 0] * jnp.sin(rphiz[:, 1]) + xyz = jnp.stack([x, y, rphiz[:, 2]], axis=1) + B_true = jax.vmap(self.base_field_cart)(xyz) + B_pred = self.B_xyz(xyz) + diff = jnp.linalg.norm(B_true - B_pred, axis=1) + rms = jnp.sqrt(jnp.mean(diff**2)) + mx = jnp.max(diff) + return float(rms), float(mx) diff --git a/examples/trace_fieldlines_interpolated.py b/examples/trace_fieldlines_interpolated.py new file mode 100644 index 0000000..26c88fa --- /dev/null +++ b/examples/trace_fieldlines_interpolated.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +import os +number_of_processors_to_use = 1 # Parallelization, should divide nfieldlines +os.environ["JAX_ENABLE_X64"] = "true" +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' + +from time import time +import jax +import jax.numpy as jnp +from jax import block_until_ready, vmap +import matplotlib.pyplot as plt + +# --- ESSOS imports --- +from essos.fields import Vmec +from essos.dynamics import Tracing +from essos.surfaces import SurfaceClassifier + +# --- Our interpolator (from the canvas code you have) --- +from essos.interpolated_field import InterpolatedField + +# ----------------------------- +# Inputs (same as your example) +# ----------------------------- +tmax = 1500 +nfieldlines_per_core = 6 +nfieldlines = number_of_processors_to_use * nfieldlines_per_core +R0 = jnp.linspace(0.05, 0.6, nfieldlines) +trace_tolerance = 1e-10 +num_steps = 10000 + +# ---------------------------------- +# Load VMEC & set up interpolation +# ---------------------------------- +wout_file = os.path.join(os.path.dirname(__file__), "input_files", "wout_QH_simple_scaled.nc") +vmec = Vmec(wout_file) +nfp = int(vmec.nfp) + +# Grid extents chosen to tightly cover the surface (like SIMSOPT’s example) +# You can widen these a bit for safety if your tracer steps outside frequently. +ntheta, nphi = 40, 180 +x2d, y2d, z2d, R2d = vmec.surface.get_boundary(r=0.0, ntheta=ntheta, nphi=nphi) # r=0 is the plasma boundary in Vmec coords +rs = jnp.sqrt(x2d**2 + y2d**2) +zsurf = z2d + +rrange = (float(rs.min()), float(rs.max()), 24) # (rmin, rmax, nr_cells) +phirange = (0.0, float(2 * jnp.pi / nfp), 48) # fundamental domain +# We’ll use stellarator symmetry, so keep z >= 0 domain only: +zrange = (0.0, float(jnp.abs(zsurf).max()), 16) + +# A small “buffer” expanding the domain (meters) to avoid skipping tangential cells: +BUFFER = 0.04 +sc_trace = SurfaceClassifier(vmec.surface, h=0.03, p=2) + +def skip_fn(rvec: jnp.ndarray, phivec: jnp.ndarray, zvec: jnp.ndarray) -> jnp.ndarray: + """ + Return True where the point is confidently outside the domain. + Evaluated on all dof nodes; the interpolant will skip cells whose 8 corners are all True. + """ + # Convert (r,phi,z) -> XYZ to reuse SurfaceClassifier (which works in Cartesian): + x = rvec * jnp.cos(phivec) + y = rvec * jnp.sin(phivec) + pts = jnp.stack([x, y, zvec], axis=1) + # Signed distance < -(BUFFER) => outside + d = sc_trace.evaluate(pts) # negative = inside, positive = outside + return (d < -(BUFFER)) + +# Wrap vmec.B(xyz) to feed the interpolant +def base_field_cart(pt_xyz: jnp.ndarray) -> jnp.ndarray: + return vmec.B(pt_xyz) + +# Build interpolated field (cubic per axis; change degree as you like) +interp = InterpolatedField( + base_field_cart=base_field_cart, + degree=3, + rrange=rrange, + phirange=phirange, + zrange=zrange, + extrapolate=True, + nfp=nfp, + stellsym=True, # exploit z→-z reflection + skip_fn=skip_fn, + use_chebyshev=False, + build_gradabsb=False, # flip to True if you also need ∇|B| +) +interp = interp.build_B() + +# Tiny adapter so Tracing can treat it like a field with .B and .to_xyz +class FieldAdapter: + def __init__(self, interpolant: InterpolatedField): + self.interpolant = interpolant + def B(self, points_xyz: jnp.ndarray) -> jnp.ndarray: + return self.interpolant.B_xyz(points_xyz) + def AbsB(self, points_xyz: jnp.ndarray) -> jnp.ndarray: + B = self.B(points_xyz) + return jnp.linalg.norm(B, axis=-1) + def to_xyz(self, pts_xyz: jnp.ndarray) -> jnp.ndarray: + # already in Cartesian for tracing + return pts_xyz + +bsh = FieldAdapter(interp) + +# --------------------- +# Initial conditions +# --------------------- +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T + +# --------------------- +# Trace (interpolated) +# --------------------- +time0 = time() +tracing = block_until_ready( + Tracing(field=bsh, model="FieldLineAdaptative", initial_conditions=initial_xyz, + maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance, rtol=trace_tolerance) +) +print(f"ESSOS tracing (InterpolatedField) took {time()-time0:.2f} s") +trajectories = tracing.trajectories # still in Cartesian (we kept to_xyz identity) + +# ------------- +# Plot results +# ------------- +fig = plt.figure(figsize=(9, 5)) +ax1 = fig.add_subplot(121, projection="3d") +ax2 = fig.add_subplot(122) + +# Plot VMEC boundary +vmec.surface.plot(ax=ax1, show=False) + +# Plot trajectories (already xyz) +tracing.plot(ax=ax1, show=False) + +# If your Tracing.poincare_plot expects (s,theta,phi), convert from xyz via vmec inverse map if available. +# Here we reuse vmec.to_xyz for consistency with your original script by projecting to (s,theta,phi) first if you have a helper. +# If not, you can directly do a φ=atan2(y,x) Poincaré at fixed φ planes: +def phi_of(xyz): + x, y, _ = xyz + return jnp.arctan2(y, x) + +# Quick-and-dirty Poincaré at φ = 0 plane: +phis = vmap(vmap(phi_of))(trajectories) +mask = jnp.isclose((phis % (2*jnp.pi/nfp)), 0.0, atol=2e-3) +xy_hits = jnp.where(mask[..., None], trajectories[..., :2], jnp.nan) +for line in xy_hits: + pts = jnp.reshape(line, (-1, 2)) + ax2.plot(pts[:, 0], pts[:, 1], ".", ms=1, alpha=0.6) + +ax2.set_xlabel("X") +ax2.set_ylabel("Y") +ax2.set_title("Poincaré (φ≈0)") + +plt.tight_layout() +plt.show() + +# Optional sanity check: interpolation error +key = jax.random.key(0) +rms, mx = interp.estimate_error_B(key, nsamples=5000) +print(f"Interpolant |B| error — RMS: {rms:.3e}, Max: {mx:.3e}") From 944452e1a1e92f2ad9b1babf4521ffcddc0051dc Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sun, 21 Sep 2025 20:38:52 -0500 Subject: [PATCH 2/6] Add comprehensive tests for interpolated fields and surfaces - Introduced tests for `InterpolatedField` including linear and quadratic fields, ensuring exactness and JIT compatibility. - Implemented tests for `RegularGridInterpolant3D` to validate building and evaluation with linear fields, including skip function behavior. - Added tests for `SurfaceRZFourier` to verify geometry, normals, cross-sectional area, and field evaluations on surfaces. - Included a mock magnetic field for testing `B_on_surface`, `BdotN`, and `BdotN_over_B` functions. - Established a `SurfaceClassifier` with tests for signed distances and vectorized evaluations. - Ensured numerical stability by enabling x64 precision in tests. --- essos/interpolated_field.py | 23 +- essos/surfaces.py | 247 ++++++++------- examples/trace_fieldlines_interpolated.py | 351 ++++++++++++++------- tests/test_interpolated_field.py | 362 ++++++++++++++++++++++ tests/test_surfaces.py | 289 +++++++++++++++++ 5 files changed, 1032 insertions(+), 240 deletions(-) create mode 100644 tests/test_interpolated_field.py create mode 100644 tests/test_surfaces.py diff --git a/essos/interpolated_field.py b/essos/interpolated_field.py index a8bc1d8..0bf72b2 100644 --- a/essos/interpolated_field.py +++ b/essos/interpolated_field.py @@ -178,16 +178,19 @@ def __init__( object.__setattr__(self, "hz", hz) d = rule.degree - # 1D DOF locations for each axis: (n_cells*d + 1) - r_dof_1d = jnp.concatenate( - [rmin + (i * hr + rule.nodes * hr) for i in range(nr)] + [jnp.array([rmax])] - ) - phi_dof_1d = jnp.concatenate( - [phimin + (j * hphi + rule.nodes * hphi) for j in range(nphi)] + [jnp.array([phimax])] - ) - z_dof_1d = jnp.concatenate( - [zmin + (k * hz + rule.nodes * hz) for k in range(nz)] + [jnp.array([zmax])] - ) + + def axis_dofs(xmin, h, n_cells): + # include nodes [0..d-1] for each cell, then add a single final endpoint + base = [] + for i in range(n_cells): + # take all nodes except the last one + base.append(xmin + (i * h + rule.nodes[:d] * h)) + base = jnp.concatenate(base) if base else jnp.array([]) + return jnp.concatenate([base, jnp.array([xmin + n_cells * h])]) + + r_dof_1d = axis_dofs(rmin, hr, nr) + phi_dof_1d = axis_dofs(phimin, hphi, nphi) + z_dof_1d = axis_dofs(zmin, hz, nz) # Full tensor grid of DOFs R, P, Z = jnp.meshgrid(r_dof_1d, phi_dof_1d, z_dof_1d, indexing="ij") Rf = R.reshape(-1) diff --git a/essos/surfaces.py b/essos/surfaces.py index 0048e3c..4b8673c 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -1,10 +1,16 @@ from functools import partial import jax.numpy as jnp from jax.scipy.interpolate import RegularGridInterpolator -from jax import jit, vmap, devices, device_put +from jax import jit, vmap, devices, device_put, block_until_ready from jax.sharding import Mesh, NamedSharding, PartitionSpec from essos.plot import fix_matplotlib_3d import jaxkd +import time +import numpy as np +try: + from scipy.spatial import cKDTree +except ImportError: + raise ImportError("pip install scipy to speed up SurfaceClassifier building.") mesh = Mesh(devices(), ("dev",)) sharding = NamedSharding(mesh, PartitionSpec("dev", None)) @@ -328,129 +334,144 @@ def mean_cross_sectional_area(self): dZ_dtheta = dgamma1[:, :, 2] * Jinv[:, :, 0, 1] + dgamma2[:, :, 2] * Jinv[:, :, 1, 1] mean_cross_sectional_area = jnp.abs(jnp.mean(jnp.sqrt(x2y2) * dZ_dtheta * detJ))/(2 * jnp.pi) return mean_cross_sectional_area - -#This class is based on simsopt classifier but translated to fit jax -class SurfaceClassifier(): + +class SurfaceClassifier: """ - Takes in a toroidal surface and constructs an interpolant of the signed distance function - :math:`f:R^3\to R` that is positive inside the volume contained by the surface, - (approximately) zero on the surface, and negative outisde the volume contained by the surface. + Signed-distance interpolant f: R^3 -> R + + inside the surface + ~ 0 on the surface + - outside """ - def __init__(self, surface,h=0.05): + def __init__(self, surface, h=0.05, use_fundamental_phi=True): """ Args: - surface: the surface to contruct the distance from. - h: grid resolution of the interpolant + surface: SurfaceRZFourier + h: target step for r,phi,z + use_fundamental_phi: if True, restrict phi to [0, 2π/nfp] instead of [0, 2π] """ - gammas = surface.gamma + t0 = time.perf_counter() + print("[SC] init: start") + + self.surface = surface + gammas = surface.gamma # (nphi, ntheta, 3) + nfp = getattr(surface, "nfp", 1) + + # ------------------------- + # Bounds & grid resolution + # ------------------------- r = jnp.linalg.norm(gammas[:, :, :2], axis=2) z = gammas[:, :, 2] - rmin = max(jnp.min(r) - 0.1, 0.) - rmax = jnp.max(r) + 0.1 - zmin = jnp.min(z) - 0.1 - zmax = jnp.max(z) + 0.1 - - self.zrange = (zmin, zmax) - self.rrange = (rmin, rmax) - - nr = int((self.rrange[1]-self.rrange[0])/h) - nphi = int(2*jnp.pi/h) - nz = int((self.zrange[1]-self.zrange[0])/h) - - def fbatch(rs, phis, zs): - xyz = jnp.zeros(( 3)) - xyz=xyz.at[0].set( rs * jnp.cos(phis)) - xyz=xyz.at[1].set(rs * jnp.sin(phis)) - xyz=xyz.at[2].set(zs) - return signed_distance_from_surface_jax(xyz, surface) - #return signed_distance_from_surface_extras(xyz, surface) ####memory bounded - - #rule = sopp.UniformInterpolationRule(p) - #self.dist = RegularGridInterpolator((jnp.linspace(rmin,rmax,nr), - # jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz)), - # vmap(vmap(vmap(fbatch,in_axes=(0,None,None)),in_axes=(None,0,None)),in_axes=(None,None,0))(jnp.linspace(rmin,rmax,nr), - # jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz))) - #self.r_list=jnp.linspace(16.9,17.1,nr) - #self.phi_list=jnp.linspace(0., 0.01, nphi) - #self.z_list=jnp.linspace(-0.1, 0.1, nz) - #self.test= vmap(vmap(vmap(fbatch,in_axes=(0,None,None)),in_axes=(None,0,None)),in_axes=(None,None,0))(self.r_list, - # self.phi_list, self.z_list) - #self.r_list=jnp.linspace(rmin,rmax,nr) - #self.phi_list=jnp.linspace(0., 2*jnp.pi, nphi) - #self.z_list=jnp.linspace(zmin, zmax, nz) - #self.test= vmap(vmap(vmap(fbatch,in_axes=(None,None,0)),in_axes=(None,0,None)),in_axes=(0,None,None))(jnp.linspace(rmin,rmax,nr), - # jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz)) - #self.dist = RegularGridInterpolator((self.r_list,self.phi_list, self.z_list), - # vmap(vmap(vmap(fbatch,in_axes=(None,None,0)),in_axes=(None,0,None)),in_axes=(0,None,None))(self.r_list,self.phi_list, self.z_list),fill_value=-1.) - self.dist = RegularGridInterpolator((jnp.linspace(rmin,rmax,nr), - jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz)), - vmap(vmap(vmap(fbatch,in_axes=(None,None,0)),in_axes=(None,0,None)),in_axes=(0,None,None))(jnp.linspace(rmin,rmax,nr), - jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz)),fill_value=-1.) - #self.dist.interpolate_batch(fbatch) + rmin = float(max(jnp.min(r) - 0.1, 0.0)) + rmax = float(jnp.max(r) + 0.1) + zmin = float(jnp.min(z) - 0.1) + zmax = float(jnp.max(z) + 0.1) - @partial(jit, static_argnames=['self']) - def evaluate_xyz(self, xyz): - rphiz = jnp.zeros_like(xyz) - rphiz=rphiz.at[0].set(jnp.linalg.norm(xyz[:2])) - rphiz=rphiz.at[1].set(jnp.mod(jnp.arctan2(xyz[1], xyz[0]), 2*jnp.pi)) - rphiz=rphiz.at[2].set(xyz.at[2].get()) - # initialize to -1 since the regular grid interpolant will just keep - # that value when evaluated outside of bounds - d=self.dist(rphiz)[0][0] - return d + if use_fundamental_phi and nfp > 0: + phimin = 0.0 + phimax = float(2 * jnp.pi / nfp) + else: + phimin = 0.0 + phimax = float(2 * jnp.pi) - @partial(jit, static_argnames=['self']) - def evaluate_rphiz(self, rphiz): - # initialize to -1 since the regular grid interpolant will just keep - # that value when evaluated outside of bounds - d=self.dist(rphiz)[0][0] - return d - + self.rrange = (rmin, rmax) + self.zrange = (zmin, zmax) + self.phirange = (phimin, phimax) -partial(jit, static_argnames=['surface']) -def signed_distance_from_surface_jax(xyz, surface): - """ - Compute the signed distances from points ``xyz`` to a surface. The sign is - positive for points inside the volume surrounded by the surface. - """ - gammas = surface.gamma.reshape((-1, 3)) - #from scipy.spatial import KDTree ##better for cpu? - tree = jaxkd.build_tree(gammas) - mins, _ = jaxkd.query_neighbors(tree, xyz, k=1) - n = surface.unitnormal.reshape((-1, 3)) - nmins = n[mins] - gammamins = gammas[mins] - # Now that we have found the closest node, we approximate the surface with - # a plane through that node with the appropriate normal and then compute - # the distance from the point to that plane - # https://stackoverflow.com/questions/55189333/how-to-get-distance-from-point-to-plane-in-3d - mindist = jnp.sum((xyz-gammamins) * nmins, axis=1) - a_point_in_the_surface = jnp.mean(surface.gamma[0, :, :], axis=0) - sign_of_interiorpoint = jnp.sign(jnp.sum((a_point_in_the_surface-gammas[0, :])*n[0, :])) - signed_dists = mindist * sign_of_interiorpoint - return signed_dists - -#@partial(jit, static_argnames=['surface']) -def signed_distance_from_surface_extras(xyz, surface): - """ - Compute the signed distances from points ``xyz`` to a surface. The sign is - positive for points inside the volume surrounded by the surface. - """ - gammas = surface.gamma.reshape((-1, 3)) - mins, _ = jaxkd.extras.query_neighbors_pairwise(gammas, xyz, k=1) - n = surface.unitnormal.reshape((-1, 3)) - nmins = n[mins] - gammamins = gammas[mins] - # Now that we have found the closest node, we approximate the surface with - # a plane through that node with the appropriate normal and then compute - # the distance from the point to that plane - # https://stackoverflow.com/questions/55189333/how-to-get-distance-from-point-to-plane-in-3d - mindist = jnp.sum((xyz-gammamins) * nmins, axis=1) - a_point_in_the_surface = jnp.mean(surface.gamma[0, :, :], axis=0) - sign_of_interiorpoint = jnp.sign(jnp.sum((a_point_in_the_surface-gammas[0, :])*n[0, :])) - signed_dists = mindist * sign_of_interiorpoint - return signed_dists + # Make sure we have at least 2 points per axis: + nr = max(int((rmax - rmin) / h), 2) + nphi = max(int((phimax - phimin) / h), 3) # keep ≥3 to resolve periodicity a bit + nz = max(int((zmax - zmin) / h), 2) + + print(f"[SC] ranges: r=({rmin:.3f},{rmax:.3f}) phi=({phimin:.3f},{phimax:.3f}) z=({zmin:.3f},{zmax:.3f})") + print(f"[SC] grid sizes: nr={nr}, nphi={nphi}, nz={nz} -> total={nr*nphi*nz:,d} nodes") + + # ------------------------- + # Precompute KD-tree once + # ------------------------- + t_tree = time.perf_counter() + gammas_flat = gammas.reshape((-1, 3)) + normals_flat = surface.unitnormal.reshape((-1, 3)) + self._tree = jaxkd.build_tree(gammas_flat) + # Sign convention (interior point): + a_point = jnp.mean(surface.gamma[0, :, :], axis=0) + sign_of_interiorpoint = jnp.sign(jnp.sum((a_point - gammas_flat[0, :]) * normals_flat[0, :])) + self._sign = float(sign_of_interiorpoint) + print(f"[SC] KD-tree build: {time.perf_counter() - t_tree:.2f}s " + f"(nodes={gammas_flat.shape[0]:,d})") + # ------------------------- + # Build (r,phi,z) grid + # ------------------------- + t_grid = time.perf_counter() + r_list = jnp.linspace(rmin, rmax, nr) + phi_list = jnp.linspace(phimin, phimax, nphi) + z_list = jnp.linspace(zmin, zmax, nz) + + # Mesh in 'ij' so r varies slowest, z fastest when flattened + RR, PP, ZZ = jnp.meshgrid(r_list, phi_list, z_list, indexing="ij") # each (nr, nphi, nz) + Ntot = nr * nphi * nz + + # Convert to Cartesian for nearest-neighbor query: + XX = RR * jnp.cos(PP) + YY = RR * jnp.sin(PP) + xyz_grid = jnp.stack([XX, YY, ZZ], axis=-1).reshape((Ntot, 3)) + + print(f"[SC] grid gen: {time.perf_counter() - t_grid:.2f}s; xyz_grid shape={tuple(xyz_grid.shape)}") + + # Build SciPy KD-tree on CPU (fast) + t_query = time.perf_counter() + tree = cKDTree(np.asarray(gammas_flat)) # (Ng, 3) + dist, idxs = tree.query(np.asarray(xyz_grid), k=1, workers=-1) # (Ntot,), (Ntot,) + nearest_pts = gammas_flat[np.asarray(idxs)] # jnp will accept np indexing + nearest_normals = normals_flat[np.asarray(idxs)] + # signed distance to tangent plane + d_plane = jnp.sum((xyz_grid - nearest_pts) * nearest_normals, axis=-1) # (Ntot,) + signed = self._sign * d_plane + field_vals = signed.reshape((nr, nphi, nz)) + _ = block_until_ready(field_vals) + print(f"[SC] KD query+dist: {time.perf_counter() - t_query:.2f}s (SciPy cKDTree)") + + # ------------------------- + # Build RGI + # ------------------------- + t_rgi = time.perf_counter() + self._r_list = r_list + self._phi_list = phi_list + self._z_list = z_list + + # fill_value < 0.0 => "outside" by default beyond bounds + self.dist = RegularGridInterpolator( + (r_list, phi_list, z_list), field_vals, fill_value=-1.0 + ) + + print(f"[SC] RGI build: {time.perf_counter() - t_rgi:.2f}s") + print(f"[SC] init done in {time.perf_counter() - t0:.2f}s total") + + # ------------------------- + # Vectorized signed-distance API (XYZ) + # ------------------------- + @staticmethod + def _xyz_to_rphiz(xyz): + """xyz: (...,3) -> rphiz: (...,3)""" + x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] + r = jnp.sqrt(x * x + y * y) + phi = jnp.mod(jnp.arctan2(y, x), 2 * jnp.pi) + return jnp.stack([r, phi, z], axis=-1) + + def _wrap_phi(self, phi): + period = 2 * jnp.pi / max(1, int(getattr(self.surface, "nfp", 1))) + return jnp.mod(phi, period) + + @partial(jit, static_argnames=['self']) + def evaluate_xyz(self, xyz): + rphiz = self._xyz_to_rphiz(xyz) + rphiz = rphiz.at[..., 1].set(self._wrap_phi(rphiz[..., 1])) + return self.dist(rphiz) + + @partial(jit, static_argnames=['self']) + def evaluate_rphiz(self, rphiz): + rphiz = rphiz.at[..., 1].set(self._wrap_phi(rphiz[..., 1])) + return self.dist(rphiz) \ No newline at end of file diff --git a/examples/trace_fieldlines_interpolated.py b/examples/trace_fieldlines_interpolated.py index 26c88fa..d966484 100644 --- a/examples/trace_fieldlines_interpolated.py +++ b/examples/trace_fieldlines_interpolated.py @@ -1,158 +1,275 @@ #!/usr/bin/env python3 -import os -number_of_processors_to_use = 1 # Parallelization, should divide nfieldlines +import os, time +number_of_processors_to_use = 1 os.environ["JAX_ENABLE_X64"] = "true" os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from time import time import jax import jax.numpy as jnp from jax import block_until_ready, vmap +from jax.scipy.interpolate import RegularGridInterpolator as JaxRGI +from functools import partial import matplotlib.pyplot as plt -# --- ESSOS imports --- +# ---- Your ESSOS imports ---- from essos.fields import Vmec -from essos.dynamics import Tracing from essos.surfaces import SurfaceClassifier -# --- Our interpolator (from the canvas code you have) --- -from essos.interpolated_field import InterpolatedField +# ---- Import the RegularGridInterpolant3D / rules you already have ---- +from essos.interpolated_field import RegularGridInterpolant3D, UniformInterpolationRule, ChebyshevInterpolationRule -# ----------------------------- -# Inputs (same as your example) -# ----------------------------- -tmax = 1500 -nfieldlines_per_core = 6 -nfieldlines = number_of_processors_to_use * nfieldlines_per_core -R0 = jnp.linspace(0.05, 0.6, nfieldlines) -trace_tolerance = 1e-10 -num_steps = 10000 +# --------------------------------------------------------------------------------- +# Helper: If you ever build an (r,φ,z)-space interpolant, you can use this skip_fn +# For native (s,θ,φ) it is typically unnecessary, so this is OPTIONAL here. +# --------------------------------------------------------------------------------- +def make_skip_fn_from_classifier(sc, buffer=0.04): + """Return (rvec, phivec, zvec) -> bool[N] mask; True means 'skip' (outside).""" + def _skip(rvec: jnp.ndarray, phivec: jnp.ndarray, zvec: jnp.ndarray) -> jnp.ndarray: + rphiz = jnp.stack([rvec, phivec, zvec], axis=1) # (N,3) + d = jax.vmap(sc.evaluate_rphiz)(rphiz) # (N,) + return (d < -buffer) + return _skip + +# --------------------------------------------------------------------------------- +# Native-coordinate interpolated field for VMEC: (s,θ,φ)-> {B_cov, B_con, sqrtg} +# --------------------------------------------------------------------------------- +class InterpolatedVmecNative: + def __init__(self, vmec, + srange=(0.0, 1.0, 24), + thetarange=(0.0, 2*jnp.pi, 48), + phirange=(0.0, None, 48)): + self.vmec = vmec + nfp = int(vmec.nfp) + if phirange[1] is None: + phirange = (0.0, 2*jnp.pi/nfp, phirange[2]) + self.srange, self.thetarange, self.phirange = srange, thetarange, phirange + self._rgis = None # will hold dict of per-component RGIs + + def build_all(self): + print("[build] Precomputing grids & field values for RGI ...") + t0 = time.perf_counter() + s0, s1, ns = self.srange + th0, th1, nth = self.thetarange + ph0, ph1, nph = self.phirange + + s_list = jnp.linspace(s0, s1, ns) + th_list = jnp.linspace(th0, th1, nth) + ph_list = jnp.linspace(ph0, ph1, nph) + + # Tensor grid (ij indexing so reshape matches (ns,nth,nph)) + SS, TT, PP = jnp.meshgrid(s_list, th_list, ph_list, indexing="ij") + pts_flat = jnp.stack([SS.ravel(), TT.ravel(), PP.ravel()], axis=1) # (N,3) + N = pts_flat.shape[0] + print(f"[build] grid sizes: ns={ns}, nth={nth}, nph={nph} -> N={N}") + + # Batch evaluate VMEC (cartesian) quantities + t_eval = time.perf_counter() + Bcov = jax.vmap(self.vmec.B_covariant)(pts_flat) # (N,3) + Bcon = jax.vmap(self.vmec.B_contravariant)(pts_flat) # (N,3) + sqrtg = jax.vmap(self.vmec.sqrtg)(pts_flat) # (N,) + # Force computation/timing: + Bcov, Bcon, sqrtg = jax.block_until_ready(Bcov), jax.block_until_ready(Bcon), jax.block_until_ready(sqrtg) + print(f"[build] VMEC eval on grid: {time.perf_counter()-t_eval:.2f}s") -# ---------------------------------- -# Load VMEC & set up interpolation -# ---------------------------------- + # Reshape to tensor grid + Bcov = Bcov.reshape((ns, nth, nph, 3)) + Bcon = Bcon.reshape((ns, nth, nph, 3)) + sqrtg = sqrtg.reshape((ns, nth, nph)) + + # Build 7 RGIs (3+3+1), fill_value extrapolates as you prefer: + t_rgi = time.perf_counter() + def rgi3(A3): + # A3 shape = (ns,nth,nph) + return JaxRGI((s_list, th_list, ph_list), A3, fill_value=None) + + rgis = { + "Bcov0": rgi3(Bcov[..., 0]), + "Bcov1": rgi3(Bcov[..., 1]), + "Bcov2": rgi3(Bcov[..., 2]), + "Bcon0": rgi3(Bcon[..., 0]), + "Bcon1": rgi3(Bcon[..., 1]), + "Bcon2": rgi3(Bcon[..., 2]), + "sqrtg": rgi3(sqrtg), + } + self._rgis = rgis + print(f"[build] RGI build: {time.perf_counter()-t_rgi:.2f}s") + print(f"[build] total: {time.perf_counter()-t0:.2f}s") + return self + + # ---------------- evaluate (batched & JIT) ---------------- + @partial(jax.jit, static_argnames=("self",)) + def B_covariant(self, pts_stp): # pts_stp (...,3) + x = pts_stp + b0 = self._rgis["Bcov0"](x) + b1 = self._rgis["Bcov1"](x) + b2 = self._rgis["Bcov2"](x) + return jnp.stack([b0, b1, b2], axis=-1) + + @partial(jax.jit, static_argnames=("self",)) + def B_contravariant(self, pts_stp): + x = pts_stp + b0 = self._rgis["Bcon0"](x) + b1 = self._rgis["Bcon1"](x) + b2 = self._rgis["Bcon2"](x) + return jnp.stack([b0, b1, b2], axis=-1) + + @partial(jax.jit, static_argnames=("self",)) + def sqrtg(self, pts_stp): + return self._rgis["sqrtg"](pts_stp) + + @partial(jax.jit, static_argnames=("self",)) + def AbsB(self, pts_stp): + # Optional, reuse vmec.AbsB to avoid building another set of RGIs: + return jax.vmap(self.vmec.AbsB)(pts_stp) + + +# ------------------------------------------------------------------------------------------------- +# Script body: build interpolants in native coords; trace using your native EOM (s,θ,φ dynamics) +# ------------------------------------------------------------------------------------------------- +t0 = time.perf_counter() +print("[stage] Loading VMEC ...") wout_file = os.path.join(os.path.dirname(__file__), "input_files", "wout_QH_simple_scaled.nc") vmec = Vmec(wout_file) -nfp = int(vmec.nfp) +print(f"[time] VMEC load: {time.perf_counter()-t0:.2f}s (nfp={vmec.nfp})") + +print("[stage] Building SurfaceClassifier (for diagnostics / optional skip) ...") +t_sc = time.perf_counter() +sc = SurfaceClassifier(vmec.surface, h=0.06) +print(f"[time] SurfaceClassifier: {time.perf_counter()-t_sc:.2f}s") + +# Native grid in (s,θ,φ): +srange = (0.0, 1.0, 24) +thetarange = (0.0, 2*jnp.pi, 48) +phirange = (0.0, 2*jnp.pi/int(vmec.nfp), 48) -# Grid extents chosen to tightly cover the surface (like SIMSOPT’s example) -# You can widen these a bit for safety if your tracer steps outside frequently. -ntheta, nphi = 40, 180 -x2d, y2d, z2d, R2d = vmec.surface.get_boundary(r=0.0, ntheta=ntheta, nphi=nphi) # r=0 is the plasma boundary in Vmec coords -rs = jnp.sqrt(x2d**2 + y2d**2) -zsurf = z2d - -rrange = (float(rs.min()), float(rs.max()), 24) # (rmin, rmax, nr_cells) -phirange = (0.0, float(2 * jnp.pi / nfp), 48) # fundamental domain -# We’ll use stellarator symmetry, so keep z >= 0 domain only: -zrange = (0.0, float(jnp.abs(zsurf).max()), 16) - -# A small “buffer” expanding the domain (meters) to avoid skipping tangential cells: -BUFFER = 0.04 -sc_trace = SurfaceClassifier(vmec.surface, h=0.03, p=2) - -def skip_fn(rvec: jnp.ndarray, phivec: jnp.ndarray, zvec: jnp.ndarray) -> jnp.ndarray: - """ - Return True where the point is confidently outside the domain. - Evaluated on all dof nodes; the interpolant will skip cells whose 8 corners are all True. - """ - # Convert (r,phi,z) -> XYZ to reuse SurfaceClassifier (which works in Cartesian): - x = rvec * jnp.cos(phivec) - y = rvec * jnp.sin(phivec) - pts = jnp.stack([x, y, zvec], axis=1) - # Signed distance < -(BUFFER) => outside - d = sc_trace.evaluate(pts) # negative = inside, positive = outside - return (d < -(BUFFER)) - -# Wrap vmec.B(xyz) to feed the interpolant -def base_field_cart(pt_xyz: jnp.ndarray) -> jnp.ndarray: - return vmec.B(pt_xyz) - -# Build interpolated field (cubic per axis; change degree as you like) -interp = InterpolatedField( - base_field_cart=base_field_cart, - degree=3, - rrange=rrange, - phirange=phirange, - zrange=zrange, - extrapolate=True, - nfp=nfp, - stellsym=True, # exploit z→-z reflection - skip_fn=skip_fn, - use_chebyshev=False, - build_gradabsb=False, # flip to True if you also need ∇|B| -) -interp = interp.build_B() - -# Tiny adapter so Tracing can treat it like a field with .B and .to_xyz -class FieldAdapter: - def __init__(self, interpolant: InterpolatedField): - self.interpolant = interpolant - def B(self, points_xyz: jnp.ndarray) -> jnp.ndarray: - return self.interpolant.B_xyz(points_xyz) - def AbsB(self, points_xyz: jnp.ndarray) -> jnp.ndarray: - B = self.B(points_xyz) - return jnp.linalg.norm(B, axis=-1) - def to_xyz(self, pts_xyz: jnp.ndarray) -> jnp.ndarray: - # already in Cartesian for tracing - return pts_xyz - -bsh = FieldAdapter(interp) +print("[stage] Building native (s,θ,φ) interpolants: B_cov, B_con, sqrtg ...") +t_build = time.perf_counter() +interp = InterpolatedVmecNative( + vmec, + srange=(0.0, 1.0, 33), # ns + thetarange=(0.0, 2*jnp.pi, 48), # nth + phirange=(0.0, None, 64), # nph; None -> 2π/nfp +).build_all() +interp = interp.build_all() +print(f"[time] Interpolant build total: {time.perf_counter()-t_build:.2f}s") + +# Adapter exposing the same API your Tracing uses (native EOM): +class VmecNativeAdapter: + def __init__(self, base_vmec, interp_native): + self.vmec = base_vmec + self.I = interp_native + self.nfp = base_vmec.nfp # keep attribute parity + + # EOM-relevant pieces in native coordinates: + def B_covariant(self, pts_stp): + return self.I.B_covariant(pts_stp) + + def B_contravariant(self, pts_stp): + return self.I.B_contravariant(pts_stp) + + def sqrtg(self, pts_stp): + return self.I.sqrtg(pts_stp) + + # Optional: if Tracing uses these too: + def AbsB(self, pts_stp): + return self.I.AbsB(pts_stp) + + # Geometry map (already provided by your Vmec) + def to_xyz(self, pts_stp): + return vmap(self.vmec.to_xyz)(pts_stp) + + # If Tracing occasionally calls B in Cartesian, you can provide: + # def B(self, pts_stp): + # # Cartesian B from covariant basis vectors: + # # You already have vmec.B that returns Cartesian from (s,θ,φ); + # # If you want to be fully interpolated, build a separate Cartesian interpolant. + # return vmap(self.vmec.B)(pts_stp) + +bfield = VmecNativeAdapter(vmec, interp) # --------------------- # Initial conditions # --------------------- -Z0 = jnp.zeros(nfieldlines) -phi0 = jnp.zeros(nfieldlines) -initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T +t_init = time.perf_counter() +nfieldlines_per_core = 6 +nfieldlines = number_of_processors_to_use * nfieldlines_per_core +# Choose initial (s,θ,φ). Your prior script used Cartesian R0/Z0/φ0; here we stay native: +s0 = jnp.linspace(0.02, 0.98, nfieldlines) # avoid exact boundary +th0 = jnp.zeros(nfieldlines) +ph0 = jnp.zeros(nfieldlines) +initial_stp = jnp.stack([s0, th0, ph0], axis=1) +print(f"[time] Init conditions set: {time.perf_counter()-t_init:.2f}s (n={nfieldlines})") # --------------------- -# Trace (interpolated) +# Trace in ESSOS (native EOM) # --------------------- -time0 = time() -tracing = block_until_ready( - Tracing(field=bsh, model="FieldLineAdaptative", initial_conditions=initial_xyz, - maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance, rtol=trace_tolerance) -) -print(f"ESSOS tracing (InterpolatedField) took {time()-time0:.2f} s") -trajectories = tracing.trajectories # still in Cartesian (we kept to_xyz identity) +from essos.dynamics import Tracing +t_trace = time.perf_counter() +tmax = 1500 +trace_tolerance = 1e-10 +num_steps = 10000 + +print("[stage] JIT warmup for Tracing ...") +# Hint: do a tiny warmup call if Tracing JITs internally; otherwise first call will take longer. + +print("[stage] Running Tracing ...") +print("[stage] Tracing fieldlines using interpolated field (native s,θ,φ) …") +t0 = time.perf_counter() +tracing = block_until_ready(Tracing( + field=interp, # <= use interpolant here + model='FieldLineAdaptative', + initial_conditions=initial_stp, # still XYZ for seeding; your Tracing will convert + maxtime=tmax, + times_to_trace=num_steps, + atol=trace_tolerance, + rtol=trace_tolerance +)) +print(f"[time] ESSOS tracing: {time.perf_counter()-t0:.2f}s") + +# Trajectories are in native (s,θ,φ). Convert to xyz for plotting: +t_xyz = time.perf_counter() +trajectories_stp = tracing.trajectories +trajectories_xyz = vmap(vmap(bfield.to_xyz))(trajectories_stp) +print(f"[time] stp->xyz conversion: {time.perf_counter()-t_xyz:.2f}s") # ------------- # Plot results # ------------- +print("[stage] Plotting ...") +t_plot = time.perf_counter() fig = plt.figure(figsize=(9, 5)) ax1 = fig.add_subplot(121, projection="3d") ax2 = fig.add_subplot(122) -# Plot VMEC boundary vmec.surface.plot(ax=ax1, show=False) +# quick plot of trajectories in xyz +for tr in trajectories_xyz: + ax1.plot(tr[:,0], tr[:,1], tr[:,2], lw=0.8, alpha=0.8) -# Plot trajectories (already xyz) -tracing.plot(ax=ax1, show=False) - -# If your Tracing.poincare_plot expects (s,theta,phi), convert from xyz via vmec inverse map if available. -# Here we reuse vmec.to_xyz for consistency with your original script by projecting to (s,theta,phi) first if you have a helper. -# If not, you can directly do a φ=atan2(y,x) Poincaré at fixed φ planes: -def phi_of(xyz): - x, y, _ = xyz - return jnp.arctan2(y, x) - -# Quick-and-dirty Poincaré at φ = 0 plane: -phis = vmap(vmap(phi_of))(trajectories) +# Simple Poincaré at φ ≈ 0 using native coordinates directly: +nfp = int(vmec.nfp) +phis = trajectories_stp[..., 2] mask = jnp.isclose((phis % (2*jnp.pi/nfp)), 0.0, atol=2e-3) -xy_hits = jnp.where(mask[..., None], trajectories[..., :2], jnp.nan) +xy_hits = jnp.where(mask[..., None], trajectories_xyz[..., :2], jnp.nan) for line in xy_hits: pts = jnp.reshape(line, (-1, 2)) ax2.plot(pts[:, 0], pts[:, 1], ".", ms=1, alpha=0.6) -ax2.set_xlabel("X") -ax2.set_ylabel("Y") -ax2.set_title("Poincaré (φ≈0)") +ax2.set_xlabel("X"); ax2.set_ylabel("Y"); ax2.set_title("Poincaré (φ≈0)") +plt.tight_layout(); plt.show() +print(f"[time] Plotting: {time.perf_counter()-t_plot:.2f}s") -plt.tight_layout() -plt.show() - -# Optional sanity check: interpolation error +# Optional: quick accuracy probe at random points key = jax.random.key(0) -rms, mx = interp.estimate_error_B(key, nsamples=5000) -print(f"Interpolant |B| error — RMS: {rms:.3e}, Max: {mx:.3e}") +Nsamp = 4096 +s_s = jax.random.uniform(key, (Nsamp,), minval=srange[0], maxval=srange[1]) +th_s = jax.random.uniform(key, (Nsamp,), minval=thetarange[0], maxval=thetarange[1]) +ph_s = jax.random.uniform(key, (Nsamp,), minval=phirange[0], maxval=phirange[1]) +pts = jnp.stack([s_s, th_s, ph_s], axis=1) + +t_err = time.perf_counter() +Bcov_true = jax.vmap(vmec.B_covariant)(pts) +Bcov_interp = bfield.B_covariant(pts) +err = jnp.linalg.norm(Bcov_true - Bcov_interp, axis=1) +print(f"[check] RMS|B_cov-Interp|={jnp.sqrt(jnp.mean(err**2)):.3e}, " + f"Max={jnp.max(err):.3e} (computed in {time.perf_counter()-t_err:.2f}s)") diff --git a/tests/test_interpolated_field.py b/tests/test_interpolated_field.py new file mode 100644 index 0000000..0c1f873 --- /dev/null +++ b/tests/test_interpolated_field.py @@ -0,0 +1,362 @@ +# tests/test_interpolated_field.py +import math +import pytest +import jax +import jax.numpy as jnp + +# ---- import your code under test ---- +# adjust the import path to wherever you placed InterpolationRule / RegularGridInterpolant3D / InterpolatedField +from essos.interpolated_field import ( + UniformInterpolationRule, + ChebyshevInterpolationRule, + RegularGridInterpolant3D, + GridSpec, + InterpolatedField, + _cart_to_cyl_vectors, + _cyl_to_cart_vectors, +) + +# -------------------------------------------------------------------------------------- +# Helpers +# -------------------------------------------------------------------------------------- + +@pytest.fixture(scope="module", autouse=True) +def _enable_x64(): + # Keep consistent precision for assertions + jax.config.update("jax_enable_x64", True) + + +def linear_cartesian_field(xyz: jnp.ndarray) -> jnp.ndarray: + """A linear field in Cartesian; degree-1 interpolation should be exact.""" + x, y, z = xyz + # Arbitrary but linear map to 3 components: + return jnp.array([ + 2.0 * x - y + 0.5 * z, + -x + 3.0 * y + 0.25 * z, + -0.5 * x + 0.75 * y + 2.0 * z, + ]) + + +def quadratic_cartesian_field(xyz: jnp.ndarray) -> jnp.ndarray: + """Smooth non-linear field to exercise higher-degree rules.""" + x, y, z = xyz + return jnp.array([ + x * x + y + 0.5 * z, + x + y * y - 0.2 * z, + 0.1 * x + 0.3 * y + z * z, + ]) + + +def make_grid(rr=(0.4, 1.2, 4), ph=(0.0, math.pi/2, 3), zz=(-0.5, 0.5, 4)): + return GridSpec(rrange=rr, phi_range=ph, z_range=zz, value_size=3) + + +# Skip-function that masks a thin inner cylinder (r < rmin+0.05) +def make_skip_fn(grid: GridSpec): + rmin, rmax, _ = grid.r_range + cutoff = rmin + 0.05 + def _skip(r: jnp.ndarray, phi: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray: + # return boolean: True=>skip + del phi, z + return r < cutoff + return _skip + +# -------------------------------------------------------------------------------------- +# InterpolationRule (basis) tests +# -------------------------------------------------------------------------------------- + +@pytest.mark.parametrize("deg_cls", [UniformInterpolationRule, ChebyshevInterpolationRule]) +@pytest.mark.parametrize("degree", [1, 2, 3]) +def test_basis_kronecker_and_partition(deg_cls, degree): + rule = deg_cls(degree) + nodes = rule.nodes + + # 1) Kronecker delta at nodes: p_i(x_j) = δ_ij + for i in range(degree + 1): + pis = rule.basis(jnp.array(nodes[i])) + eye = jax.nn.one_hot(i, degree + 1, dtype=pis.dtype) + assert jnp.allclose(pis, eye, atol=1e-12) + + # 2) Partition of unity: sum_i p_i(x) = 1 for x in [0,1] + xs = jnp.linspace(0.0, 1.0, 31) + P = rule.basis(xs) # (d+1, 31) + s = jnp.sum(P, axis=0) + assert jnp.allclose(s, jnp.ones_like(xs), atol=1e-12) + +# -------------------------------------------------------------------------------------- +# RegularGridInterpolant3D structure & build/eval tests +# -------------------------------------------------------------------------------------- + +@pytest.mark.parametrize("deg_cls", [UniformInterpolationRule, ChebyshevInterpolationRule]) +def test_regular_grid_build_and_eval_linear_exact(deg_cls): + # degree 1 should reproduce linear fields exactly + rule = deg_cls(1) + grid = make_grid() + interp = RegularGridInterpolant3D(rule, grid, extrapolate=False, skip_fn=None) + + # fbatch maps (rvec, phivec, zvec) -> (Nd, 3); here use linear field in Cartesian projected to cyl + def fbatch(r, phi, z): + # build N x 3 xyz and evaluate linear field, then rotate to cylindrical + x = r * jnp.cos(phi) + y = r * jnp.sin(phi) + pts = jnp.stack([x, y, z], axis=1) + Bxyz = jax.vmap(linear_cartesian_field)(pts) + Bcyl = _cart_to_cyl_vectors(phi, Bxyz) + return Bcyl + + interp = interp.build(fbatch) + + # evaluate at random batch in-domain; rotate back to compare with original field + key = jax.random.PRNGKey(0) + rmin, rmax, _ = grid.r_range + pmin, pmax, _ = grid.phi_range + zmin, zmax, _ = grid.z_range + u = jax.random.uniform(key, (256, 3)) + r = rmin + (rmax - rmin) * u[:, 0] + phi = pmin + (pmax - pmin) * u[:, 1] + z = zmin + (zmax - zmin) * u[:, 2] + rphiz = jnp.stack([r, phi, z], axis=1) + + Bcyl = interp.evaluate_batch(rphiz) # (N,3) + Bxyz_pred = _cyl_to_cart_vectors(phi, Bcyl) # (N,3) + + xyz = jnp.stack([r * jnp.cos(phi), r * jnp.sin(phi), z], axis=1) + Bxyz_true = jax.vmap(linear_cartesian_field)(xyz) + + assert jnp.allclose(Bxyz_pred, Bxyz_true, atol=1e-11, rtol=1e-11) + + +def test_regular_grid_skip_fn_masks_inner_core(): + rule = UniformInterpolationRule(1) + grid = make_grid() + skip_fn = make_skip_fn(grid) + interp = RegularGridInterpolant3D(rule, grid, extrapolate=True, skip_fn=skip_fn) + + # basic structural properties + assert interp.r_dofs.ndim == 1 and interp.phi_dofs.ndim == 1 and interp.z_dofs.ndim == 1 + assert interp.vals.shape[1] == 3 # vector-valued + + # Build with any function; we only check that masked DOFs weren’t included + def fbatch(r, phi, z): + return jnp.stack([r, phi, z], axis=1) + + interp2 = interp.build(fbatch) + + # Any reduced dof should have r >= rmin+0.05 (mask removes smaller radii) + rmin, _, _ = grid.r_range + assert jnp.all(interp2.r_dofs >= rmin + 0.049999) # allow tiny numerical slack + +# -------------------------------------------------------------------------------------- +# InterpolatedField end-to-end tests (build_B, symmetry, xyz path, jit) +# -------------------------------------------------------------------------------------- + +@pytest.mark.parametrize("deg_cls", [UniformInterpolationRule, ChebyshevInterpolationRule]) +def test_interpolated_field_linear_exact_and_jittable(deg_cls): + # Build with linear base field: degree=1 interpolant should be exact + degree = 1 + grid = make_grid() + field = InterpolatedField( + base_field_cart=linear_cartesian_field, + degree=degree, + rrange=grid.r_range, + phirange=grid.phi_range, + zrange=grid.z_range, + extrapolate=False, + nfp=3, # test periodic reduction path + stellsym=True, # enable stellarator symmetry path + skip_fn=None, + use_chebyshev=(deg_cls == ChebyshevInterpolationRule), + build_gradabsb=False, + ) + field = field.build_B() + + # Batch xyz points + key = jax.random.PRNGKey(42) + rmin, rmax, _ = grid.r_range + pmin, pmax, _ = grid.phi_range + zmin, zmax, _ = grid.z_range + u = jax.random.uniform(key, (128, 3)) + r = rmin + (rmax - rmin) * u[:, 0] + phi = pmin + (pmax - pmin) * u[:, 1] + z = zmin + (zmax - zmin) * u[:, 2] + xyz = jnp.stack([r * jnp.cos(phi), r * jnp.sin(phi), z], axis=1) + + # Exactness for linear field + B_pred = field.B_xyz(xyz) + B_true = jax.vmap(linear_cartesian_field)(xyz) + assert jnp.allclose(B_pred, B_true, atol=1e-11, rtol=1e-11) + + # JIT smoke test: the jitted function should run & match + jit_fun = jax.jit(field.B_xyz) + B_jit = jit_fun(xyz) + assert jnp.allclose(B_jit, B_true, atol=1e-11, rtol=1e-11) + + +def test_interpolated_field_quadratic_uniform_vs_chebyshev_agree_on_grid_nodes(): + # With quadratic field, degree=2 should be exact on the interpolation nodes. + rr = (0.4, 1.2, 3) + ph = (0.0, math.pi/2, 3) + zz = (-0.6, 0.6, 3) + grid = GridSpec(rr, ph, zz, value_size=3) + + for use_cheb in [False, True]: + field = InterpolatedField( + base_field_cart=quadratic_cartesian_field, + degree=2, + rrange=grid.r_range, + phirange=grid.phi_range, + zrange=grid.z_range, + extrapolate=False, + nfp=1, + stellsym=False, + skip_fn=None, + use_chebyshev=use_cheb, + build_gradabsb=False, + ) + field = field.build_B() + + # sample the *dof nodes* of the underlying grid to guarantee exactness + interp_grid = field.interp_B + r_nodes = interp_grid.r_dofs + p_nodes = interp_grid.phi_dofs + z_nodes = interp_grid.z_dofs + R, P, Z = jnp.meshgrid(r_nodes, p_nodes, z_nodes, indexing="ij") + xyz = jnp.stack([R * jnp.cos(P), R * jnp.sin(P), Z], axis=-1).reshape(-1, 3) + + B_true = jax.vmap(quadratic_cartesian_field)(xyz) + B_pred = field.B_xyz(xyz) + assert jnp.allclose(B_pred, B_true, atol=1e-10, rtol=1e-10) + + +def test_symmetry_reflection_rules_consistency(): + # Build small field where we can reason about symmetry flips + # Use a base field symmetric in z apart from a linear term in r to trigger Br flip. + base = linear_cartesian_field + grid = make_grid(rr=(0.6, 0.8, 2), ph=(0.0, 2*math.pi/3, 2), zz=(-0.4, 0.4, 2)) + + field = InterpolatedField( + base_field_cart=base, + degree=1, + rrange=grid.r_range, + phirange=grid.phi_range, + zrange=grid.z_range, + extrapolate=True, + nfp=3, # 2pi/3 periodicity + stellsym=True, # enforce reflection logic in B_cyl + skip_fn=None, + use_chebyshev=False, + build_gradabsb=False, + ).build_B() + + # pick mirrored points: (r, phi, +z) and (r, 2pi - phi, -z) + r = jnp.array([0.7, 0.75, 0.78]) + phi = jnp.array([0.1, 0.6, 1.2]) + z = jnp.array([0.2, 0.3, 0.1]) + batch_pos = jnp.stack([r, phi, z], axis=1) + batch_mir = jnp.stack([r, 2*jnp.pi - phi, -z], axis=1) + + # Evaluate in cylindrical (internal path), then compare applying reflection rule + Bp = field.B_cyl(batch_pos) + Bm = field.B_cyl(batch_mir) + + # For z<0 reflect: Br flips sign; Bphi and Bz remain (per helper in code) + # Compare after manual flip on mirrored outputs + Br, Bphi, Bz = Bm.T + Bm_ref_applied = jnp.stack([-Br, Bphi, Bz], axis=1) + assert jnp.allclose(Bp, Bm_ref_applied, atol=1e-11, rtol=1e-11) + +# -------------------------------------------------------------------------------------- +# Grad|B| option & estimator (light regression) +# -------------------------------------------------------------------------------------- + +def test_build_gradabsb_and_shapes(): + grid = make_grid() + field = InterpolatedField( + base_field_cart=quadratic_cartesian_field, + degree=2, + rrange=grid.r_range, + phirange=grid.phi_range, + zrange=grid.z_range, + extrapolate=True, + nfp=1, + stellsym=False, + skip_fn=None, + use_chebyshev=True, + build_gradabsb=True, + ) + field = field.build_B() + field = field.build_GradAbsB() + + key = jax.random.PRNGKey(0) + rmin, rmax, _ = grid.r_range + pmin, pmax, _ = grid.phi_range + zmin, zmax, _ = grid.z_range + u = jax.random.uniform(key, (32, 3)) + r = rmin + (rmax - rmin) * u[:, 0] + phi = pmin + (pmax - pmin) * u[:, 1] + z = zmin + (zmax - zmin) * u[:, 2] + rphiz = jnp.stack([r, phi, z], axis=1) + + G = field.GradAbsB_cyl(rphiz) + assert G.shape == (32, 3) + assert jnp.all(jnp.isfinite(G)) + +def test_error_estimator_small_for_linear(): + grid = make_grid() + field = InterpolatedField( + base_field_cart=linear_cartesian_field, + degree=1, + rrange=grid.r_range, + phirange=grid.phi_range, + zrange=grid.z_range, + extrapolate=False, + nfp=1, + stellsym=False, + skip_fn=None, + use_chebyshev=False, + build_gradabsb=False, + ).build_B() + rms, mx = field.estimate_error_B(jax.random.PRNGKey(123), nsamples=2000) + assert rms < 1e-11 and mx < 1e-10 + +# -------------------------------------------------------------------------------------- +# Boundary & extrapolation behavior +# -------------------------------------------------------------------------------------- + +def test_boundary_nodes_and_extrapolation_off(): + grid = make_grid() + rule = UniformInterpolationRule(1) + interp = RegularGridInterpolant3D(rule, grid, extrapolate=False) + + # Simple identity function for values + def fbatch(r, phi, z): + return jnp.stack([r, phi, z], axis=1) + + interp = interp.build(fbatch) + + # evaluate exactly at max boundary (should clamp inside) + rmax = grid.r_range[0] + (grid.r_range[1] - grid.r_range[0]) + phmax = grid.phi_range[0] + (grid.phi_range[1] - grid.phi_range[0]) + zmax = grid.z_range[0] + (grid.z_range[1] - grid.z_range[0]) + + pts = jnp.array([ + [grid.r_range[0], grid.phi_range[0], grid.z_range[0]], + [rmax, phmax, zmax] + ]) + vals = interp.evaluate_batch(pts) + assert vals.shape == (2, 3) + # Within domain we get identity; at upper edge we still get finite values + assert jnp.all(jnp.isfinite(vals)) + +# -------------------------------------------------------------------------------------- +# Cyl<->Cart vector transforms (sanity) +# -------------------------------------------------------------------------------------- + +def test_cyl_cart_roundtrip_vectors(): + key = jax.random.PRNGKey(0) + N = 64 + phi = jax.random.uniform(key, (N,), minval=-math.pi, maxval=math.pi) + v = jax.random.normal(key, (N, 3)) + xyz = _cyl_to_cart_vectors(phi, v) + cyl = _cart_to_cyl_vectors(phi, xyz) + assert jnp.allclose(cyl, v, atol=1e-12, rtol=1e-12) diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py new file mode 100644 index 0000000..8a9ff1e --- /dev/null +++ b/tests/test_surfaces.py @@ -0,0 +1,289 @@ +# tests/test_surfaces.py +import math +import os +import numpy as np +import pytest +import jax +import jax.numpy as jnp + +# --- import subject under test --- +from essos.surfaces import ( + SurfaceRZFourier, + B_on_surface, + BdotN, + BdotN_over_B, + SurfaceClassifier, +) + +# ------------------------------------------------------------------------- +# Global JAX settings for numerical stability +# ------------------------------------------------------------------------- + +@pytest.fixture(scope="session", autouse=True) +def _enable_x64(): + jax.config.update("jax_enable_x64", True) + +# ------------------------------------------------------------------------- +# Helpers: Build an analytic circular torus surface via Fourier coefficients +# R(θ,φ) = R0 + a cos θ +# Z(θ,φ) = a sin θ +# (No φ-dependence; nfp can be arbitrary but we’ll use 1 and 4 in tests.) +# ------------------------------------------------------------------------- + +def make_circular_torus_surface( + R0=10.0, + a=2.0, + nfp=1, + ntheta=64, + nphi=48, + close=True, + range_torus="full torus", +): + """ + Construct SurfaceRZFourier via the (rc, zs, nfp) path with only (m=0,n=0) and (m=1,n=0) active: + rmnc(0,0)=R0, rmnc(1,0)=a, zmns(1,0)=a + """ + # mpol must be >= 2 to hold m=0 and m=1 rows + mpol = 2 + ntor = 0 # only n=0 + rc = jnp.zeros((mpol, 2 * ntor + 1)) + zs = jnp.zeros((mpol, 2 * ntor + 1)) + rc = rc.at[0, 0].set(R0) # m=0,n=0 + rc = rc.at[1, 0].set(a) # m=1,n=0 + zs = zs.at[1, 0].set(a) # m=1,n=0 + + surf = SurfaceRZFourier( + vmec=None, + s=1.0, + ntheta=ntheta, + nphi=nphi, + close=close, + range_torus=range_torus, + rc=rc, + zs=zs, + nfp=nfp, + ) + return surf + +# ------------------------------------------------------------------------- +# Mock field for B_on_surface / BdotN tests +# ------------------------------------------------------------------------- + +class ConstBzField: + """Simple mock field with B = (0,0,B0) everywhere (in Cartesian).""" + def __init__(self, B0=1.0): + self.B0 = B0 + + @staticmethod + def B(point_xyz): + # 'point_xyz' is (3,) but we ignore it + return jnp.array([0.0, 0.0, 1.0], dtype=jnp.float64) + + @staticmethod + def AbsB(point_xyz): + return jnp.array(1.0, dtype=jnp.float64) + +# ------------------------------------------------------------------------- +# Unit tests: geometry of SurfaceRZFourier on the analytic torus +# ------------------------------------------------------------------------- + +def test_gamma_matches_analytic_circular_torus(): + R0, a = 10.0, 2.0 + surf = make_circular_torus_surface(R0=R0, a=a, nfp=1, ntheta=64, nphi=48) + + theta_2d, phi_2d = surf.theta_2d, surf.phi_2d + R = R0 + a * jnp.cos(theta_2d) + Z = a * jnp.sin(theta_2d) + X = R * jnp.cos(phi_2d) + Y = R * jnp.sin(phi_2d) + + gamma = surf.gamma # (nphi, ntheta, 3) + assert gamma.shape == (surf.nphi, surf.ntheta, 3) + assert jnp.allclose(gamma[:, :, 0], X, atol=1e-12) + assert jnp.allclose(gamma[:, :, 1], Y, atol=1e-12) + assert jnp.allclose(gamma[:, :, 2], Z, atol=1e-12) + +def test_normals_are_unit_and_perpendicular_to_tangent(): + surf = make_circular_torus_surface(ntheta=48, nphi=32) + n = surf.unitnormal + gt = surf.gammadash_theta + gp = surf.gammadash_phi + + # unit length: + nlen = jnp.linalg.norm(n, axis=2) + assert jnp.allclose(nlen, 1.0, atol=1e-10) + + # orthogonal to both tangent directions: + dot_t = jnp.sum(n * gt, axis=2) + dot_p = jnp.sum(n * gp, axis=2) + assert jnp.allclose(dot_t, 0.0, atol=1e-10) + assert jnp.allclose(dot_p, 0.0, atol=1e-10) + +def test_mean_cross_section_area_matches_pi_a2(): + R0, a = 8.0, 1.5 + surf = make_circular_torus_surface(R0=R0, a=a, nfp=1, ntheta=96, nphi=64) + # For a circular torus, average poloidal cross-sectional area is π a^2 + area = surf.mean_cross_sectional_area() + assert jnp.allclose(area, math.pi * a * a, rtol=2e-3, atol=2e-3) # allow slight discretization error + +def test_dofs_setter_updates_geometry(): + R0, a = 9.0, 1.2 + surf = make_circular_torus_surface(R0=R0, a=a, nfp=1, ntheta=32, nphi=24) + # Keep original gamma: + g0 = jnp.array(surf.gamma) + + # Increase a by 10% by tweaking the corresponding coefficient in dofs: + # Layout in SurfaceRZFourier: dofs concatenates rc (flattened)[ntor:] then zs[ntor:] + # We placed rmnc(1,0)=a and zmns(1,0)=a originally; locate them in the rc/zs arrays. + idx_rm_m1n0 = 1 * (2 * 0 + 1) + 0 # m=1, n=0 within shape (mpol, 1) -> index 1 + idx_zs_m1n0 = 1 * (2 * 0 + 1) + 0 + + dofs = jnp.array(surf.dofs) + rc_len = surf.rc.size + zs_len = surf.zs.size + + # Current values: + assert np.isclose(surf.rc.ravel()[idx_rm_m1n0], a) + assert np.isclose(surf.zs.ravel()[idx_zs_m1n0], a) + + # Bump 'a' by 10% in both R and Z harmonics: + dofs = dofs.at[idx_rm_m1n0].set(1.1 * a) + dofs = dofs.at[rc_len + idx_zs_m1n0].set(1.1 * a) + surf.dofs = dofs + + g1 = surf.gamma + # Expect outward/inward displacement ~0.1*a in R/Z amplitudes; just check that geometry changed: + assert not jnp.allclose(g0, g1) + +# ------------------------------------------------------------------------- +# Field on surface: B_on_surface / BdotN / BdotN_over_B +# ------------------------------------------------------------------------- + +def test_B_on_surface_shapes_and_simple_values(): + surf = make_circular_torus_surface(ntheta=16, nphi=10) + field = ConstBzField(B0=1.0) + + Bout = B_on_surface(surf, field) + assert Bout.shape == (surf.nphi, surf.ntheta, 3) + # all Bz ~ 1; Bx=By=0: + assert jnp.allclose(Bout[..., 0], 0.0, atol=1e-12) + assert jnp.allclose(Bout[..., 1], 0.0, atol=1e-12) + assert jnp.allclose(Bout[..., 2], 1.0, atol=1e-12) + +def test_BdotN_and_BdotN_over_B_ranges(): + surf = make_circular_torus_surface(ntheta=24, nphi=18) + field = ConstBzField(B0=1.0) + + bn = BdotN(surf, field) + assert bn.shape == (surf.nphi, surf.ntheta) + # |B·n| <= |B| = 1 + assert jnp.all(bn <= 1.0 + 1e-12) + assert jnp.all(bn >= -1.0 - 1e-12) + + bn_over_B = BdotN_over_B(surf, field) + assert bn_over_B.shape == (surf.nphi, surf.ntheta) + assert jnp.all(bn_over_B <= 1.0 + 1e-12) + assert jnp.all(bn_over_B >= -1.0 - 1e-12) + # consistency: + assert jnp.allclose(bn_over_B, bn / 1.0, atol=1e-12) + +# ------------------------------------------------------------------------- +# SurfaceClassifier: build & evaluate +# ------------------------------------------------------------------------- + +@pytest.mark.parametrize("nfp", [1, 4]) +def test_surface_classifier_build_and_signs(nfp): + # Build surface and classifier (uses SciPy cKDTree path in your code) + R0, a = 10.0, 2.0 + surf = make_circular_torus_surface(R0=R0, a=a, nfp=nfp, ntheta=64, nphi=64) + + # Keep the grid coarser (h) for speed but still accurate + sc = SurfaceClassifier(surf, h=0.12, use_fundamental_phi=True) + + # (1) Points *on* the surface should give ~0 signed distance + # take a small set of surface samples: + th = jnp.linspace(0, 2 * jnp.pi, 9, endpoint=False) + ph = jnp.linspace(0, 2 * jnp.pi / nfp, 7, endpoint=False) + TH, PH = jnp.meshgrid(th, ph, indexing="ij") + R = R0 + a * jnp.cos(TH) + Z = a * jnp.sin(TH) + X = R * jnp.cos(PH) + Y = R * jnp.sin(PH) + XYZ = jnp.stack([X, Y, Z], axis=-1).reshape(-1, 3) + + d_on = sc.evaluate_xyz(XYZ) + assert d_on.shape == (XYZ.shape[0],) + assert jnp.all(jnp.abs(d_on) < 0.15) # coarse grid => small but nonzero acceptable + + # (2) A point strictly inside (e.g., near magnetic axis) should be positive + inside = jnp.array([R0, 0.0, 0.0]) + assert sc.evaluate_xyz(inside) > -1e-6 + + # (3) A point outside (R larger than R0 + a + margin) should be negative + outside = jnp.array([R0 + a + 0.5, 0.0, 0.0]) + assert sc.evaluate_xyz(outside) < 0.0 + +def test_surface_classifier_phi_wrapping_equivalence(): + surf = make_circular_torus_surface(R0=8.0, a=1.7, nfp=3, ntheta=48, nphi=48) + sc = SurfaceClassifier(surf, h=0.10, use_fundamental_phi=True) + + # Evaluate at same (r,z) but wildly different φ; wrapping should make them equal. + r = 8.5 + z = 0.2 + phi1 = 10.0 * math.pi # large + phi2 = 0.1 # small + d1 = sc.evaluate_rphiz(jnp.array([r, phi1, z])) + d2 = sc.evaluate_rphiz(jnp.array([r, phi2, z])) + assert jnp.allclose(d1, d2, atol=1e-10) + +def test_surface_classifier_vectorized_batch_xyz(): + surf = make_circular_torus_surface(R0=9.0, a=1.0, nfp=2, ntheta=40, nphi=40) + sc = SurfaceClassifier(surf, h=0.12, use_fundamental_phi=True) + + # batch of random xyz within a bounding box: + key = jax.random.PRNGKey(0) + xs = jax.random.uniform(key, (200,), minval=7.0, maxval=11.0) + ys = jax.random.uniform(key, (200,), minval=-3.0, maxval=3.0) + zs = jax.random.uniform(key, (200,), minval=-2.0, maxval=2.0) + xyz = jnp.stack([xs, ys, zs], axis=1) + + vals = sc.evaluate_xyz(xyz) + assert vals.shape == (200,) + assert jnp.all(jnp.isfinite(vals)) + +# ------------------------------------------------------------------------- +# Smoke tests for JIT compilation on classifier APIs +# ------------------------------------------------------------------------- + +def test_classifier_jit_smoke(): + surf = make_circular_torus_surface(R0=8.0, a=1.4, nfp=1, ntheta=32, nphi=32) + sc = SurfaceClassifier(surf, h=0.15, use_fundamental_phi=True) + + xyz = jnp.array([[8.0, 0.0, 0.0], + [9.2, 0.0, 0.1]]) + rphiz = jnp.array([[8.0, 7.0, 0.0], + [9.2, 20.0, 0.1]]) + + # JIT both methods: + f1 = jax.jit(sc.evaluate_xyz, static_argnames=("self",)) + f2 = jax.jit(sc.evaluate_rphiz, static_argnames=("self",)) + + out1 = f1(xyz) + out2 = f2(rphiz) + assert out1.shape == (2,) and out2.shape == (2,) + assert jnp.all(jnp.isfinite(out1)) and jnp.all(jnp.isfinite(out2)) + +# ------------------------------------------------------------------------- +# Optional: to_vmec round-trip smoke (file content check) +# ------------------------------------------------------------------------- + +def test_to_vmec_writes_expected_coeffs(tmp_path): + R0, a = 7.5, 1.1 + surf = make_circular_torus_surface(R0=R0, a=a, nfp=4, ntheta=16, nphi=16) + out = tmp_path / "surf_in.vmec" + surf.to_vmec(str(out)) + txt = out.read_text() + # Sanity: NFP present and at least some RBC/ZBS lines with our values + assert "NFP = 4" in txt + assert "RBC(" in txt and "ZBS(" in txt + assert any("RBC(" in line and "ZBS(" in line for line in txt.splitlines()) From 308fc691b23e43ce5e39434c8ac93378666e635c Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sun, 21 Sep 2025 21:05:21 -0500 Subject: [PATCH 3/6] Refactor and optimize VMEC native interpolant integration in tracing example --- essos/interpolated_field.py | 165 +++++++++++++++ examples/trace_fieldlines_interpolated.py | 237 ++++------------------ 2 files changed, 207 insertions(+), 195 deletions(-) diff --git a/essos/interpolated_field.py b/essos/interpolated_field.py index 0bf72b2..89f5bdf 100644 --- a/essos/interpolated_field.py +++ b/essos/interpolated_field.py @@ -28,6 +28,11 @@ import equinox as eqx from jax import lax +from functools import partial +from jax import block_until_ready +from jax.scipy.interpolate import RegularGridInterpolator as JaxRGI +import time + # -------------------------------------------------------------------------------------- # Utility: Lagrange basis with precomputed nodes and scalings (barycentric‑like product) # -------------------------------------------------------------------------------------- @@ -521,3 +526,163 @@ def estimate_error_B(self, key: jax.Array, nsamples: int = 10_000) -> Tuple[floa rms = jnp.sqrt(jnp.mean(diff**2)) mx = jnp.max(diff) return float(rms), float(mx) + +def _ensure_batch(x: jnp.ndarray): + """Accept (3,) or (N,3). Return (N,3) and a flag indicating 'was_single'.""" + x = jnp.asarray(x) + if x.ndim == 1: + return x[None, :], True + return x, False + + +def _unbatch(out: jnp.ndarray, was_single: bool): + """If 'out' is (N,…) and was_single, return out[0]; else return out.""" + return out[0] if was_single else out + + +class InterpolatedVmecNative(eqx.Module): + """ + Per-component regular-grid interpolators in native VMEC coordinates (s,θ,φ). + Exposes jittable evaluators for B_covariant, B_contravariant, and sqrtg, + and convenience passthroughs for AbsB and to_xyz. + """ + vmec: any + srange: Tuple[float, float, int] + thetarange: Tuple[float, float, int] + phirange: Tuple[float, float, int] + _rgis: dict = eqx.static_field() + + def __init__( + self, + vmec, + srange: Tuple[float, float, int] = (0.0, 1.0, 24), + thetarange: Tuple[float, float, int] = (0.0, 2 * jnp.pi, 48), + phirange: Tuple[float, float, int] = (0.0, None, 64), + ): + self.vmec = vmec + nfp = int(vmec.nfp) + if phirange[1] is None: + phirange = (0.0, 2 * jnp.pi / nfp, phirange[2]) + self.srange, self.thetarange, self.phirange = srange, thetarange, phirange + object.__setattr__(self, "_rgis", {}) + + def build_all(self) -> "InterpolatedVmecNative": + print("[build] Precomputing grids & field values for RGI ...") + t0 = time.perf_counter() + s0, s1, ns = self.srange + th0, th1, nth = self.thetarange + ph0, ph1, nph = self.phirange + + s_list = jnp.linspace(s0, s1, ns) + th_list = jnp.linspace(th0, th1, nth) + ph_list = jnp.linspace(ph0, ph1, nph) + + # Tensor grid (ij indexing so reshape matches (ns,nth,nph)) + SS, TT, PP = jnp.meshgrid(s_list, th_list, ph_list, indexing="ij") + pts_flat = jnp.stack([SS.ravel(), TT.ravel(), PP.ravel()], axis=1) # (N,3) + N = pts_flat.shape[0] + print(f"[build] grid sizes: ns={ns}, nth={nth}, nph={nph} -> N={N}") + + # Batch evaluate VMEC quantities + t_eval = time.perf_counter() + Bcov = jax.vmap(self.vmec.B_covariant)(pts_flat) # (N,3) + Bcon = jax.vmap(self.vmec.B_contravariant)(pts_flat) # (N,3) + sqrtg = jax.vmap(self.vmec.sqrtg)(pts_flat) # (N,) + Bcov, Bcon, sqrtg = (block_until_ready(Bcov), + block_until_ready(Bcon), + block_until_ready(sqrtg)) + print(f"[build] VMEC eval on grid: {time.perf_counter() - t_eval:.2f}s") + + # Reshape to tensor grid + Bcov = Bcov.reshape((ns, nth, nph, 3)) + Bcon = Bcon.reshape((ns, nth, nph, 3)) + sqrtg = sqrtg.reshape((ns, nth, nph)) + + # Build RGIs (3+3+1) + t_rgi = time.perf_counter() + + def rgi3(A3): + # A3 shape = (ns,nth,nph) + return JaxRGI((s_list, th_list, ph_list), A3, fill_value=None) + + rgis = { + "Bcov0": rgi3(Bcov[..., 0]), + "Bcov1": rgi3(Bcov[..., 1]), + "Bcov2": rgi3(Bcov[..., 2]), + "Bcon0": rgi3(Bcon[..., 0]), + "Bcon1": rgi3(Bcon[..., 1]), + "Bcon2": rgi3(Bcon[..., 2]), + "sqrtg": rgi3(sqrtg), + } + object.__setattr__(self, "_rgis", rgis) + + print(f"[build] RGI build: {time.perf_counter() - t_rgi:.2f}s") + print(f"[build] total: {time.perf_counter() - t0:.2f}s") + return self + + # ---------- public, jittable evaluators with shape handling ---------- + @eqx.filter_jit + def B_covariant(self, s_th_phi: jnp.ndarray) -> jnp.ndarray: + pts, single = _ensure_batch(s_th_phi) + v0 = self._rgis["Bcov0"](pts) + v1 = self._rgis["Bcov1"](pts) + v2 = self._rgis["Bcov2"](pts) + val = jnp.stack([v0, v1, v2], axis=-1) # (N,3) + return _unbatch(val, single) # (3,) or (N,3) + + @eqx.filter_jit + def B_contravariant(self, s_th_phi: jnp.ndarray) -> jnp.ndarray: + pts, single = _ensure_batch(s_th_phi) + v0 = self._rgis["Bcon0"](pts) + v1 = self._rgis["Bcon1"](pts) + v2 = self._rgis["Bcon2"](pts) + val = jnp.stack([v0, v1, v2], axis=-1) # (N,3) + return _unbatch(val, single) # (3,) or (N,3) + + @eqx.filter_jit + def sqrtg(self, s_th_phi: jnp.ndarray) -> jnp.ndarray: + pts, single = _ensure_batch(s_th_phi) + g = self._rgis["sqrtg"](pts) # (N,) + return _unbatch(g, single) # () or (N,) + + # ---------- convenience passthroughs ---------- + @partial(jax.jit, static_argnames=("self",)) + def AbsB(self, pts_stp): + """Accept (..., 3) or (3,) and return (...,). Works for any batch rank.""" + pts = jnp.asarray(pts_stp) + if pts.ndim == 1: # (3,) + return self.vmec.AbsB(pts) + if pts.shape[-1] != 3: + raise ValueError(f"AbsB expects last dim = 3; got {pts.shape}") + leading = pts.shape[:-1] + flat = pts.reshape((-1, 3)) + out = jax.vmap(self.vmec.AbsB)(flat) # (N,) + return out.reshape(leading) + + @partial(jax.jit, static_argnames=("self",)) + def to_xyz(self, pts_stp): + """Accept (..., 3) or (3,) and return (..., 3). Works for any batch rank.""" + pts = jnp.asarray(pts_stp) + if pts.ndim == 1: # (3,) + return self.vmec.to_xyz(pts) + if pts.shape[-1] != 3: + raise ValueError(f"to_xyz expects last dim = 3; got {pts.shape}") + # Flatten leading dims, vmap, then reshape back + leading = pts.shape[:-1] + flat = pts.reshape((-1, 3)) + out = jax.vmap(self.vmec.to_xyz)(flat) # (N, 3) + return out.reshape(leading + (3,)) + + +def build_vmec_native_interpolant( + vmec, + srange=(0.0, 1.0, 33), + thetarange=(0.0, 2 * jnp.pi, 48), + phirange=(0.0, None, 64), +) -> InterpolatedVmecNative: + """Helper factory with one-liner build, plus prints for timing.""" + print("[stage] Building native (s,θ,φ) interpolants: B_cov, B_con, sqrtg ...") + t = time.perf_counter() + interp = InterpolatedVmecNative(vmec, srange, thetarange, phirange).build_all() + print(f"[time] Interpolant build total: {time.perf_counter() - t:.2f}s") + return interp diff --git a/examples/trace_fieldlines_interpolated.py b/examples/trace_fieldlines_interpolated.py index d966484..57ca118 100644 --- a/examples/trace_fieldlines_interpolated.py +++ b/examples/trace_fieldlines_interpolated.py @@ -6,235 +6,80 @@ import jax import jax.numpy as jnp -from jax import block_until_ready, vmap -from jax.scipy.interpolate import RegularGridInterpolator as JaxRGI -from functools import partial +from jax import block_until_ready import matplotlib.pyplot as plt -# ---- Your ESSOS imports ---- from essos.fields import Vmec from essos.surfaces import SurfaceClassifier +from essos.interpolated_field import ( + build_vmec_native_interpolant, + InterpolatedVmecNative, +) -# ---- Import the RegularGridInterpolant3D / rules you already have ---- -from essos.interpolated_field import RegularGridInterpolant3D, UniformInterpolationRule, ChebyshevInterpolationRule - -# --------------------------------------------------------------------------------- -# Helper: If you ever build an (r,φ,z)-space interpolant, you can use this skip_fn -# For native (s,θ,φ) it is typically unnecessary, so this is OPTIONAL here. -# --------------------------------------------------------------------------------- -def make_skip_fn_from_classifier(sc, buffer=0.04): - """Return (rvec, phivec, zvec) -> bool[N] mask; True means 'skip' (outside).""" - def _skip(rvec: jnp.ndarray, phivec: jnp.ndarray, zvec: jnp.ndarray) -> jnp.ndarray: - rphiz = jnp.stack([rvec, phivec, zvec], axis=1) # (N,3) - d = jax.vmap(sc.evaluate_rphiz)(rphiz) # (N,) - return (d < -buffer) - return _skip - -# --------------------------------------------------------------------------------- -# Native-coordinate interpolated field for VMEC: (s,θ,φ)-> {B_cov, B_con, sqrtg} -# --------------------------------------------------------------------------------- -class InterpolatedVmecNative: - def __init__(self, vmec, - srange=(0.0, 1.0, 24), - thetarange=(0.0, 2*jnp.pi, 48), - phirange=(0.0, None, 48)): - self.vmec = vmec - nfp = int(vmec.nfp) - if phirange[1] is None: - phirange = (0.0, 2*jnp.pi/nfp, phirange[2]) - self.srange, self.thetarange, self.phirange = srange, thetarange, phirange - self._rgis = None # will hold dict of per-component RGIs - - def build_all(self): - print("[build] Precomputing grids & field values for RGI ...") - t0 = time.perf_counter() - s0, s1, ns = self.srange - th0, th1, nth = self.thetarange - ph0, ph1, nph = self.phirange - - s_list = jnp.linspace(s0, s1, ns) - th_list = jnp.linspace(th0, th1, nth) - ph_list = jnp.linspace(ph0, ph1, nph) - - # Tensor grid (ij indexing so reshape matches (ns,nth,nph)) - SS, TT, PP = jnp.meshgrid(s_list, th_list, ph_list, indexing="ij") - pts_flat = jnp.stack([SS.ravel(), TT.ravel(), PP.ravel()], axis=1) # (N,3) - N = pts_flat.shape[0] - print(f"[build] grid sizes: ns={ns}, nth={nth}, nph={nph} -> N={N}") - - # Batch evaluate VMEC (cartesian) quantities - t_eval = time.perf_counter() - Bcov = jax.vmap(self.vmec.B_covariant)(pts_flat) # (N,3) - Bcon = jax.vmap(self.vmec.B_contravariant)(pts_flat) # (N,3) - sqrtg = jax.vmap(self.vmec.sqrtg)(pts_flat) # (N,) - # Force computation/timing: - Bcov, Bcon, sqrtg = jax.block_until_ready(Bcov), jax.block_until_ready(Bcon), jax.block_until_ready(sqrtg) - print(f"[build] VMEC eval on grid: {time.perf_counter()-t_eval:.2f}s") - - # Reshape to tensor grid - Bcov = Bcov.reshape((ns, nth, nph, 3)) - Bcon = Bcon.reshape((ns, nth, nph, 3)) - sqrtg = sqrtg.reshape((ns, nth, nph)) - - # Build 7 RGIs (3+3+1), fill_value extrapolates as you prefer: - t_rgi = time.perf_counter() - def rgi3(A3): - # A3 shape = (ns,nth,nph) - return JaxRGI((s_list, th_list, ph_list), A3, fill_value=None) - - rgis = { - "Bcov0": rgi3(Bcov[..., 0]), - "Bcov1": rgi3(Bcov[..., 1]), - "Bcov2": rgi3(Bcov[..., 2]), - "Bcon0": rgi3(Bcon[..., 0]), - "Bcon1": rgi3(Bcon[..., 1]), - "Bcon2": rgi3(Bcon[..., 2]), - "sqrtg": rgi3(sqrtg), - } - self._rgis = rgis - print(f"[build] RGI build: {time.perf_counter()-t_rgi:.2f}s") - print(f"[build] total: {time.perf_counter()-t0:.2f}s") - return self - - # ---------------- evaluate (batched & JIT) ---------------- - @partial(jax.jit, static_argnames=("self",)) - def B_covariant(self, pts_stp): # pts_stp (...,3) - x = pts_stp - b0 = self._rgis["Bcov0"](x) - b1 = self._rgis["Bcov1"](x) - b2 = self._rgis["Bcov2"](x) - return jnp.stack([b0, b1, b2], axis=-1) - - @partial(jax.jit, static_argnames=("self",)) - def B_contravariant(self, pts_stp): - x = pts_stp - b0 = self._rgis["Bcon0"](x) - b1 = self._rgis["Bcon1"](x) - b2 = self._rgis["Bcon2"](x) - return jnp.stack([b0, b1, b2], axis=-1) - - @partial(jax.jit, static_argnames=("self",)) - def sqrtg(self, pts_stp): - return self._rgis["sqrtg"](pts_stp) - - @partial(jax.jit, static_argnames=("self",)) - def AbsB(self, pts_stp): - # Optional, reuse vmec.AbsB to avoid building another set of RGIs: - return jax.vmap(self.vmec.AbsB)(pts_stp) - - -# ------------------------------------------------------------------------------------------------- -# Script body: build interpolants in native coords; trace using your native EOM (s,θ,φ dynamics) -# ------------------------------------------------------------------------------------------------- +# ---------------------------------- load VMEC ---------------------------------- t0 = time.perf_counter() print("[stage] Loading VMEC ...") wout_file = os.path.join(os.path.dirname(__file__), "input_files", "wout_QH_simple_scaled.nc") vmec = Vmec(wout_file) print(f"[time] VMEC load: {time.perf_counter()-t0:.2f}s (nfp={vmec.nfp})") +# ------------------------ (optional) SurfaceClassifier ------------------------- print("[stage] Building SurfaceClassifier (for diagnostics / optional skip) ...") t_sc = time.perf_counter() -sc = SurfaceClassifier(vmec.surface, h=0.06) +_ = SurfaceClassifier(vmec.surface, h=0.06) # keeps the nice prints, optional print(f"[time] SurfaceClassifier: {time.perf_counter()-t_sc:.2f}s") -# Native grid in (s,θ,φ): -srange = (0.0, 1.0, 24) -thetarange = (0.0, 2*jnp.pi, 48) -phirange = (0.0, 2*jnp.pi/int(vmec.nfp), 48) - -print("[stage] Building native (s,θ,φ) interpolants: B_cov, B_con, sqrtg ...") -t_build = time.perf_counter() -interp = InterpolatedVmecNative( +# --------------------- build native (s,θ,φ) interpolants ----------------------- +interp: InterpolatedVmecNative = build_vmec_native_interpolant( vmec, - srange=(0.0, 1.0, 33), # ns - thetarange=(0.0, 2*jnp.pi, 48), # nth - phirange=(0.0, None, 64), # nph; None -> 2π/nfp -).build_all() -interp = interp.build_all() -print(f"[time] Interpolant build total: {time.perf_counter()-t_build:.2f}s") - -# Adapter exposing the same API your Tracing uses (native EOM): -class VmecNativeAdapter: - def __init__(self, base_vmec, interp_native): - self.vmec = base_vmec - self.I = interp_native - self.nfp = base_vmec.nfp # keep attribute parity - - # EOM-relevant pieces in native coordinates: - def B_covariant(self, pts_stp): - return self.I.B_covariant(pts_stp) - - def B_contravariant(self, pts_stp): - return self.I.B_contravariant(pts_stp) - - def sqrtg(self, pts_stp): - return self.I.sqrtg(pts_stp) - - # Optional: if Tracing uses these too: - def AbsB(self, pts_stp): - return self.I.AbsB(pts_stp) + srange=(0.0, 1.0, 33), + thetarange=(0.0, 2*jnp.pi, 48), + phirange=(0.0, None, 64), # None -> 2π/nfp +) - # Geometry map (already provided by your Vmec) - def to_xyz(self, pts_stp): - return vmap(self.vmec.to_xyz)(pts_stp) - - # If Tracing occasionally calls B in Cartesian, you can provide: - # def B(self, pts_stp): - # # Cartesian B from covariant basis vectors: - # # You already have vmec.B that returns Cartesian from (s,θ,φ); - # # If you want to be fully interpolated, build a separate Cartesian interpolant. - # return vmap(self.vmec.B)(pts_stp) - -bfield = VmecNativeAdapter(vmec, interp) - -# --------------------- -# Initial conditions -# --------------------- +# ------------------------------ initial conditions ----------------------------- t_init = time.perf_counter() nfieldlines_per_core = 6 nfieldlines = number_of_processors_to_use * nfieldlines_per_core -# Choose initial (s,θ,φ). Your prior script used Cartesian R0/Z0/φ0; here we stay native: -s0 = jnp.linspace(0.02, 0.98, nfieldlines) # avoid exact boundary +s0 = jnp.linspace(0.02, 0.98, nfieldlines) th0 = jnp.zeros(nfieldlines) ph0 = jnp.zeros(nfieldlines) initial_stp = jnp.stack([s0, th0, ph0], axis=1) print(f"[time] Init conditions set: {time.perf_counter()-t_init:.2f}s (n={nfieldlines})") -# --------------------- -# Trace in ESSOS (native EOM) -# --------------------- +# ----------------------------- quick shape sanity ------------------------------ +print("[stage] JIT warmup for Tracing ...") +_test = jnp.array([0.5, 0.0, 0.0]) +print("[dbg] B_con shape:", interp.B_contravariant(_test).shape) +print("[dbg] B_cov shape:", interp.B_covariant(_test).shape) +print("[dbg] sqrtg shape:", jnp.shape(interp.sqrtg(_test))) +print("[dbg] to_xyz shape:", interp.to_xyz(_test).shape) + +# ----------------------------------- trace ------------------------------------ from essos.dynamics import Tracing -t_trace = time.perf_counter() tmax = 1500 trace_tolerance = 1e-10 num_steps = 10000 -print("[stage] JIT warmup for Tracing ...") -# Hint: do a tiny warmup call if Tracing JITs internally; otherwise first call will take longer. - print("[stage] Running Tracing ...") print("[stage] Tracing fieldlines using interpolated field (native s,θ,φ) …") t0 = time.perf_counter() tracing = block_until_ready(Tracing( - field=interp, # <= use interpolant here + field=interp, # interpolated native field model='FieldLineAdaptative', - initial_conditions=initial_stp, # still XYZ for seeding; your Tracing will convert + initial_conditions=initial_stp, # native initial conditions maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance, - rtol=trace_tolerance + rtol=trace_tolerance, )) print(f"[time] ESSOS tracing: {time.perf_counter()-t0:.2f}s") -# Trajectories are in native (s,θ,φ). Convert to xyz for plotting: -t_xyz = time.perf_counter() +# ------------------------- grab trajectories and plot -------------------------- trajectories_stp = tracing.trajectories -trajectories_xyz = vmap(vmap(bfield.to_xyz))(trajectories_stp) -print(f"[time] stp->xyz conversion: {time.perf_counter()-t_xyz:.2f}s") +# If your Tracing already produces xyz, prefer it: +trajectories_xyz = getattr(tracing, "trajectories_xyz", interp.to_xyz(trajectories_stp)) -# ------------- -# Plot results -# ------------- print("[stage] Plotting ...") t_plot = time.perf_counter() fig = plt.figure(figsize=(9, 5)) @@ -242,11 +87,10 @@ def to_xyz(self, pts_stp): ax2 = fig.add_subplot(122) vmec.surface.plot(ax=ax1, show=False) -# quick plot of trajectories in xyz for tr in trajectories_xyz: - ax1.plot(tr[:,0], tr[:,1], tr[:,2], lw=0.8, alpha=0.8) + ax1.plot(tr[:, 0], tr[:, 1], tr[:, 2], lw=0.8, alpha=0.8) -# Simple Poincaré at φ ≈ 0 using native coordinates directly: +# Simple Poincaré at φ ≈ 0 using native coordinates directly nfp = int(vmec.nfp) phis = trajectories_stp[..., 2] mask = jnp.isclose((phis % (2*jnp.pi/nfp)), 0.0, atol=2e-3) @@ -259,17 +103,20 @@ def to_xyz(self, pts_stp): plt.tight_layout(); plt.show() print(f"[time] Plotting: {time.perf_counter()-t_plot:.2f}s") -# Optional: quick accuracy probe at random points +# ------------------------------ interpolation check ---------------------------- key = jax.random.key(0) +s0, s1, ns = (0.0, 1.0, 33) +th0, th1, nth = (0.0, 2*jnp.pi, 48) +ph0, ph1, nph = (0.0, 2*jnp.pi/nfp, 64) Nsamp = 4096 -s_s = jax.random.uniform(key, (Nsamp,), minval=srange[0], maxval=srange[1]) -th_s = jax.random.uniform(key, (Nsamp,), minval=thetarange[0], maxval=thetarange[1]) -ph_s = jax.random.uniform(key, (Nsamp,), minval=phirange[0], maxval=phirange[1]) +s_s = jax.random.uniform(key, (Nsamp,), minval=s0, maxval=s1) +th_s = jax.random.uniform(key, (Nsamp,), minval=th0, maxval=th1) +ph_s = jax.random.uniform(key, (Nsamp,), minval=ph0, maxval=ph1) pts = jnp.stack([s_s, th_s, ph_s], axis=1) t_err = time.perf_counter() Bcov_true = jax.vmap(vmec.B_covariant)(pts) -Bcov_interp = bfield.B_covariant(pts) +Bcov_interp = interp.B_covariant(pts) err = jnp.linalg.norm(Bcov_true - Bcov_interp, axis=1) print(f"[check] RMS|B_cov-Interp|={jnp.sqrt(jnp.mean(err**2)):.3e}, " - f"Max={jnp.max(err):.3e} (computed in {time.perf_counter()-t_err:.2f}s)") + f"Max={jnp.max(err):.3e} (computed in {time.perf_counter()-t_err:.2f}s)") From db64eccafddd26f3531f3b8783958db3dafe5a9d Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sun, 21 Sep 2025 21:07:51 -0500 Subject: [PATCH 4/6] Add tests for InterpolatedVmecNative functionality and behavior --- tests/test_interpolated_field.py | 214 +++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/tests/test_interpolated_field.py b/tests/test_interpolated_field.py index 0c1f873..cd1d96b 100644 --- a/tests/test_interpolated_field.py +++ b/tests/test_interpolated_field.py @@ -14,6 +14,7 @@ InterpolatedField, _cart_to_cyl_vectors, _cyl_to_cart_vectors, + InterpolatedVmecNative, ) # -------------------------------------------------------------------------------------- @@ -360,3 +361,216 @@ def test_cyl_cart_roundtrip_vectors(): xyz = _cyl_to_cart_vectors(phi, v) cyl = _cart_to_cyl_vectors(phi, xyz) assert jnp.allclose(cyl, v, atol=1e-12, rtol=1e-12) + +# -------------------------------------------------------------------------------------- +# InterpolatedVmecNative tests +# -------------------------------------------------------------------------------------- + +class MockVmec: + """ + Minimal stand-in that behaves like Vmec for native (s,θ,φ) calls, + using functions that are linear in (s,θ,φ) so trilinear interpolation + should be exact on any grid. + """ + def __init__(self, nfp=4, R0=10.0): + self.nfp = nfp + self.R0 = R0 + + # Coeffs for linear forms: a0 + a1*s + a2*th + a3*ph + self.bc_a = jnp.array([ + [ 0.1, 0.4, -0.3, 0.2], # B_cov[0] + [-0.2, 0.1, 0.5, 0.7], # B_cov[1] + [ 0.3, -0.6, 0.2, -0.4], # B_cov[2] + ]) + self.bn_a = jnp.array([ + [ 0.05, -0.2, 0.1, 0.3], # B_con[0] + [-0.1, 0.3, 0.4, -0.2], # B_con[1] + [ 0.6, 0.5, -0.3, 0.1], # B_con[2] + ]) + self.g_a = jnp.array([1.1, 0.2, -0.15, 0.05]) # sqrtg + + # ---------- helpers ---------- + @staticmethod + def _lin(a, s, th, ph): + # a: (4,) -> a0 + a1*s + a2*th + a3*ph + return a[0] + a[1]*s + a[2]*th + a[3]*ph + + # ---------- native API ---------- + def B_covariant(self, points): + s, th, ph = points + return jnp.array([self._lin(self.bc_a[0], s, th, ph), + self._lin(self.bc_a[1], s, th, ph), + self._lin(self.bc_a[2], s, th, ph)]) + + def B_contravariant(self, points): + s, th, ph = points + return jnp.array([self._lin(self.bn_a[0], s, th, ph), + self._lin(self.bn_a[1], s, th, ph), + self._lin(self.bn_a[2], s, th, ph)]) + + def sqrtg(self, points): + s, th, ph = points + return self._lin(self.g_a, s, th, ph) + + def to_xyz(self, points): + # simple tokamak-like embedding: R = R0 + s*cos(th), Z = s*sin(th) + # X=R*cos(ph), Y=R*sin(ph) + s, th, ph = points + R = self.R0 + s*jnp.cos(th) + Z = s*jnp.sin(th) + X = R*jnp.cos(ph) + Y = R*jnp.sin(ph) + return jnp.array([X, Y, Z]) + + def AbsB(self, points): + # Just a linear combination; not physically meaningful, but deterministic. + # (Used only for shape/broadcast tests in InterpolatedVmecNative.) + s, th, ph = points + return 2.0 + 0.1*s - 0.05*th + 0.02*ph + + +def _sample_native_box(key, srange, trange, prange, N): + s0, s1, _ = srange + t0, t1, _ = trange + p0, p1, _ = prange + u = jax.random.uniform(key, (N, 3)) + s = s0 + (s1 - s0) * u[:, 0] + th = t0 + (t1 - t0) * u[:, 1] + ph = p0 + (p1 - p0) * u[:, 2] + return jnp.stack([s, th, ph], axis=1) + + +def test_vmec_native_build_and_exact_linear(): + """ + Since MockVmec is linear in (s,θ,φ), the trilinear RGI should be exact across the grid. + """ + vm = MockVmec(nfp=4) + interp = InterpolatedVmecNative( + vm, + srange=(0.0, 1.0, 16), + thetarange=(0.0, 2*math.pi, 17), + phirange=(0.0, 2*math.pi/vm.nfp, 19) + ).build_all() + + key = jax.random.PRNGKey(0) + pts = _sample_native_box(key, interp.srange, interp.thetarange, interp.phirange, N=512) + + Bc_true = jax.vmap(vm.B_covariant)(pts) + Bn_true = jax.vmap(vm.B_contravariant)(pts) + g_true = jax.vmap(vm.sqrtg)(pts) + + Bc_pred = interp.B_covariant(pts) + Bn_pred = interp.B_contravariant(pts) + g_pred = interp.sqrtg(pts) + + assert Bc_pred.shape == (512, 3) + assert Bn_pred.shape == (512, 3) + assert g_pred.shape == (512,) + assert jnp.allclose(Bc_pred, Bc_true, atol=1e-12, rtol=1e-12) + assert jnp.allclose(Bn_pred, Bn_true, atol=1e-12, rtol=1e-12) + assert jnp.allclose(g_pred, g_true, atol=1e-12, rtol=1e-12) + + +def test_vmec_native_shapes_and_broadcast(): + vm = MockVmec() + interp = InterpolatedVmecNative( + vm, + srange=(0.0, 1.0, 8), + thetarange=(0.0, 2*math.pi, 9), + phirange=(0.0, 2*math.pi/vm.nfp, 10) + ).build_all() + + # single point + p = jnp.array([0.3, 1.1, 0.2]) + assert interp.B_covariant(p).shape == (3,) + assert interp.B_contravariant(p).shape == (3,) + assert jnp.shape(interp.sqrtg(p)) == () + + # batched (N,3) + P = jnp.stack([p, p + 0.01], axis=0) + assert interp.B_covariant(P).shape == (2, 3) + assert interp.B_contravariant(P).shape == (2, 3) + assert interp.sqrtg(P).shape == (2,) + + # higher-rank (...,3) + P3 = jnp.stack([P, P], axis=0) # (2,2,3) + # to_xyz/AbsB must broadcast over any leading dims + XYZ = interp.to_xyz(P3) + AB = interp.AbsB(P3) + assert XYZ.shape == (2, 2, 3) + assert AB.shape == (2, 2) + + # compare to underlying vmec for correctness on (N,3) + Bc_true = jax.vmap(vm.B_covariant)(P) + Bc_pred = interp.B_covariant(P) + assert jnp.allclose(Bc_pred, Bc_true, atol=1e-12, rtol=1e-12) + + +def test_vmec_native_jit_smoke(): + vm = MockVmec() + interp = InterpolatedVmecNative(vm).build_all() + + f1 = jax.jit(interp.B_covariant) + f2 = jax.jit(interp.B_contravariant) + f3 = jax.jit(interp.sqrtg) + f4 = jax.jit(interp.to_xyz) + f5 = jax.jit(interp.AbsB) + + p = jnp.array([0.25, 0.5, 0.1]) + # They should run and match non-jitted outputs + assert jnp.allclose(f1(p), interp.B_covariant(p)) + assert jnp.allclose(f2(p), interp.B_contravariant(p)) + assert jnp.allclose(f3(p), interp.sqrtg(p)) + assert jnp.allclose(f4(p), interp.to_xyz(p)) + assert jnp.allclose(f5(p), interp.AbsB(p)) + + +def test_vmec_native_edge_points_in_range(): + """ + Evaluate exactly on edges of the interpolation domain; should be finite and consistent. + """ + vm = MockVmec() + s0, s1, ns = (0.0, 1.0, 7) + t0, t1, nt = (0.0, 2*math.pi, 8) + p0, p1, np_ = (0.0, 2*math.pi/vm.nfp, 9) + interp = InterpolatedVmecNative(vm, + srange=(s0, s1, ns), + thetarange=(t0, t1, nt), + phirange=(p0, p1, np_)).build_all() + + # corners + pts = jnp.array([ + [s0, t0, p0], + [s1, t1, p1], + [s0, t1, p1], + [s1, t0, p0], + ]) + for fn_true, fn_pred in [ + (lambda q: jax.vmap(vm.B_covariant)(q), interp.B_covariant), + (lambda q: jax.vmap(vm.B_contravariant)(q), interp.B_contravariant), + (lambda q: jax.vmap(vm.sqrtg)(q), interp.sqrtg), + ]: + A = fn_pred(pts) + B = fn_true(pts) + assert jnp.all(jnp.isfinite(A)) + assert jnp.allclose(A, B, atol=1e-12, rtol=1e-12) + + +def test_vmec_native_to_xyz_roundtrip_shapes(): + """ + Not an exact inverse test (we don’t have stp<-xyz), but we can at least check + that to_xyz preserves leading batch dims and gives plausible ranges. + """ + vm = MockVmec(R0=10.0) + interp = InterpolatedVmecNative(vm, + srange=(0.0, 1.0, 6), + thetarange=(0.0, 2*math.pi, 7), + phirange=(0.0, 2*math.pi/vm.nfp, 8)).build_all() + + key = jax.random.PRNGKey(123) + pts = _sample_native_box(key, interp.srange, interp.thetarange, interp.phirange, N=128) # (128,3) + xyz = interp.to_xyz(pts) + assert xyz.shape == (128, 3) + # plausible radii near R0..R0+1 + R = jnp.linalg.norm(xyz[:, :2], axis=1) + assert jnp.all((R >= vm.R0 - 1.01) & (R <= vm.R0 + 1.01)) From 8a6dc28ad124d66bafe92305a70ce95295eed528 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Mon, 22 Sep 2025 08:43:27 -0500 Subject: [PATCH 5/6] Fix GridSpec parameter name in make_grid function and simplify JIT compilation in surface tests --- tests/test_interpolated_field.py | 2 +- tests/test_surfaces.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_interpolated_field.py b/tests/test_interpolated_field.py index cd1d96b..147dc44 100644 --- a/tests/test_interpolated_field.py +++ b/tests/test_interpolated_field.py @@ -49,7 +49,7 @@ def quadratic_cartesian_field(xyz: jnp.ndarray) -> jnp.ndarray: def make_grid(rr=(0.4, 1.2, 4), ph=(0.0, math.pi/2, 3), zz=(-0.5, 0.5, 4)): - return GridSpec(rrange=rr, phi_range=ph, z_range=zz, value_size=3) + return GridSpec(r_range=rr, phi_range=ph, z_range=zz, value_size=3) # Skip-function that masks a thin inner cylinder (r < rmin+0.05) diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py index 8a9ff1e..33b9423 100644 --- a/tests/test_surfaces.py +++ b/tests/test_surfaces.py @@ -265,8 +265,8 @@ def test_classifier_jit_smoke(): [9.2, 20.0, 0.1]]) # JIT both methods: - f1 = jax.jit(sc.evaluate_xyz, static_argnames=("self",)) - f2 = jax.jit(sc.evaluate_rphiz, static_argnames=("self",)) + f1 = sc.evaluate_xyz + f2 = sc.evaluate_rphiz out1 = f1(xyz) out2 = f2(rphiz) From b243bfea702969fe0aba655dbbbb9c2f0378fa3d Mon Sep 17 00:00:00 2001 From: eduardolneto Date: Wed, 15 Oct 2025 04:32:54 +0000 Subject: [PATCH 6/6] Added changes to surfaces.py for fixing some changes done for the classifier. It now uses different methods for cpu and for gpu. Also changed use_fundamental_phi default value to False --- essos/surfaces.py | 292 +++++++++++++++++++++++++++++++++------------- 1 file changed, 210 insertions(+), 82 deletions(-) diff --git a/essos/surfaces.py b/essos/surfaces.py index 4b8673c..f12b15c 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -1,12 +1,13 @@ from functools import partial import jax.numpy as jnp from jax.scipy.interpolate import RegularGridInterpolator -from jax import jit, vmap, devices, device_put, block_until_ready +from jax import jit, vmap, devices, device_put, block_until_ready, default_backend from jax.sharding import Mesh, NamedSharding, PartitionSpec from essos.plot import fix_matplotlib_3d import jaxkd import time import numpy as np + try: from scipy.spatial import cKDTree except ImportError: @@ -343,7 +344,7 @@ class SurfaceClassifier: - outside """ - def __init__(self, surface, h=0.05, use_fundamental_phi=True): + def __init__(self, surface, h=0.05, use_fundamental_phi=False): """ Args: surface: SurfaceRZFourier @@ -386,92 +387,219 @@ def __init__(self, surface, h=0.05, use_fundamental_phi=True): print(f"[SC] ranges: r=({rmin:.3f},{rmax:.3f}) phi=({phimin:.3f},{phimax:.3f}) z=({zmin:.3f},{zmax:.3f})") print(f"[SC] grid sizes: nr={nr}, nphi={nphi}, nz={nz} -> total={nr*nphi*nz:,d} nodes") - # ------------------------- - # Precompute KD-tree once - # ------------------------- - t_tree = time.perf_counter() - gammas_flat = gammas.reshape((-1, 3)) - normals_flat = surface.unitnormal.reshape((-1, 3)) - - self._tree = jaxkd.build_tree(gammas_flat) - # Sign convention (interior point): - a_point = jnp.mean(surface.gamma[0, :, :], axis=0) - sign_of_interiorpoint = jnp.sign(jnp.sum((a_point - gammas_flat[0, :]) * normals_flat[0, :])) - self._sign = float(sign_of_interiorpoint) - - print(f"[SC] KD-tree build: {time.perf_counter() - t_tree:.2f}s " - f"(nodes={gammas_flat.shape[0]:,d})") - - # ------------------------- - # Build (r,phi,z) grid - # ------------------------- - t_grid = time.perf_counter() - r_list = jnp.linspace(rmin, rmax, nr) - phi_list = jnp.linspace(phimin, phimax, nphi) - z_list = jnp.linspace(zmin, zmax, nz) - - # Mesh in 'ij' so r varies slowest, z fastest when flattened - RR, PP, ZZ = jnp.meshgrid(r_list, phi_list, z_list, indexing="ij") # each (nr, nphi, nz) - Ntot = nr * nphi * nz - - # Convert to Cartesian for nearest-neighbor query: - XX = RR * jnp.cos(PP) - YY = RR * jnp.sin(PP) - xyz_grid = jnp.stack([XX, YY, ZZ], axis=-1).reshape((Ntot, 3)) - - print(f"[SC] grid gen: {time.perf_counter() - t_grid:.2f}s; xyz_grid shape={tuple(xyz_grid.shape)}") - - # Build SciPy KD-tree on CPU (fast) - t_query = time.perf_counter() - tree = cKDTree(np.asarray(gammas_flat)) # (Ng, 3) - dist, idxs = tree.query(np.asarray(xyz_grid), k=1, workers=-1) # (Ntot,), (Ntot,) - nearest_pts = gammas_flat[np.asarray(idxs)] # jnp will accept np indexing - nearest_normals = normals_flat[np.asarray(idxs)] - # signed distance to tangent plane - d_plane = jnp.sum((xyz_grid - nearest_pts) * nearest_normals, axis=-1) # (Ntot,) - signed = self._sign * d_plane - field_vals = signed.reshape((nr, nphi, nz)) - _ = block_until_ready(field_vals) - print(f"[SC] KD query+dist: {time.perf_counter() - t_query:.2f}s (SciPy cKDTree)") + if default_backend() == 'cpu': + #This method is likely faster on CPU, since SciPy cKDTree is implemented in C + #but is extremely memory intensive, since it requires to store the full grid and + #the full distance field in memory at once. + # ------------------------- + # Precompute KD-tree once + # ------------------------- + t_tree = time.perf_counter() + gammas_flat = gammas.reshape((-1, 3)) + normals_flat = surface.unitnormal.reshape((-1, 3)) + + self._tree = jaxkd.build_tree(gammas_flat) + # Sign convention (interior point): + a_point = jnp.mean(surface.gamma[0, :, :], axis=0) + sign_of_interiorpoint = jnp.sign(jnp.sum((a_point - gammas_flat[0, :]) * normals_flat[0, :])) + self._sign = float(sign_of_interiorpoint) + + print(f"[SC] KD-tree build: {time.perf_counter() - t_tree:.2f}s " + f"(nodes={gammas_flat.shape[0]:,d})") + + # ------------------------- + # Build (r,phi,z) grid + # ------------------------- + t_grid = time.perf_counter() + r_list = jnp.linspace(rmin, rmax, nr) + phi_list = jnp.linspace(phimin, phimax, nphi) + z_list = jnp.linspace(zmin, zmax, nz) + + # Mesh in 'ij' so r varies slowest, z fastest when flattened + RR, PP, ZZ = jnp.meshgrid(r_list, phi_list, z_list, indexing="ij") # each (nr, nphi, nz) + Ntot = nr * nphi * nz + + # Convert to Cartesian for nearest-neighbor query: + XX = RR * jnp.cos(PP) + YY = RR * jnp.sin(PP) + xyz_grid = jnp.stack([XX, YY, ZZ], axis=-1).reshape((Ntot, 3)) + + print(f"[SC] grid gen: {time.perf_counter() - t_grid:.2f}s; xyz_grid shape={tuple(xyz_grid.shape)}") + + + + # Build SciPy KD-tree on CPU (fast) + t_query = time.perf_counter() + tree = cKDTree(np.asarray(gammas_flat)) # (Ng, 3) + dist, idxs = tree.query(np.asarray(xyz_grid), k=1, workers=-1) # (Ntot,), (Ntot,) + nearest_pts = gammas_flat[np.asarray(idxs)] # jnp will accept np indexing + nearest_normals = normals_flat[np.asarray(idxs)] + # signed distance to tangent plane + d_plane = jnp.sum((xyz_grid - nearest_pts) * nearest_normals, axis=-1) # (Ntot,) + signed = self._sign * d_plane + field_vals = signed.reshape((nr, nphi, nz)) + _ = block_until_ready(field_vals) + print(f"[SC] KD query+dist: {time.perf_counter() - t_query:.2f}s (SciPy cKDTree)") + + # ------------------------- + # Build RGI + # ------------------------- + t_rgi = time.perf_counter() + self._r_list = r_list + self._phi_list = phi_list + self._z_list = z_list + + # fill_value < 0.0 => "outside" by default beyond bounds + self.dist = RegularGridInterpolator( + (r_list, phi_list, z_list), field_vals, fill_value=-1.0 + ) + + print(f"[SC] RGI build: {time.perf_counter() - t_rgi:.2f}s") + print(f"[SC] init done in {time.perf_counter() - t0:.2f}s total") + + elif default_backend() == 'gpu': + # #from scipy.spatial import KDTree ##better for cpu? + # tree = jaxkd.build_tree(gammas_flat) # on device + # mins, _ = jaxkd.query_neighbors(tree, xyz_grid, k=1) + # nearest_normals = normals_flat[mins] + # nearest_pts = gammas_flat[mins] + # d_plane = jnp.sum((xyz_grid - nearest_pts) * nearest_normals, axis=-1) # (Ntot,) + # signed = self._sign * d_plane + # field_vals = signed.reshape((nr, nphi, nz)) + # _ = block_until_ready(field_vals) + + + t_grid = time.perf_counter() + r_list = jnp.linspace(rmin, rmax, nr) + phi_list = jnp.linspace(phimin, phimax, nphi) + z_list = jnp.linspace(zmin, zmax, nz) + + + # ------------------------- + # Build SciPy KD-tree on GPU (fast) + t_query = time.perf_counter() + + def fbatch(rs, phis, zs): + xyz = jnp.zeros(( 3)) + xyz=xyz.at[0].set( rs * jnp.cos(phis)) + xyz=xyz.at[1].set(rs * jnp.sin(phis)) + xyz=xyz.at[2].set(zs) + return signed_distance_from_surface_jax(xyz, surface) + #return signed_distance_from_surface_extras(xyz, surface) ####memory bounded + + + # ------------------------- + # Build RGI + # ------------------------- + t_rgi = time.perf_counter() + self._r_list = r_list + self._phi_list = phi_list + self._z_list = z_list + + #self.dist = RegularGridInterpolator((self._r_list, self._phi_list, self._z_list), + # vmap(vmap(vmap(fbatch,in_axes=(None,None,0)),in_axes=(None,0,None)),in_axes=(0,None,None)) + # (self._r_list, self._phi_list, self._z_list),fill_value=-1.) + self.dist = RegularGridInterpolator((jnp.linspace(rmin,rmax,nr), + jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz)), + vmap(vmap(vmap(fbatch,in_axes=(None,None,0)),in_axes=(None,0,None)),in_axes=(0,None,None))(jnp.linspace(rmin,rmax,nr), + jnp.linspace(0., 2*jnp.pi, nphi), jnp.linspace(zmin, zmax, nz)),fill_value=-1.) + + + + print(f"[SC] RGI build: {time.perf_counter() - t_rgi:.2f}s") + print(f"[SC] init done in {time.perf_counter() - t0:.2f}s total") + + + + if default_backend() == 'gpu': + @partial(jit, static_argnames=['self']) + def evaluate_xyz(self, xyz): + rphiz = jnp.zeros_like(xyz) + rphiz=rphiz.at[0].set(jnp.linalg.norm(xyz[:2])) + rphiz=rphiz.at[1].set(self._wrap_phi(jnp.mod(jnp.arctan2(xyz[1], xyz[0]), 2*jnp.pi))) + #rphiz=rphiz.at[1].set(jnp.mod(jnp.arctan2(xyz[1], xyz[0]), 2*jnp.pi)) + rphiz=rphiz.at[2].set(xyz.at[2].get()) + # initialize to -1 since the regular grid interpolant will just keep + # that value when evaluated outside of bounds + d=self.dist(rphiz)[0][0] + return d + + @partial(jit, static_argnames=['self']) + def evaluate_rphiz(self, rphiz): + # initialize to -1 since the regular grid interpolant will just keep + # that value when evaluated outside of bounds + d=self.dist(rphiz)[0][0] + return d + elif default_backend() == 'cpu': # ------------------------- - # Build RGI + # Vectorized signed-distance API (XYZ) # ------------------------- - t_rgi = time.perf_counter() - self._r_list = r_list - self._phi_list = phi_list - self._z_list = z_list - - # fill_value < 0.0 => "outside" by default beyond bounds - self.dist = RegularGridInterpolator( - (r_list, phi_list, z_list), field_vals, fill_value=-1.0 - ) - - print(f"[SC] RGI build: {time.perf_counter() - t_rgi:.2f}s") - print(f"[SC] init done in {time.perf_counter() - t0:.2f}s total") - - # ------------------------- - # Vectorized signed-distance API (XYZ) - # ------------------------- - @staticmethod - def _xyz_to_rphiz(xyz): - """xyz: (...,3) -> rphiz: (...,3)""" - x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] - r = jnp.sqrt(x * x + y * y) - phi = jnp.mod(jnp.arctan2(y, x), 2 * jnp.pi) - return jnp.stack([r, phi, z], axis=-1) + @staticmethod + def _xyz_to_rphiz(xyz): + """xyz: (...,3) -> rphiz: (...,3)""" + x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] + r = jnp.sqrt(x * x + y * y) + phi = jnp.mod(jnp.arctan2(y, x), 2 * jnp.pi) + return jnp.stack([r, phi, z], axis=-1) + + + @partial(jit, static_argnames=['self']) + def evaluate_xyz(self, xyz): + rphiz = self._xyz_to_rphiz(xyz) + rphiz = rphiz.at[..., 1].set(self._wrap_phi(rphiz[..., 1])) + return self.dist(rphiz) + + @partial(jit, static_argnames=['self']) + def evaluate_rphiz(self, rphiz): + rphiz = rphiz.at[..., 1].set(self._wrap_phi(rphiz[..., 1])) + return self.dist(rphiz) def _wrap_phi(self, phi): period = 2 * jnp.pi / max(1, int(getattr(self.surface, "nfp", 1))) return jnp.mod(phi, period) - @partial(jit, static_argnames=['self']) - def evaluate_xyz(self, xyz): - rphiz = self._xyz_to_rphiz(xyz) - rphiz = rphiz.at[..., 1].set(self._wrap_phi(rphiz[..., 1])) - return self.dist(rphiz) - @partial(jit, static_argnames=['self']) - def evaluate_rphiz(self, rphiz): - rphiz = rphiz.at[..., 1].set(self._wrap_phi(rphiz[..., 1])) - return self.dist(rphiz) \ No newline at end of file +partial(jit, static_argnames=['surface']) +def signed_distance_from_surface_jax(xyz, surface): + """ + Compute the signed distances from points ``xyz`` to a surface. The sign is + positive for points inside the volume surrounded by the surface. + """ + gammas = surface.gamma.reshape((-1, 3)) + #from scipy.spatial import KDTree ##better for cpu? + tree = jaxkd.build_tree(gammas) + mins, _ = jaxkd.query_neighbors(tree, xyz, k=1) + n = surface.unitnormal.reshape((-1, 3)) + nmins = n[mins] + gammamins = gammas[mins] + # Now that we have found the closest node, we approximate the surface with + # a plane through that node with the appropriate normal and then compute + # the distance from the point to that plane + # https://stackoverflow.com/questions/55189333/how-to-get-distance-from-point-to-plane-in-3d + mindist = jnp.sum((xyz-gammamins) * nmins, axis=1) + a_point_in_the_surface = jnp.mean(surface.gamma[0, :, :], axis=0) + sign_of_interiorpoint = jnp.sign(jnp.sum((a_point_in_the_surface-gammas[0, :])*n[0, :])) + signed_dists = mindist * sign_of_interiorpoint + return signed_dists + +#@partial(jit, static_argnames=['surface']) +def signed_distance_from_surface_extras(xyz, surface): + """ + Compute the signed distances from points ``xyz`` to a surface. The sign is + positive for points inside the volume surrounded by the surface. + """ + gammas = surface.gamma.reshape((-1, 3)) + mins, _ = jaxkd.extras.query_neighbors_pairwise(gammas, xyz, k=1) + n = surface.unitnormal.reshape((-1, 3)) + nmins = n[mins] + gammamins = gammas[mins] + # Now that we have found the closest node, we approximate the surface with + # a plane through that node with the appropriate normal and then compute + # the distance from the point to that plane + # https://stackoverflow.com/questions/55189333/how-to-get-distance-from-point-to-plane-in-3d + mindist = jnp.sum((xyz-gammamins) * nmins, axis=1) + a_point_in_the_surface = jnp.mean(surface.gamma[0, :, :], axis=0) + sign_of_interiorpoint = jnp.sign(jnp.sum((a_point_in_the_surface-gammas[0, :])*n[0, :])) + signed_dists = mindist * sign_of_interiorpoint + return signed_dists