Skip to content
Open

QFM #27

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions essos/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def B(self, points):
dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy")
return jnp.mean(dB_sum, axis=0)

@partial(jit, static_argnames=['self'])
def A(self, points):
dif_R = (jnp.array(points)-self.gamma)
dA = self.gamma_dash / jnp.linalg.norm(dif_R, axis=-1, keepdims=True)
A_vec = jnp.sum(
jnp.mean(self.currents[:, None, None] * dA * 1e-7, axis=1),
axis=0
)

return A_vec

@partial(jit, static_argnames=['self'])
def B_covariant(self, points):
return self.B(points)
Expand Down
3 changes: 0 additions & 3 deletions essos/multiobjectiveoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from functools import partial
from essos.coils import Coils, Curves, CreateEquallySpacedCurves
from essos.fields import BiotSavart
import numpy as np
from pandas.plotting import parallel_coordinates
import pandas as pd


class MultiObjectiveOptimizer:
Expand Down
221 changes: 221 additions & 0 deletions essos/qfm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import jax
from jax import vmap, grad, device_get
import jax.numpy as jnp
from jaxopt import LBFGS
from essos.surfaces import SurfaceRZFourier
from scipy.optimize import minimize


class QfmSurface:
def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float = None,
toroidal_flux_idx: int = 0):
assert label in ["area", "volume", "toroidal_flux"], f"Unsupported label: {label}"

self.field = field
self.surface = surface
self.surface_optimize = self._build_surface_with_x(surface, surface.x)
self.label = label
self.toroidal_flux_idx = int(toroidal_flux_idx)
self.name = str(id(self))

if targetlabel is None:
self.targetlabel = {
"volume": surface.volume,
"area": surface.area,
"toroidal_flux": self._toroidal_flux(surface)
}[label]
else:
self.targetlabel = targetlabel

def _toroidal_flux(self, surf: SurfaceRZFourier):
curve = surf.gamma[self.toroidal_flux_idx]
dl = jnp.roll(curve, -1, axis=0) - curve
A_vals = vmap(self.field.A)(curve)
return jnp.sum(jnp.sum(A_vals * dl, axis=1))

def _build_surface_with_x(self, surface, x):
rc_safe = device_get(surface.rc)
zs_safe = device_get(surface.zs)
x_safe = device_get(x)

s = SurfaceRZFourier(
rc=rc_safe,
zs=zs_safe,
nfp=int(surface.nfp),
ntheta=int(surface.ntheta),
nphi=int(surface.nphi),
range_torus=surface.range_torus,
close=True
)
s.x = x_safe
return s

def objective(self, x):
surf = self.surface_optimize
x_old = surf.x
surf.x = x
N = surf.unitnormal
norm_N = jnp.linalg.norm(surf.normal, axis=2)
points = surf.gamma.reshape(-1, 3)
B = vmap(self.field.B)(points).reshape(N.shape)
B_n = jnp.sum(B * N, axis=2)
norm_B = jnp.linalg.norm(B, axis=2)
value = jnp.sum(B_n**2 * norm_N) / jnp.sum(norm_B**2 * norm_N)
surf.x = x_old
return value

def constraint(self, x):
surf = self.surface_optimize
x_old = surf.x
surf.x = x

raw_c = {
"volume": surf.volume - self.targetlabel,
"area": surf.area - self.targetlabel,
"toroidal_flux": self._toroidal_flux(surf) - self.targetlabel
}[self.label]

c = raw_c / jnp.abs(self.targetlabel)

surf.x = x_old
return c

def penalty_objective(self, x, constraint_weight=1.0):
r = self.objective(x)
c = self.constraint(x)
return r + 0.5 * constraint_weight * c**2

def _callback(self, info, printlog=True):
if isinstance(info, dict):
# LBFGS
it = info.get("iter", -1)
r = info["objective"]
c = info["constraint"]
penalty = info["penalty"]
grad_norm = info["grad_norm"]

# Print logs if printlog is True
if printlog:
print(f"[LBFGS iter {it}] objective={r:.6e} constraint={c:.3e} "
f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}")
else:
# SLSQP
it = getattr(self, "_slsqp_iter", 0) + 1
setattr(self, "_slsqp_iter", it)

obj = float(self.objective(info))
cst = float(self.constraint(info))
penalty = float(self.penalty_objective(info))
grad_norm = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z))(info)))

# Print logs if printlog is True
if printlog:
print(f"[SLSQP iter {it}] objective={obj:.6e} constraint={cst:.3e} "
f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}")


def minimize_lbfgs(self, x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e4,
printlog=True, **kwargs):
x0 = self.surface_optimize.x if x0 is None else x0

# ---------- Define objective function, return scalar + aux dict (all use jnp.array) ----------
def fn(x):
value = self.penalty_objective(x, constraint_weight)
aux = {
"objective": self.objective(x),
"constraint": self.constraint(x),
"penalty": value
}
return value, aux

solver = LBFGS(fun=fn, maxiter=maxiter, tol=tol, has_aux=True)
state = solver.init_state(x0)

trace = []
x = x0
for k in range(maxiter):
x, state = solver.update(x, state)

info = {key: device_get(v) if isinstance(v, jnp.ndarray) else v for key, v in state.aux.items()}
info["iter"] = k + 1
info["grad_norm"] = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z, constraint_weight))(x)))
info["error"] = float(state.error)

# Ensure we call _callback for logging every step if printlog is True
self._callback(info, printlog)

if state.error <= tol:
break

x_safe = device_get(x) # Move back to host
self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe)

return {
"fun": float(self.penalty_objective(x, constraint_weight)),
"gradient": jnp.array(grad(lambda z: self.penalty_objective(z, constraint_weight))(x)),
"iter": k + 1,
"info": state,
"success": state.error <= tol,
"s": self.surface_optimize,
}

def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, printlog=True, **kwargs):
x0 = jnp.array(self.surface_optimize.x if x0 is None else x0)

# Run the SLSQP optimizer
res = minimize(
fun=lambda x: float(self.objective(x)),
x0=x0,
method="SLSQP",
constraints={"type": "eq", "fun": lambda x: float(self.constraint(x))},
tol=tol,
options={"maxiter": maxiter, "disp": False},
callback=lambda x: self._callback(x, printlog) # Use internal callback directly
)

# Store the optimized x in the surface
x_safe = device_get(res.x)
self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe)

# Return the result with optimization trace
return {
"fun": res.fun,
"gradient": jnp.array(jax.grad(self.objective)(res.x)),
"iter": res.nit,
"info": res,
"success": res.success,
"s": self.surface_optimize
}

def run(
self,
method: str = "SLSQP",
tol: float = 1e-6,
maxiter: int = 1000,
x0=None,
constraint_weight: float = 1e-3,
printlog: bool = True,
**kwargs
):

method_up = method.upper()

if method_up == "SLSQP":
return self.minimize_slsqp(
x0=x0,
tol=tol,
maxiter=maxiter,
printlog=printlog,
**kwargs
)
elif method_up == "LBFGS":
return self.minimize_lbfgs(
x0=x0,
tol=tol,
maxiter=maxiter,
constraint_weight=constraint_weight,
printlog=printlog,
**kwargs
)
else:
raise ValueError(f"Unknown method '{method}'")
93 changes: 92 additions & 1 deletion essos/surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
mesh = Mesh(devices(), ("dev",))
sharding = NamedSharding(mesh, PartitionSpec("dev", None))

@partial(jit, static_argnames=['surface','field'])
def toroidal_flux(surface, field, idx=0) -> jnp.ndarray:
gamma = surface.gamma
curve = gamma[idx, :, :]
dl = jnp.roll(curve, -1, axis=0) - curve
A_vals = vmap(field.A)(curve)
Adl = jnp.sum(A_vals * dl, axis=1)
tf = jnp.sum(Adl)
return tf

@partial(jit, static_argnames=['surface','field'])
def B_on_surface(surface, field):
ntheta = surface.ntheta
Expand Down Expand Up @@ -235,7 +245,88 @@ def x(self):
@x.setter
def x(self, new_dofs):
self.dofs = new_dofs


@property
def volume(self):

xyz = self.gamma # shape: (nphi, ntheta, 3)
n = self.normal # shape: (nphi, ntheta, 3)

integrand = jnp.sum(xyz * n, axis=2) # dot(x, n), shape: (nphi, ntheta)
volume = jnp.mean(integrand) / 3.0
return volume

@property
def area(self):
n = self.normal # shape: (nphi, ntheta, 3)
norm_n = jnp.linalg.norm(n, axis=2)

dphi = 2 * jnp.pi / self.nphi
dtheta = 2 * jnp.pi / self.ntheta

area = jnp.sum(norm_n) * dphi * dtheta
return area


def change_resolution(self, mpol: int, ntor: int):
"""
Change the values of `mpol` and `ntor`.
New Fourier coefficients are zero by default.
Old coefficients outside the new range are discarded.
"""
rc_old, zs_old = self.rc, self.zs
mpol_old, ntor_old = self.mpol, self.ntor

rc_new = jnp.zeros((mpol, 2 * ntor + 1))
zs_new = jnp.zeros((mpol, 2 * ntor + 1))

m_keep = min(mpol_old, mpol)
n_keep = min(ntor_old, ntor)

old_slice = slice(ntor_old - n_keep, ntor_old + n_keep + 1)
new_slice = slice(ntor - n_keep, ntor + n_keep + 1)

# Copy overlapping region
rc_new = rc_new.at[:m_keep, new_slice].set(rc_old[:m_keep, old_slice])
zs_new = zs_new.at[:m_keep, new_slice].set(zs_old[:m_keep, old_slice])

# Update attributes
self.mpol, self.ntor = mpol, ntor
self.rc, self.zs = rc_new, zs_new

# Recompute xm/xn and interpolation arrays
m1d = jnp.arange(self.mpol)
n1d = jnp.arange(-self.ntor, self.ntor + 1)
n2d, m2d = jnp.meshgrid(n1d, m1d)
self.xm = m2d.flatten()[self.ntor:]
self.xn = self.nfp * n2d.flatten()[self.ntor:]

indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T
self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]]
self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]]

# Update degrees of freedom
self.num_dofs_rc = len(jnp.ravel(self.rc)[self.ntor:])
self.num_dofs_zs = len(jnp.ravel(self.zs)[self.ntor:])
self._dofs = jnp.concatenate(
(jnp.ravel(self.rc)[self.ntor:], jnp.ravel(self.zs)[self.ntor:])
)

# Recompute angles and geometry
self.angles = (
jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)
- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)
)
(self._gamma, self._gammadash_theta, self._gammadash_phi,
self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp)

# Recompute AbsB if available
if hasattr(self, 'bmnc'):
self._AbsB = self._set_AbsB()

return self


def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs):
if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.")

Expand Down
1 change: 1 addition & 0 deletions examples/input_files/stellarator_coils.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nfp": 2, "stellsym": true, "order": 3, "n_segments": 45, "dofs_curves": [[[11.622700935298976, 0.021162524110782358, 4.969310302793117, 0.21619524408779187, 1.0480524616370408, 0.28649150353772795, -0.28257357882380735], [2.650045608829704, 1.3195115302904592, 2.0655255644324324, -0.9222231514058897, -0.6185897653334609, -0.029257891517016514, -0.4689941563806397], [0.5736288312625168, -6.124609285728119, -0.25270656488797033, -0.7494953058315302, -0.5577484125670837, 0.41280301860103114, 0.49815152814320796]], [[8.270668252860887, 0.3665817338447269, 3.1105370020049437, -1.1258699831006203, 1.3263614203440663, 0.2102848349841804, -0.49142766630482954], [7.030418225942375, 0.6256285318976378, 4.821568953388758, -0.3255223631956292, -0.5563748261863908, 0.5699454535195534, 0.4766999014677532], [1.1730264767378422, -5.641665411926324, -0.3776722058863022, -0.22889203846752135, -0.37533249101718297, -0.13615989719970903, 0.2925684224469988]], [[2.870219295136962, 1.011171856312114, 1.0052842269429074, -1.616205601094772, 0.4568093873850295, 0.8322530493441007, -0.47523615678931685], [9.19851171039639, 0.15301426329616008, 6.079831607393782, -0.1245216794618166, -0.5264330665117136, 0.028667523169553216, 0.4360349874253642], [0.5149651564138196, -5.068527850344334, -0.2765171026569377, -0.4083657568321486, -0.061669206943173696, -0.01678957665952365, 0.08107054639495188]]], "dofs_currents": [0.9714976808550443, 1.0086510674635079, 1.019851251681448]}
12 changes: 6 additions & 6 deletions examples/optimize_coils_vmec_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
max_coil_curvature = 1.0
order_Fourier_series_coils = 3
number_coil_points = order_Fourier_series_coils*15
maximum_function_evaluations = 50
maximum_function_evaluations = 500
number_coils_per_half_field_period = 3
tolerance_optimization = 1e-5
ntheta=35
Expand Down Expand Up @@ -67,11 +67,11 @@
plt.tight_layout()
plt.show()

# # Save the coils to a json file
# coils_optimized.to_json("stellarator_coils.json")
# # Load the coils from a json file
# from essos.coils import Coils_from_json
# coils = Coils_from_json("stellarator_coils.json")
# Save the coils to a json file
coils_optimized.to_json("input_files/stellarator_coils.json")
# Load the coils from a json file
from essos.coils import Coils_from_json
coils = Coils_from_json("input_files/stellarator_coils.json")

# # Save results in vtk format to analyze in Paraview
# from essos.fields import BiotSavart
Expand Down
Loading