From 0dbb8bb8f28fb90ad8aec8b76483cec67f50af7a Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Fri, 15 Aug 2025 21:14:44 +0100 Subject: [PATCH 01/18] Refine example of multi objective optimization to run with more trials but each trial is faster --- examples/optimize_multiple_objectives.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/optimize_multiple_objectives.py b/examples/optimize_multiple_objectives.py index f38ed8e..78b09d6 100644 --- a/examples/optimize_multiple_objectives.py +++ b/examples/optimize_multiple_objectives.py @@ -17,12 +17,12 @@ "max_coil_curvature": 0.0, }, opt_config={ - "n_trials": 2, - "maximum_function_evaluations": 300, + "n_trials": 20, + "maximum_function_evaluations": 50, "tolerance_optimization": 1e-5, "optimizer_choices": ["adam", "amsgrad", "sgd"], "num_coils": 4, - "order_Fourier": 6, + "order_Fourier": 3, } ) @@ -46,10 +46,7 @@ print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") - - - manager.plot_pareto_fronts(z_thresh=3, save= True) manager.plot_optimization_history(z_thresh=3, save= True) manager.plot_param_importances(save= True) -manager.plot_parallel_coordinates(z_thresh=3, save= True) +# manager.plot_parallel_coordinates(z_thresh=3, save= True) From ae5904c3c9158efdee38d0bea01cda2c5aa8b2aa Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Mon, 18 Aug 2025 11:38:05 -0500 Subject: [PATCH 02/18] add qfm --- essos/fields.py | 21 +- essos/qfm.py | 272 ++++++++++++++++++++++++ essos/surfaces.py | 20 +- examples/optimize_coils_vmec_surface.py | 10 +- examples/optimize_qfm_surface.py | 69 ++++++ 5 files changed, 385 insertions(+), 7 deletions(-) create mode 100644 essos/qfm.py create mode 100644 examples/optimize_qfm_surface.py diff --git a/essos/fields.py b/essos/fields.py index 0789d2a..214a354 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -27,7 +27,26 @@ def B(self, points): dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") return jnp.mean(dB_sum, axis=0) - + + @partial(jit, static_argnames=['self']) + def B_field(self, points): + points = jnp.array(points) # (Nphi, Ntheta, 3) + points_flat = points.reshape(-1, 3) # (Npoints, 3) + + gamma_flat = self.gamma.reshape(-1, 3) # (Ngamma, 3) + gamma_dash_flat = self.gamma_dash.reshape(-1, 3) # (Ngamma, 3) + currents_flat = jnp.repeat(self.currents, self.gamma.shape[1]) # shape = (Ngamma,) + + def B_at_point(p): + dif_R = p - gamma_flat # (Ngamma, 3) + norm = jnp.linalg.norm(dif_R, axis=1) + dB = jnp.cross(gamma_dash_flat, dif_R) / norm[:, None]**3 # (Ngamma, 3) + return jnp.einsum("i,ij->j", 1e-7 * currents_flat, dB) # (3,) + + B_flat = vmap(B_at_point)(points_flat) # (Npoints, 3) + return B_flat # (Nphi, Ntheta, 3) B_flat.reshape(points.shape) + + @partial(jit, static_argnames=['self']) def B_covariant(self, points): return self.B(points) diff --git a/essos/qfm.py b/essos/qfm.py new file mode 100644 index 0000000..3b8e811 --- /dev/null +++ b/essos/qfm.py @@ -0,0 +1,272 @@ +import jax +import jax.numpy as jnp +from jaxopt import LBFGS, ScipyMinimize +from scipy.optimize import minimize +import optax +from essos.surfaces import SurfaceRZFourier + +class QfmSurface: + def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float): + assert label in ["area", "volume"], f"Unsupported label: {label}" + self.field = field + self.surface = surface + self.surface_optimize = self._with_x(surface, surface.x) + self.label = label + self.targetlabel = targetlabel + self.name = str(id(self)) + + def _with_x(self, surface: SurfaceRZFourier, x): + s = SurfaceRZFourier( + rc=surface.rc, + zs=surface.zs, + nfp=surface.nfp, + ntheta=surface.ntheta, + nphi=surface.nphi, + range_torus=surface.range_torus, + close=True + ) + s.x = x + return s + + def objective(self, x): + surf = self._with_x(self.surface_optimize, x) + N = surf.unitnormal + norm_N = jnp.linalg.norm(surf.normal, axis=2) + B = self.field.B_field(surf.gamma).reshape(N.shape) + B_n = jnp.sum(B * N, axis=2) + norm_B = jnp.linalg.norm(B, axis=2) + result = jnp.sum(B_n**2 * norm_N) / jnp.sum(norm_B**2 * norm_N) + return result + + def constraint(self, x): + surf = self._with_x(self.surface_optimize, x) + if self.label == "volume": + return surf.volume - self.targetlabel + elif self.label == "area": + return surf.area - self.targetlabel + else: + raise ValueError(f"Unsupported label: {self.label}") + + def penalty_objective(self, x, constraint_weight=1.0): + r = self.objective(x) + c = self.constraint(x) + result = r + 0.5 * constraint_weight * c**2 + return jnp.asarray(result), None + + def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): + value_and_grad_fn = jax.value_and_grad( + lambda x: self.penalty_objective(x, constraint_weight), + has_aux=True + ) + solver = LBFGS( + fun=value_and_grad_fn, + value_and_grad=True, + has_aux=True, + implicit_diff=False, + tol=tol, + maxiter=maxiter + ) + x0 = self.surface_optimize.x + res = solver.run(x0) + self.surface_optimize = self._with_x(self.surface_optimize, res.params) + return { + "fun": res.state.value, + "gradient": jax.grad(lambda x: self.penalty_objective(x, constraint_weight)[0])(res.params), + "iter": res.state.iter_num, + "info": res.state, + "success": res.state.error <= tol, + "s": self.surface_optimize, + } + + + def minimize_penalty_scipy_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): + fun = lambda x: jnp.asarray(self.penalty_objective(x, constraint_weight)[0]).item() + grad = lambda x: jax.grad(lambda x_: self.penalty_objective(x_, constraint_weight)[0])(x) + x0 = self.surface_optimize.x + res = minimize( + fun=fun, x0=jnp.array(x0), jac=grad, + method='L-BFGS-B', tol=tol, options={"maxiter": maxiter} + ) + self.surface_optimize = self._with_x(self.surface_optimize, res.x) + return { + "fun": res.fun, + "gradient": grad(res.x), + "iter": res.nit, + "info": res, + "success": res.success, + "s": self.surface_optimize, + } + + def minimize_penalty_slsqp(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): + fun = lambda x: self.penalty_objective(x, constraint_weight)[0] + grad = jax.grad(fun) + + solver = ScipyMinimize( + fun=fun, + method="SLSQP", + tol=tol, + maxiter=maxiter + ) + + x0 = self.surface_optimize.x + res = solver.run(x0) + self.surface_optimize = self._with_x(self.surface_optimize, res.params) + + # 安全获取迭代次数 + iter_count = getattr(res.state, "num_iters", None) + if iter_count is None: + iter_count = getattr(res.state, "maxiter", -1) + + return { + "fun": res.state.fun_val, + "gradient": grad(res.params), + "iter": iter_count, + "info": res.state, + "success": getattr(res.state, "status", 0) == 0, + "s": self.surface_optimize, + } + + + def minimize_exact_SLSQP(self, tol=1e-3, maxiter=1000): + loss_fn = lambda x: self.objective(x) + constraint_fn = lambda x: self.constraint(x) + grad_loss = jax.grad(loss_fn) + dcon = jax.grad(constraint_fn) + solver = ScipyMinimize( + fun=loss_fn, + method="SLSQP", + constraints=[{"type": "eq", "fun": constraint_fn, "jac": dcon}], + tol=tol, + options={"maxiter": maxiter} + ) + x0 = self.surface_optimize.x + res = solver.run(x0) + self.surface_optimize = self._with_x(self.surface_optimize, res.params) + return { + "fun": res.state.fun_val, + "gradient": grad_loss(res.params), + "iter": res.state.nit, + "info": res.state, + "success": res.state.status == 0, + "s": self.surface_optimize, + } + + def minimize_exact_scipy_slsqp(self, tol=1e-3, maxiter=1000): + fun = lambda x: jnp.asarray(self.objective(x)).item() + jac = lambda x: jnp.asarray(jax.grad(self.objective)(x)) + con_fun = lambda x: jnp.asarray(self.constraint(x)).item() + con_jac = lambda x: jnp.asarray(jax.grad(self.constraint)(x)) + constraints = [{"type": "eq", "fun": con_fun, "jac": con_jac}] + x0 = self.surface_optimize.x + res = minimize( + fun=fun, x0=jnp.array(x0), jac=jac, + constraints=constraints, method='SLSQP', + tol=tol, options={"maxiter": maxiter} + ) + self.surface_optimize = self._with_x(self.surface_optimize, res.x) + return { + "fun": res.fun, + "gradient": jac(res.x), + "iter": res.nit, + "info": res, + "success": res.success, + "s": self.surface_optimize, + } + + # ⬅️ 新增 + +# ========== 新增方法:构造 optax 优化器 ========== + def _build_optax_optimizer(self, method: str, lr: float): + m = method.strip().lower() + if m == 'adam': return optax.adam(lr) + if m == 'adamw': return optax.adamw(lr) + if m == 'sgd': return optax.sgd(lr) + if m == 'momentum': return optax.sgd(lr, momentum=0.9) + if m == 'nesterov': return optax.sgd(lr, momentum=0.9, nesterov=True) + if m == 'rmsprop': return optax.rmsprop(lr) + if m == 'adagrad': return optax.adagrad(lr) + if m == 'adafactor': return optax.adafactor(learning_rate=lr) + if m == 'lamb': return optax.lamb(learning_rate=lr) + if m == 'lars': return optax.lars(learning_rate=lr) + raise ValueError(f"Unknown optax optimizer '{method}'") + +# ========== 新增方法:Optax penalty 优化 ========== + def minimize_penalty_optax(self, optimizer='adam', lr=1e-2, tol=1e-3, maxiter=1000, constraint_weight=1.0): + loss = lambda x: self.penalty_objective(x, constraint_weight)[0] + grad_loss = jax.grad(loss) + + opt = self._build_optax_optimizer(optimizer, lr) + x = self.surface_optimize.x + opt_state = opt.init(x) + + for it in range(int(maxiter)): + g = grad_loss(x) + updates, opt_state = opt.update(g, opt_state, x) + x = optax.apply_updates(x, updates) + if float(jnp.linalg.norm(g)) <= tol: + break + + self.surface_optimize = self._with_x(self.surface_optimize, x) + return { + "fun": float(loss(x)), + "gradient": grad_loss(x), + "iter": it + 1, + "info": {"grad_norm": float(jnp.linalg.norm(g)), "optimizer": optimizer, "lr": lr}, + "success": float(jnp.linalg.norm(g)) <= tol, + "s": self.surface_optimize, + } + +# ========== 修改 run 方法:新增 optax 分支 ========== + def run(self, tol=1e-4, maxiter=1000, method='SLSQP', constraint_weight=10.0, lr=1e-2): + method_up = method.upper() + if method_up == 'SLSQP': + return self.minimize_penalty_slsqp(tol=tol, maxiter=maxiter) + elif method_up == 'LBFGS': + return self.minimize_penalty_lbfgs( + tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) + elif method_up == 'SCIPYLBFGS': + return self.minimize_penalty_scipy_lbfgs( + tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) + elif method_up == 'SCIPYSLSQP': + return self.minimize_exact_scipy_slsqp( + tol=tol, maxiter=maxiter) + + # Optax 分支 + OPTAX_METHODS = { + 'OPTAX': 'adam', + 'ADAM': 'adam', + 'ADAMW': 'adamw', + 'SGD': 'sgd', + 'MOMENTUM': 'momentum', + 'NESTEROV': 'nesterov', + 'RMSPROP': 'rmsprop', + 'ADAGRAD': 'adagrad', + 'ADAFACTOR': 'adafactor', + 'LAMB': 'lamb', + 'LARS': 'lars', + } + if method_up in OPTAX_METHODS: + return self.minimize_penalty_optax( + optimizer=OPTAX_METHODS[method_up], + lr=lr, tol=tol, maxiter=maxiter, + constraint_weight=constraint_weight + ) + + raise ValueError(f"Unknown method '{method}'") + + + # def run(self, tol=1e-4, maxiter=1000, method='SLSQP', constraint_weight=10.0): + # method_up = method.upper() + # if method_up == 'SLSQP': + # return self.minimize_penalty_slsqp(tol=tol, maxiter=maxiter) + # elif method_up == 'LBFGS': + # return self.minimize_penalty_lbfgs( + # tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) + # elif method_up == 'SCIPYLBFGS': + # return self.minimize_penalty_scipy_lbfgs( + # tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) + # elif method_up == 'SCIPYSLSQP': + # return self.minimize_exact_scipy_slsqp( + # tol=tol, maxiter=maxiter) + # else: + # raise ValueError(f"Unknown method '{method}'") diff --git a/essos/surfaces.py b/essos/surfaces.py index 0048e3c..891ae7c 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -235,7 +235,25 @@ def x(self): @x.setter def x(self, new_dofs): self.dofs = new_dofs - + + @property + def volume(self): + + xyz = self.gamma # shape: (nphi, ntheta, 3) + n = self.normal # shape: (nphi, ntheta, 3) + + integrand = jnp.sum(xyz * n, axis=2) # dot(x, n), shape: (nphi, ntheta) + volume = jnp.mean(integrand) / 3.0 + return volume + + @property + def area(self): + n = self.normal() # (nphi, ntheta, 3) + norm_n = jnp.linalg.norm(n, axis=2) # shape: (nphi, ntheta) + avg_area = jnp.mean(norm_n) + return avg_area + + def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.") diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py index 2ded4be..c8222b6 100644 --- a/examples/optimize_coils_vmec_surface.py +++ b/examples/optimize_coils_vmec_surface.py @@ -62,11 +62,11 @@ plt.tight_layout() plt.show() -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") +# Save the coils to a json file +coils_optimized.to_json("stellarator_coils.json") +# Load the coils from a json file +from essos.coils import Coils_from_json +coils = Coils_from_json("stellarator_coils.json") # # Save results in vtk format to analyze in Paraview # from essos.fields import BiotSavart diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py new file mode 100644 index 0000000..cb0af03 --- /dev/null +++ b/examples/optimize_qfm_surface.py @@ -0,0 +1,69 @@ +import os +import jax.numpy as jnp +import matplotlib.pyplot as plt + +from essos.surfaces import SurfaceRZFourier +from essos.qfm import QfmSurface +from essos.fields import Vmec, BiotSavart + +method = 'slsqp' #slsqp lbfgs + +# 1. 加载等离子体 VMEC 文件,并生成 surface +ntheta=30 +nphi=30 +vmec = os.path.join('input_files','input.rotating_ellipse') +surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') +initial_vol = surf.volume +# 2. 创建初始线圈并生成 BiotSavart 磁场 + + +ntheta=35 +nphi=35 + +# Initialize VMEC field +truesurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') + + +truevmec = Vmec(os.path.join(os.path.dirname(__file__), 'input_files', 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), + ntheta=ntheta, nphi=nphi, range_torus='half period') + + +from essos.coils import Coils_from_json +coils = Coils_from_json("stellarator_coils.json") + +# 创建磁场对象 +field = BiotSavart(coils) + +# 3. 计算当前体积作为 target label(或设置固定值) +target_volume = truevmec.surface.volume # 你可以手动设置一个目标值,如 target_volume = 1.0 + +# 4. 构建 QfmSurface 优化器 +qfm = QfmSurface( + field=field, + surface=surf, + label='volume', # or "area" + targetlabel=target_volume # or target_area +) + +# 5. 运行优化(选择方法) +result = qfm.run(tol=1e-3, maxiter=10000,method=method) + +# 6. 打印结果 +print("Optimization method:", method) +print("Optimization success:", result['success']) +print("Final Bnormal objective:", result['fun']) +print("Iterations:", result['iter']) +print(f"target volume: {target_volume}, initial volume: {initial_vol}, final volume: {result['s'].volume}") + + + + +fig = plt.figure(figsize=(8, 4)) +ax1 = fig.add_subplot(121, projection='3d') +ax2 = fig.add_subplot(122, projection='3d') +coils.plot(ax=ax1, show=False) +truesurf.plot(ax=ax1, show=False) +coils.plot(ax=ax2, show=False) +result['s'].plot(ax=ax2, show=False) +plt.tight_layout() +plt.show() From 9690c224a496b153748313b79ad3259c6aa982a0 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Tue, 19 Aug 2025 14:32:01 -0500 Subject: [PATCH 03/18] add toroidal flux as qfm label and vector potential in BiotSavart --- essos/fields.py | 24 +--- essos/qfm.py | 211 ++++++------------------------- essos/surfaces.py | 2 +- examples/optimize_qfm_surface.py | 50 +++++--- 4 files changed, 75 insertions(+), 212 deletions(-) diff --git a/essos/fields.py b/essos/fields.py index 214a354..c8ad04a 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -29,23 +29,13 @@ def B(self, points): return jnp.mean(dB_sum, axis=0) @partial(jit, static_argnames=['self']) - def B_field(self, points): - points = jnp.array(points) # (Nphi, Ntheta, 3) - points_flat = points.reshape(-1, 3) # (Npoints, 3) - - gamma_flat = self.gamma.reshape(-1, 3) # (Ngamma, 3) - gamma_dash_flat = self.gamma_dash.reshape(-1, 3) # (Ngamma, 3) - currents_flat = jnp.repeat(self.currents, self.gamma.shape[1]) # shape = (Ngamma,) - - def B_at_point(p): - dif_R = p - gamma_flat # (Ngamma, 3) - norm = jnp.linalg.norm(dif_R, axis=1) - dB = jnp.cross(gamma_dash_flat, dif_R) / norm[:, None]**3 # (Ngamma, 3) - return jnp.einsum("i,ij->j", 1e-7 * currents_flat, dB) # (3,) - - B_flat = vmap(B_at_point)(points_flat) # (Npoints, 3) - return B_flat # (Nphi, Ntheta, 3) B_flat.reshape(points.shape) - + def A(self, points): + dif_R = (jnp.array(points)-self.gamma) + R_norm = jnp.linalg.norm(dif_R, axis=-1) + dA = self.gamma_dash / R_norm[..., None] + weighted = self.currents[:, None, None] * dA * 1e-7 + A_vec = jnp.mean(weighted, axis=(0, 1)) + return A_vec @partial(jit, static_argnames=['self']) def B_covariant(self, points): diff --git a/essos/qfm.py b/essos/qfm.py index 3b8e811..cdaff10 100644 --- a/essos/qfm.py +++ b/essos/qfm.py @@ -1,20 +1,39 @@ import jax +from jax import vmap import jax.numpy as jnp from jaxopt import LBFGS, ScipyMinimize from scipy.optimize import minimize -import optax from essos.surfaces import SurfaceRZFourier class QfmSurface: - def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float): - assert label in ["area", "volume"], f"Unsupported label: {label}" + def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float, + toroidal_flux_idx: int = 0): + assert label in ["area", "volume", "toroidal_flux"], f"Unsupported label: {label}" self.field = field - self.surface = surface - self.surface_optimize = self._with_x(surface, surface.x) + self.surface = surface + self.surface_optimize = self._with_x(surface, surface.x) self.label = label self.targetlabel = targetlabel + self.toroidal_flux_idx = int(toroidal_flux_idx) self.name = str(id(self)) + def _toroidal_flux(self, surf: SurfaceRZFourier) -> jnp.ndarray: + + idx = self.toroidal_flux_idx + + gamma = surf.gamma + + curve = gamma[idx, :, :] + dl = jnp.roll(curve, -1, axis=0) - curve + + A_vals = vmap(self.field.A)(curve) + + Adl = jnp.sum(A_vals * dl, axis=1) + + tf = jnp.sum(Adl) + return tf + + def _with_x(self, surface: SurfaceRZFourier, x): s = SurfaceRZFourier( rc=surface.rc, @@ -32,7 +51,9 @@ def objective(self, x): surf = self._with_x(self.surface_optimize, x) N = surf.unitnormal norm_N = jnp.linalg.norm(surf.normal, axis=2) - B = self.field.B_field(surf.gamma).reshape(N.shape) + points_flat = surf.gamma.reshape(-1, 3) + B = B_flat = vmap(self.field.B)(points_flat) + B = B.reshape(N.shape) B_n = jnp.sum(B * N, axis=2) norm_B = jnp.linalg.norm(B, axis=2) result = jnp.sum(B_n**2 * norm_N) / jnp.sum(norm_B**2 * norm_N) @@ -41,19 +62,22 @@ def objective(self, x): def constraint(self, x): surf = self._with_x(self.surface_optimize, x) if self.label == "volume": - return surf.volume - self.targetlabel + val = surf.volume - self.targetlabel elif self.label == "area": - return surf.area - self.targetlabel + val = surf.area - self.targetlabel + elif self.label == "toroidal_flux": + val = self._toroidal_flux(surf) - self.targetlabel else: raise ValueError(f"Unsupported label: {self.label}") + return val - def penalty_objective(self, x, constraint_weight=1.0): + def penalty_objective(self, x, constraint_weight=10): r = self.objective(x) c = self.constraint(x) result = r + 0.5 * constraint_weight * c**2 return jnp.asarray(result), None - def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): + def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=10): value_and_grad_fn = jax.value_and_grad( lambda x: self.penalty_objective(x, constraint_weight), has_aux=True @@ -79,78 +103,6 @@ def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): } - def minimize_penalty_scipy_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): - fun = lambda x: jnp.asarray(self.penalty_objective(x, constraint_weight)[0]).item() - grad = lambda x: jax.grad(lambda x_: self.penalty_objective(x_, constraint_weight)[0])(x) - x0 = self.surface_optimize.x - res = minimize( - fun=fun, x0=jnp.array(x0), jac=grad, - method='L-BFGS-B', tol=tol, options={"maxiter": maxiter} - ) - self.surface_optimize = self._with_x(self.surface_optimize, res.x) - return { - "fun": res.fun, - "gradient": grad(res.x), - "iter": res.nit, - "info": res, - "success": res.success, - "s": self.surface_optimize, - } - - def minimize_penalty_slsqp(self, tol=1e-3, maxiter=1000, constraint_weight=1.0): - fun = lambda x: self.penalty_objective(x, constraint_weight)[0] - grad = jax.grad(fun) - - solver = ScipyMinimize( - fun=fun, - method="SLSQP", - tol=tol, - maxiter=maxiter - ) - - x0 = self.surface_optimize.x - res = solver.run(x0) - self.surface_optimize = self._with_x(self.surface_optimize, res.params) - - # 安全获取迭代次数 - iter_count = getattr(res.state, "num_iters", None) - if iter_count is None: - iter_count = getattr(res.state, "maxiter", -1) - - return { - "fun": res.state.fun_val, - "gradient": grad(res.params), - "iter": iter_count, - "info": res.state, - "success": getattr(res.state, "status", 0) == 0, - "s": self.surface_optimize, - } - - - def minimize_exact_SLSQP(self, tol=1e-3, maxiter=1000): - loss_fn = lambda x: self.objective(x) - constraint_fn = lambda x: self.constraint(x) - grad_loss = jax.grad(loss_fn) - dcon = jax.grad(constraint_fn) - solver = ScipyMinimize( - fun=loss_fn, - method="SLSQP", - constraints=[{"type": "eq", "fun": constraint_fn, "jac": dcon}], - tol=tol, - options={"maxiter": maxiter} - ) - x0 = self.surface_optimize.x - res = solver.run(x0) - self.surface_optimize = self._with_x(self.surface_optimize, res.params) - return { - "fun": res.state.fun_val, - "gradient": grad_loss(res.params), - "iter": res.state.nit, - "info": res.state, - "success": res.state.status == 0, - "s": self.surface_optimize, - } - def minimize_exact_scipy_slsqp(self, tol=1e-3, maxiter=1000): fun = lambda x: jnp.asarray(self.objective(x)).item() jac = lambda x: jnp.asarray(jax.grad(self.objective)(x)) @@ -173,100 +125,13 @@ def minimize_exact_scipy_slsqp(self, tol=1e-3, maxiter=1000): "s": self.surface_optimize, } - # ⬅️ 新增 -# ========== 新增方法:构造 optax 优化器 ========== - def _build_optax_optimizer(self, method: str, lr: float): - m = method.strip().lower() - if m == 'adam': return optax.adam(lr) - if m == 'adamw': return optax.adamw(lr) - if m == 'sgd': return optax.sgd(lr) - if m == 'momentum': return optax.sgd(lr, momentum=0.9) - if m == 'nesterov': return optax.sgd(lr, momentum=0.9, nesterov=True) - if m == 'rmsprop': return optax.rmsprop(lr) - if m == 'adagrad': return optax.adagrad(lr) - if m == 'adafactor': return optax.adafactor(learning_rate=lr) - if m == 'lamb': return optax.lamb(learning_rate=lr) - if m == 'lars': return optax.lars(learning_rate=lr) - raise ValueError(f"Unknown optax optimizer '{method}'") - -# ========== 新增方法:Optax penalty 优化 ========== - def minimize_penalty_optax(self, optimizer='adam', lr=1e-2, tol=1e-3, maxiter=1000, constraint_weight=1.0): - loss = lambda x: self.penalty_objective(x, constraint_weight)[0] - grad_loss = jax.grad(loss) - - opt = self._build_optax_optimizer(optimizer, lr) - x = self.surface_optimize.x - opt_state = opt.init(x) - - for it in range(int(maxiter)): - g = grad_loss(x) - updates, opt_state = opt.update(g, opt_state, x) - x = optax.apply_updates(x, updates) - if float(jnp.linalg.norm(g)) <= tol: - break - - self.surface_optimize = self._with_x(self.surface_optimize, x) - return { - "fun": float(loss(x)), - "gradient": grad_loss(x), - "iter": it + 1, - "info": {"grad_norm": float(jnp.linalg.norm(g)), "optimizer": optimizer, "lr": lr}, - "success": float(jnp.linalg.norm(g)) <= tol, - "s": self.surface_optimize, - } - -# ========== 修改 run 方法:新增 optax 分支 ========== - def run(self, tol=1e-4, maxiter=1000, method='SLSQP', constraint_weight=10.0, lr=1e-2): + def run(self, tol=1e-4, maxiter=1000, method='SLSQP', constraint_weight=10.0): method_up = method.upper() if method_up == 'SLSQP': - return self.minimize_penalty_slsqp(tol=tol, maxiter=maxiter) + return self.minimize_exact_scipy_slsqp(tol=tol, maxiter=maxiter) elif method_up == 'LBFGS': return self.minimize_penalty_lbfgs( tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) - elif method_up == 'SCIPYLBFGS': - return self.minimize_penalty_scipy_lbfgs( - tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) - elif method_up == 'SCIPYSLSQP': - return self.minimize_exact_scipy_slsqp( - tol=tol, maxiter=maxiter) - - # Optax 分支 - OPTAX_METHODS = { - 'OPTAX': 'adam', - 'ADAM': 'adam', - 'ADAMW': 'adamw', - 'SGD': 'sgd', - 'MOMENTUM': 'momentum', - 'NESTEROV': 'nesterov', - 'RMSPROP': 'rmsprop', - 'ADAGRAD': 'adagrad', - 'ADAFACTOR': 'adafactor', - 'LAMB': 'lamb', - 'LARS': 'lars', - } - if method_up in OPTAX_METHODS: - return self.minimize_penalty_optax( - optimizer=OPTAX_METHODS[method_up], - lr=lr, tol=tol, maxiter=maxiter, - constraint_weight=constraint_weight - ) - - raise ValueError(f"Unknown method '{method}'") - - - # def run(self, tol=1e-4, maxiter=1000, method='SLSQP', constraint_weight=10.0): - # method_up = method.upper() - # if method_up == 'SLSQP': - # return self.minimize_penalty_slsqp(tol=tol, maxiter=maxiter) - # elif method_up == 'LBFGS': - # return self.minimize_penalty_lbfgs( - # tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) - # elif method_up == 'SCIPYLBFGS': - # return self.minimize_penalty_scipy_lbfgs( - # tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) - # elif method_up == 'SCIPYSLSQP': - # return self.minimize_exact_scipy_slsqp( - # tol=tol, maxiter=maxiter) - # else: - # raise ValueError(f"Unknown method '{method}'") + else: + raise ValueError(f"Unknown method '{method}'") diff --git a/essos/surfaces.py b/essos/surfaces.py index 891ae7c..b7dc83d 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -248,7 +248,7 @@ def volume(self): @property def area(self): - n = self.normal() # (nphi, ntheta, 3) + n = self.normal # (nphi, ntheta, 3) norm_n = jnp.linalg.norm(n, axis=2) # shape: (nphi, ntheta) avg_area = jnp.mean(norm_n) return avg_area diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index cb0af03..d2aed4f 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -2,68 +2,76 @@ import jax.numpy as jnp import matplotlib.pyplot as plt +from essos.surfaces import BdotN_over_B from essos.surfaces import SurfaceRZFourier from essos.qfm import QfmSurface from essos.fields import Vmec, BiotSavart -method = 'slsqp' #slsqp lbfgs -# 1. 加载等离子体 VMEC 文件,并生成 surface ntheta=30 nphi=30 vmec = os.path.join('input_files','input.rotating_ellipse') surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') initial_vol = surf.volume -# 2. 创建初始线圈并生成 BiotSavart 磁场 - ntheta=35 nphi=35 # Initialize VMEC field -truesurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') +initialsurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') truevmec = Vmec(os.path.join(os.path.dirname(__file__), 'input_files', 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), ntheta=ntheta, nphi=nphi, range_torus='half period') +method = 'lbfgs' #slsqp lbfgs +label = 'volume' +target_label = truevmec.surface.volume from essos.coils import Coils_from_json coils = Coils_from_json("stellarator_coils.json") -# 创建磁场对象 field = BiotSavart(coils) + -# 3. 计算当前体积作为 target label(或设置固定值) -target_volume = truevmec.surface.volume # 你可以手动设置一个目标值,如 target_volume = 1.0 - -# 4. 构建 QfmSurface 优化器 +BdotN_over_B_initial = BdotN_over_B(surf, BiotSavart(coils)) qfm = QfmSurface( field=field, surface=surf, - label='volume', # or "area" - targetlabel=target_volume # or target_area + label=label, + targetlabel=target_label ) -# 5. 运行优化(选择方法) result = qfm.run(tol=1e-3, maxiter=10000,method=method) -# 6. 打印结果 +BdotN_over_B_optimized = BdotN_over_B(result['s'], BiotSavart(coils)) print("Optimization method:", method) print("Optimization success:", result['success']) -print("Final Bnormal objective:", result['fun']) +print("Final qfm objective:", result['fun']) print("Iterations:", result['iter']) -print(f"target volume: {target_volume}, initial volume: {initial_vol}, final volume: {result['s'].volume}") +print(f"initial volume: {initial_vol}, target volume: {target_label}, final volume: {result['s'].volume}") +print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") +print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") +fig = plt.figure(figsize=(8, 4)) +ax1 = fig.add_subplot(131, projection='3d') +ax2 = fig.add_subplot(132, projection='3d') +ax3 = fig.add_subplot(133, projection='3d') -fig = plt.figure(figsize=(8, 4)) -ax1 = fig.add_subplot(121, projection='3d') -ax2 = fig.add_subplot(122, projection='3d') coils.plot(ax=ax1, show=False) -truesurf.plot(ax=ax1, show=False) +initialsurf.plot(ax=ax1, show=False) +ax1.set_title("Initial Surface") + coils.plot(ax=ax2, show=False) -result['s'].plot(ax=ax2, show=False) +truevmec.surface.plot(ax=ax2, show=False) +ax2.set_title("True VMEC Surface") + +coils.plot(ax=ax3, show=False) +result['s'].plot(ax=ax3, show=False) +ax3.set_title("Final Surface") + +# 布局 & 显示 plt.tight_layout() plt.show() From a41766c01629e1e9e8b04b09b05bd87aea077b96 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 21 Aug 2025 15:25:37 -0500 Subject: [PATCH 04/18] update qfm and example to solve some bugs; add toroidal field in surfaces.py and A in fields.py --- essos/fields.py | 12 +- essos/qfm.py | 59 +++++---- essos/surfaces.py | 68 +++++++++++ examples/input_files/input.initial_guess | 17 +++ examples/optimize_coils_vmec_surface.py | 6 +- examples/optimize_qfm_surface.py | 149 +++++++++++++++++------ 6 files changed, 244 insertions(+), 67 deletions(-) create mode 100644 examples/input_files/input.initial_guess diff --git a/essos/fields.py b/essos/fields.py index c8ad04a..2910cd1 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -27,14 +27,16 @@ def B(self, points): dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") return jnp.mean(dB_sum, axis=0) - + @partial(jit, static_argnames=['self']) def A(self, points): dif_R = (jnp.array(points)-self.gamma) - R_norm = jnp.linalg.norm(dif_R, axis=-1) - dA = self.gamma_dash / R_norm[..., None] - weighted = self.currents[:, None, None] * dA * 1e-7 - A_vec = jnp.mean(weighted, axis=(0, 1)) + dA = self.gamma_dash / jnp.linalg.norm(dif_R, axis=-1, keepdims=True) + A_vec = jnp.sum( + jnp.mean(self.currents[:, None, None] * dA * 1e-7, axis=1), + axis=0 + ) + return A_vec @partial(jit, static_argnames=['self']) diff --git a/essos/qfm.py b/essos/qfm.py index cdaff10..9339fc0 100644 --- a/essos/qfm.py +++ b/essos/qfm.py @@ -6,35 +6,40 @@ from essos.surfaces import SurfaceRZFourier class QfmSurface: - def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float, - toroidal_flux_idx: int = 0): + def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float = None, + toroidal_flux_idx: int = 0): assert label in ["area", "volume", "toroidal_flux"], f"Unsupported label: {label}" + self.field = field self.surface = surface - self.surface_optimize = self._with_x(surface, surface.x) + self.surface_optimize = self._build_surface_with_x(surface, surface.x) self.label = label - self.targetlabel = targetlabel self.toroidal_flux_idx = int(toroidal_flux_idx) self.name = str(id(self)) + if targetlabel is None: + if label == "volume": + self.targetlabel = surface.volume + elif label == "area": + self.targetlabel = surface.area + elif label == "toroidal_flux": + self.targetlabel = self._toroidal_flux(surface) + else: + raise ValueError(f"Unsupported label: {label}") + else: + self.targetlabel = targetlabel + def _toroidal_flux(self, surf: SurfaceRZFourier) -> jnp.ndarray: - idx = self.toroidal_flux_idx - gamma = surf.gamma - curve = gamma[idx, :, :] dl = jnp.roll(curve, -1, axis=0) - curve - A_vals = vmap(self.field.A)(curve) - Adl = jnp.sum(A_vals * dl, axis=1) - tf = jnp.sum(Adl) return tf - - def _with_x(self, surface: SurfaceRZFourier, x): + def _build_surface_with_x(self, surface: SurfaceRZFourier, x): s = SurfaceRZFourier( rc=surface.rc, zs=surface.zs, @@ -42,13 +47,13 @@ def _with_x(self, surface: SurfaceRZFourier, x): ntheta=surface.ntheta, nphi=surface.nphi, range_torus=surface.range_torus, - close=True + close=False ) s.x = x return s def objective(self, x): - surf = self._with_x(self.surface_optimize, x) + surf = self._build_surface_with_x(self.surface_optimize, x) N = surf.unitnormal norm_N = jnp.linalg.norm(surf.normal, axis=2) points_flat = surf.gamma.reshape(-1, 3) @@ -60,7 +65,13 @@ def objective(self, x): return result def constraint(self, x): - surf = self._with_x(self.surface_optimize, x) + """ + result estimate + volume: 1e-6 + area: 1e-6 + toroidal flux: 1e-12 + """ + surf = self._build_surface_with_x(self.surface_optimize, x) if self.label == "volume": val = surf.volume - self.targetlabel elif self.label == "area": @@ -71,13 +82,19 @@ def constraint(self, x): raise ValueError(f"Unsupported label: {self.label}") return val - def penalty_objective(self, x, constraint_weight=10): + def penalty_objective(self, x, constraint_weight=1.0): + """ + weight estimate + volume: 1e1 + area: 1e1 + toroidal flux: 1e10 + """ r = self.objective(x) c = self.constraint(x) result = r + 0.5 * constraint_weight * c**2 return jnp.asarray(result), None - def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=10): + def minimize_penalty_lbfgs(self, tol=1e-6, maxiter=1000, constraint_weight=1.0): value_and_grad_fn = jax.value_and_grad( lambda x: self.penalty_objective(x, constraint_weight), has_aux=True @@ -92,7 +109,7 @@ def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=10): ) x0 = self.surface_optimize.x res = solver.run(x0) - self.surface_optimize = self._with_x(self.surface_optimize, res.params) + self.surface_optimize = self._build_surface_with_x(self.surface_optimize, res.params) return { "fun": res.state.value, "gradient": jax.grad(lambda x: self.penalty_objective(x, constraint_weight)[0])(res.params), @@ -103,7 +120,7 @@ def minimize_penalty_lbfgs(self, tol=1e-3, maxiter=1000, constraint_weight=10): } - def minimize_exact_scipy_slsqp(self, tol=1e-3, maxiter=1000): + def minimize_exact_scipy_slsqp(self, tol=1e-6, maxiter=1000): fun = lambda x: jnp.asarray(self.objective(x)).item() jac = lambda x: jnp.asarray(jax.grad(self.objective)(x)) con_fun = lambda x: jnp.asarray(self.constraint(x)).item() @@ -115,7 +132,7 @@ def minimize_exact_scipy_slsqp(self, tol=1e-3, maxiter=1000): constraints=constraints, method='SLSQP', tol=tol, options={"maxiter": maxiter} ) - self.surface_optimize = self._with_x(self.surface_optimize, res.x) + self.surface_optimize = self._build_surface_with_x(self.surface_optimize, res.x) return { "fun": res.fun, "gradient": jac(res.x), @@ -126,7 +143,7 @@ def minimize_exact_scipy_slsqp(self, tol=1e-3, maxiter=1000): } - def run(self, tol=1e-4, maxiter=1000, method='SLSQP', constraint_weight=10.0): + def run(self, tol=1e-6, maxiter=1000, method='SLSQP', constraint_weight=1.0): method_up = method.upper() if method_up == 'SLSQP': return self.minimize_exact_scipy_slsqp(tol=tol, maxiter=maxiter) diff --git a/essos/surfaces.py b/essos/surfaces.py index b7dc83d..2388d78 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -9,6 +9,16 @@ mesh = Mesh(devices(), ("dev",)) sharding = NamedSharding(mesh, PartitionSpec("dev", None)) +@partial(jit, static_argnames=['surface','field']) +def toroidal_flux(surface, field, idx=0) -> jnp.ndarray: + gamma = surface.gamma + curve = gamma[idx, :, :] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(field.A)(curve) + Adl = jnp.sum(A_vals * dl, axis=1) + tf = jnp.sum(Adl) + return tf + @partial(jit, static_argnames=['surface','field']) def B_on_surface(surface, field): ntheta = surface.ntheta @@ -253,6 +263,64 @@ def area(self): avg_area = jnp.mean(norm_n) return avg_area + def change_resolution(self, mpol: int, ntor: int): + """ + Change the values of `mpol` and `ntor`. + New Fourier coefficients are zero by default. + Old coefficients outside the new range are discarded. + """ + rc_old, zs_old = self.rc, self.zs + mpol_old, ntor_old = self.mpol, self.ntor + + rc_new = jnp.zeros((mpol, 2 * ntor + 1)) + zs_new = jnp.zeros((mpol, 2 * ntor + 1)) + + m_keep = min(mpol_old, mpol) + n_keep = min(ntor_old, ntor) + + old_slice = slice(ntor_old - n_keep, ntor_old + n_keep + 1) + new_slice = slice(ntor - n_keep, ntor + n_keep + 1) + + # Copy overlapping region + rc_new = rc_new.at[:m_keep, new_slice].set(rc_old[:m_keep, old_slice]) + zs_new = zs_new.at[:m_keep, new_slice].set(zs_old[:m_keep, old_slice]) + + # Update attributes + self.mpol, self.ntor = mpol, ntor + self.rc, self.zs = rc_new, zs_new + + # Recompute xm/xn and interpolation arrays + m1d = jnp.arange(self.mpol) + n1d = jnp.arange(-self.ntor, self.ntor + 1) + n2d, m2d = jnp.meshgrid(n1d, m1d) + self.xm = m2d.flatten()[self.ntor:] + self.xn = self.nfp * n2d.flatten()[self.ntor:] + + indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T + self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] + self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] + + # Update degrees of freedom + self.num_dofs_rc = len(jnp.ravel(self.rc)[self.ntor:]) + self.num_dofs_zs = len(jnp.ravel(self.zs)[self.ntor:]) + self._dofs = jnp.concatenate( + (jnp.ravel(self.rc)[self.ntor:], jnp.ravel(self.zs)[self.ntor:]) + ) + + # Recompute angles and geometry + self.angles = ( + jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) + - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) + ) + (self._gamma, self._gammadash_theta, self._gammadash_phi, + self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + + # Recompute AbsB if available + if hasattr(self, 'bmnc'): + self._AbsB = self._set_AbsB() + + return self + def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.") diff --git a/examples/input_files/input.initial_guess b/examples/input_files/input.initial_guess new file mode 100644 index 0000000..da3e30f --- /dev/null +++ b/examples/input_files/input.initial_guess @@ -0,0 +1,17 @@ +!----- Runtime Parameters ----- +&INDATA + LASYM = F + NFP = 0002 + MPOL = 003 + NTOR = 003 +!----- Boundary Parameters (n,m) ----- + RBC( 000,000) = 10 ZBS( 000,000) = 0 + RBC( 001,000) = 1 ZBS( 001,000) = -1 + RBC(-001,001) = 0.1 ZBS(-001,001) = 0.1 + RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 + RBC( 001,001) = -1 ZBS( 001,001) = 1 + RBC(-002,002) = 1E-4 ZBS(-002,002) = 1E-4 + RBC(-002,001) = 1E-4 ZBS(-002,001) = 0. + RBC(-003,003) = 1E-4 ZBS(-003,003) = 0. + +/ diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py index c8222b6..12e5f9e 100644 --- a/examples/optimize_coils_vmec_surface.py +++ b/examples/optimize_coils_vmec_surface.py @@ -15,7 +15,7 @@ max_coil_curvature = 1.0 order_Fourier_series_coils = 3 number_coil_points = order_Fourier_series_coils*15 -maximum_function_evaluations = 50 +maximum_function_evaluations = 500 number_coils_per_half_field_period = 3 tolerance_optimization = 1e-5 ntheta=35 @@ -63,10 +63,10 @@ plt.show() # Save the coils to a json file -coils_optimized.to_json("stellarator_coils.json") +coils_optimized.to_json("input_files/stellarator_coils.json") # Load the coils from a json file from essos.coils import Coils_from_json -coils = Coils_from_json("stellarator_coils.json") +coils = Coils_from_json("input_files/stellarator_coils.json") # # Save results in vtk format to analyze in Paraview # from essos.fields import BiotSavart diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index d2aed4f..e6dd0e7 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -1,77 +1,150 @@ import os +number_of_processors_to_use = 3 # Parallelization +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' + import jax.numpy as jnp import matplotlib.pyplot as plt +from time import time +from jax import block_until_ready -from essos.surfaces import BdotN_over_B +from essos.dynamics import Tracing +from essos.surfaces import BdotN_over_B, toroidal_flux from essos.surfaces import SurfaceRZFourier from essos.qfm import QfmSurface from essos.fields import Vmec, BiotSavart - -ntheta=30 -nphi=30 -vmec = os.path.join('input_files','input.rotating_ellipse') -surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') -initial_vol = surf.volume - +# Load initial guess surface ntheta=35 -nphi=35 - -# Initialize VMEC field -initialsurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period') +nphi=36 +vmec = os.path.join('input_files','input.initial_guess') +surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) +surf.change_resolution(6,6) +initialsurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) +# Load target VMEC surface truevmec = Vmec(os.path.join(os.path.dirname(__file__), 'input_files', 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), - ntheta=ntheta, nphi=nphi, range_torus='half period') - -method = 'lbfgs' #slsqp lbfgs -label = 'volume' -target_label = truevmec.surface.volume + ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) +# Load coils and construct field from essos.coils import Coils_from_json -coils = Coils_from_json("stellarator_coils.json") - +coils = Coils_from_json("input_files/stellarator_coils.json") # from optimize_coils_vmec_surface.py field = BiotSavart(coils) - + +# QFM optimization setup +method = 'lbfgs' +label = 'toroidal_flux' +initial_label = toroidal_flux(surf, field) +targetlabel = toroidal_flux(truevmec.surface, field) +tol = 1e-6 +constraint_weight = 1e10 +maxiter = 1000 BdotN_over_B_initial = BdotN_over_B(surf, BiotSavart(coils)) -qfm = QfmSurface( - field=field, - surface=surf, - label=label, - targetlabel=target_label -) -result = qfm.run(tol=1e-3, maxiter=10000,method=method) +# Initialize QFM optimizer +qfm = QfmSurface(field=field, surface=surf, label=label, targetlabel=targetlabel) + +print("Degrees of Freedom:", qfm.surface.x.shape[0]) +result = qfm.run(tol=tol, maxiter=maxiter, method=method, constraint_weight=constraint_weight) + +# Evaluate final objective and constraint +x_opt = result["s"].x +qfm_loss = float(jnp.asarray(qfm.objective(x_opt))) +c_loss = float(jnp.asarray(qfm.constraint(x_opt))) BdotN_over_B_optimized = BdotN_over_B(result['s'], BiotSavart(coils)) print("Optimization method:", method) +print("Optimization label:", label) print("Optimization success:", result['success']) -print("Final qfm objective:", result['fun']) +print(f"final qfm objective = {qfm_loss:.3e}, final constraint objective = {c_loss:.3e}") print("Iterations:", result['iter']) -print(f"initial volume: {initial_vol}, target volume: {target_label}, final volume: {result['s'].volume}") +print(f"initial label: {initial_label}, target label: {targetlabel}, final label: {toroidal_flux(result['s'], field)}") print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") - +# Plot surfaces fig = plt.figure(figsize=(8, 4)) ax1 = fig.add_subplot(131, projection='3d') ax2 = fig.add_subplot(132, projection='3d') ax3 = fig.add_subplot(133, projection='3d') +# coils.plot(ax=ax1, show=False) +# coils.plot(ax=ax2, show=False) +# coils.plot(ax=ax3, show=False) -coils.plot(ax=ax1, show=False) -initialsurf.plot(ax=ax1, show=False) -ax1.set_title("Initial Surface") -coils.plot(ax=ax2, show=False) +initialsurf.plot(ax=ax1, show=False) truevmec.surface.plot(ax=ax2, show=False) -ax2.set_title("True VMEC Surface") - -coils.plot(ax=ax3, show=False) result['s'].plot(ax=ax3, show=False) + +ax1.set_title("Initial Surface") +ax2.set_title("True VMEC Surface") ax3.set_title("Final Surface") -# 布局 & 显示 +plt.tight_layout() +plt.show() + +# Field line tracing +tmax = 10000000000 +nfieldlines_per_core = 3 +nfieldlines = nfieldlines_per_core * number_of_processors_to_use +R0 = jnp.linspace(12.2, 13.5, nfieldlines) +trace_tolerance = 1e-7 +num_steps = 60000 + +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T + +time0 = time() +tracing = block_until_ready(Tracing( + field=field, + model='FieldLineAdaptative', + initial_conditions=initial_xyz, + maxtime=tmax, + times_to_trace=num_steps, + atol=trace_tolerance, + rtol=trace_tolerance +)) +print(f"ESSOS tracing took {time() - time0:.2f} seconds") + +trajectories = tracing.trajectories +traj = trajectories[0] +R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] + +phi_u = jnp.unwrap(phi) +phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] +phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] + +theta = jnp.linspace(0, 2 * jnp.pi, 200) + +def compute_rz_on_phi(surface, theta, phi=0.0): + angles = jnp.outer(theta, surface.xm) - phi * surface.xn + R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) + Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) + return R, Z + +# Contours from optimized surface +R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) +R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) + +# Contours from true VMEC surface +R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) +R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) + +fig, ax = plt.subplots(figsize=(6, 6)) + +tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) +ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") +ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") +ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") +ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") + +ax.set_xlabel("R") +ax.set_ylabel("Z") +ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") +ax.legend() +ax.axis("equal") plt.tight_layout() plt.show() From 60df5361e5f5c3e8bd59e48340d3e0810ab6e116 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 21 Aug 2025 15:29:46 -0500 Subject: [PATCH 05/18] delete input that we dont need actually --- examples/input_files/input.initial_guess | 17 ----------------- examples/optimize_qfm_surface.py | 2 +- 2 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 examples/input_files/input.initial_guess diff --git a/examples/input_files/input.initial_guess b/examples/input_files/input.initial_guess deleted file mode 100644 index da3e30f..0000000 --- a/examples/input_files/input.initial_guess +++ /dev/null @@ -1,17 +0,0 @@ -!----- Runtime Parameters ----- -&INDATA - LASYM = F - NFP = 0002 - MPOL = 003 - NTOR = 003 -!----- Boundary Parameters (n,m) ----- - RBC( 000,000) = 10 ZBS( 000,000) = 0 - RBC( 001,000) = 1 ZBS( 001,000) = -1 - RBC(-001,001) = 0.1 ZBS(-001,001) = 0.1 - RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 - RBC( 001,001) = -1 ZBS( 001,001) = 1 - RBC(-002,002) = 1E-4 ZBS(-002,002) = 1E-4 - RBC(-002,001) = 1E-4 ZBS(-002,001) = 0. - RBC(-003,003) = 1E-4 ZBS(-003,003) = 0. - -/ diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index e6dd0e7..fcf42c0 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -16,7 +16,7 @@ # Load initial guess surface ntheta=35 nphi=36 -vmec = os.path.join('input_files','input.initial_guess') +vmec = os.path.join('input_files','input.rotating_ellipse') surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) surf.change_resolution(6,6) From 4198c3c8b00bf7478cd248530d57bd491f26e1db Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 21 Aug 2025 15:33:16 -0500 Subject: [PATCH 06/18] add jaxopt in requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f7c6b4a..fe2a944 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ netcdf4 f90nml pyevtk optuna -pandas \ No newline at end of file +pandas +jaxopt \ No newline at end of file From d2f21ca4bc960578daf601f85923b9f8838231dd Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 21 Aug 2025 15:39:22 -0500 Subject: [PATCH 07/18] add stellarator_coils.json in input_files --- examples/input_files/stellarator_coils.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/input_files/stellarator_coils.json diff --git a/examples/input_files/stellarator_coils.json b/examples/input_files/stellarator_coils.json new file mode 100644 index 0000000..2db7d7d --- /dev/null +++ b/examples/input_files/stellarator_coils.json @@ -0,0 +1 @@ +{"nfp": 2, "stellsym": true, "order": 3, "n_segments": 45, "dofs_curves": [[[11.622700935298976, 0.021162524110782358, 4.969310302793117, 0.21619524408779187, 1.0480524616370408, 0.28649150353772795, -0.28257357882380735], [2.650045608829704, 1.3195115302904592, 2.0655255644324324, -0.9222231514058897, -0.6185897653334609, -0.029257891517016514, -0.4689941563806397], [0.5736288312625168, -6.124609285728119, -0.25270656488797033, -0.7494953058315302, -0.5577484125670837, 0.41280301860103114, 0.49815152814320796]], [[8.270668252860887, 0.3665817338447269, 3.1105370020049437, -1.1258699831006203, 1.3263614203440663, 0.2102848349841804, -0.49142766630482954], [7.030418225942375, 0.6256285318976378, 4.821568953388758, -0.3255223631956292, -0.5563748261863908, 0.5699454535195534, 0.4766999014677532], [1.1730264767378422, -5.641665411926324, -0.3776722058863022, -0.22889203846752135, -0.37533249101718297, -0.13615989719970903, 0.2925684224469988]], [[2.870219295136962, 1.011171856312114, 1.0052842269429074, -1.616205601094772, 0.4568093873850295, 0.8322530493441007, -0.47523615678931685], [9.19851171039639, 0.15301426329616008, 6.079831607393782, -0.1245216794618166, -0.5264330665117136, 0.028667523169553216, 0.4360349874253642], [0.5149651564138196, -5.068527850344334, -0.2765171026569377, -0.4083657568321486, -0.061669206943173696, -0.01678957665952365, 0.08107054639495188]]], "dofs_currents": [0.9714976808550443, 1.0086510674635079, 1.019851251681448]} \ No newline at end of file From 0f3aa15fda748e1fd25baba1f40afd5f567b1116 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Fri, 12 Sep 2025 16:05:51 +0800 Subject: [PATCH 08/18] write better example --- examples/optimize_qfm_surface.py | 171 ++++++++++++++++++------------- 1 file changed, 99 insertions(+), 72 deletions(-) diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index fcf42c0..917b0ea 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -1,13 +1,11 @@ import os -number_of_processors_to_use = 3 # Parallelization +number_of_processors_to_use = 3 os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' import jax.numpy as jnp import matplotlib.pyplot as plt from time import time -from jax import block_until_ready -from essos.dynamics import Tracing from essos.surfaces import BdotN_over_B, toroidal_flux from essos.surfaces import SurfaceRZFourier from essos.qfm import QfmSurface @@ -18,7 +16,7 @@ nphi=36 vmec = os.path.join('input_files','input.rotating_ellipse') surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) -surf.change_resolution(6,6) +surf.change_resolution(5,5) initialsurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) @@ -32,10 +30,19 @@ field = BiotSavart(coils) # QFM optimization setup -method = 'lbfgs' -label = 'toroidal_flux' -initial_label = toroidal_flux(surf, field) -targetlabel = toroidal_flux(truevmec.surface, field) +method = 'slsqp' # lbfgs, slsqp +label = 'area' # 'area', 'volume', 'toroidal_flux' +initial_label = None +targetlabel = None +if label == 'toroidal_flux': + initial_label = toroidal_flux(surf, field) + targetlabel = toroidal_flux(truevmec.surface, field) +elif label == 'volume': + initial_label = surf.volume + targetlabel = truevmec.surface.volume +elif label == 'area': + initial_label = surf.area + targetlabel = truevmec.surface.area tol = 1e-6 constraint_weight = 1e10 maxiter = 1000 @@ -46,7 +53,9 @@ qfm = QfmSurface(field=field, surface=surf, label=label, targetlabel=targetlabel) print("Degrees of Freedom:", qfm.surface.x.shape[0]) +start_time = time() result = qfm.run(tol=tol, maxiter=maxiter, method=method, constraint_weight=constraint_weight) +end_time = time() # Evaluate final objective and constraint x_opt = result["s"].x @@ -59,10 +68,24 @@ print("Optimization success:", result['success']) print(f"final qfm objective = {qfm_loss:.3e}, final constraint objective = {c_loss:.3e}") print("Iterations:", result['iter']) -print(f"initial label: {initial_label}, target label: {targetlabel}, final label: {toroidal_flux(result['s'], field)}") +print(f"Optimization time: {end_time - start_time}") + print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") +initial_area = surf.area +initial_volume = surf.volume +initial_tf = toroidal_flux(surf, field) + +final_area = result['s'].area +final_volume = result['s'].volume +final_tf = toroidal_flux(result['s'], field) + +print(f"Initial labels -> area: {initial_area:.6e}, volume: {initial_volume:.6e}, toroidal_flux: {initial_tf:.6e}") +print(f"target label: {label} target label value: {targetlabel}") +print(f"Final labels -> area: {final_area:.6e}, volume: {final_volume:.6e}, toroidal_flux: {final_tf:.6e}") + + # Plot surfaces fig = plt.figure(figsize=(8, 4)) ax1 = fig.add_subplot(131, projection='3d') @@ -85,66 +108,70 @@ plt.tight_layout() plt.show() -# Field line tracing -tmax = 10000000000 -nfieldlines_per_core = 3 -nfieldlines = nfieldlines_per_core * number_of_processors_to_use -R0 = jnp.linspace(12.2, 13.5, nfieldlines) -trace_tolerance = 1e-7 -num_steps = 60000 - -Z0 = jnp.zeros(nfieldlines) -phi0 = jnp.zeros(nfieldlines) -initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T - -time0 = time() -tracing = block_until_ready(Tracing( - field=field, - model='FieldLineAdaptative', - initial_conditions=initial_xyz, - maxtime=tmax, - times_to_trace=num_steps, - atol=trace_tolerance, - rtol=trace_tolerance -)) -print(f"ESSOS tracing took {time() - time0:.2f} seconds") - -trajectories = tracing.trajectories -traj = trajectories[0] -R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] - -phi_u = jnp.unwrap(phi) -phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] -phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] - -theta = jnp.linspace(0, 2 * jnp.pi, 200) - -def compute_rz_on_phi(surface, theta, phi=0.0): - angles = jnp.outer(theta, surface.xm) - phi * surface.xn - R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) - Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) - return R, Z - -# Contours from optimized surface -R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) -R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) - -# Contours from true VMEC surface -R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) -R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) - -fig, ax = plt.subplots(figsize=(6, 6)) - -tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) -ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") -ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") -ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") -ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") - -ax.set_xlabel("R") -ax.set_ylabel("Z") -ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") -ax.legend() -ax.axis("equal") -plt.tight_layout() -plt.show() + +# # Field line tracing +# from jax import block_until_ready +# from essos.dynamics import Tracing + +# tmax = 100000000000 +# nfieldlines_per_core = 5 +# nfieldlines = nfieldlines_per_core * number_of_processors_to_use +# R0 = jnp.linspace(12.2, 13.5, nfieldlines) +# trace_tolerance = 1e-7 +# num_steps = 60000 + +# Z0 = jnp.zeros(nfieldlines) +# phi0 = jnp.zeros(nfieldlines) +# initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T + +# time0 = time() +# tracing = block_until_ready(Tracing( +# field=field, +# model='FieldLineAdaptative', +# initial_conditions=initial_xyz, +# maxtime=tmax, +# times_to_trace=num_steps, +# atol=trace_tolerance, +# rtol=trace_tolerance +# )) +# print(f"ESSOS tracing took {time() - time0:.2f} seconds") + +# trajectories = tracing.trajectories +# traj = trajectories[0] +# R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] + +# phi_u = jnp.unwrap(phi) +# phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] +# phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] + +# theta = jnp.linspace(0, 2 * jnp.pi, 200) + +# def compute_rz_on_phi(surface, theta, phi=0.0): +# angles = jnp.outer(theta, surface.xm) - phi * surface.xn +# R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) +# Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) +# return R, Z + +# # Contours from optimized surface +# R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) +# R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) + +# # Contours from true VMEC surface +# R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) +# R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) + +# fig, ax = plt.subplots(figsize=(6, 6)) + +# tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) +# ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") +# ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") +# ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") +# ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") + +# ax.set_xlabel("R") +# ax.set_ylabel("Z") +# ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") +# ax.legend() +# ax.axis("equal") +# plt.tight_layout() +# plt.show() From 96495633d1a583a7490d88525296bc2cede75016 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Tue, 16 Sep 2025 10:24:12 +0800 Subject: [PATCH 09/18] change suggested weight --- essos/qfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/essos/qfm.py b/essos/qfm.py index 9339fc0..0ba115f 100644 --- a/essos/qfm.py +++ b/essos/qfm.py @@ -94,7 +94,7 @@ def penalty_objective(self, x, constraint_weight=1.0): result = r + 0.5 * constraint_weight * c**2 return jnp.asarray(result), None - def minimize_penalty_lbfgs(self, tol=1e-6, maxiter=1000, constraint_weight=1.0): + def minimize_penalty_lbfgs(self, tol=1e-6, maxiter=1000, constraint_weight=1e4): value_and_grad_fn = jax.value_and_grad( lambda x: self.penalty_objective(x, constraint_weight), has_aux=True @@ -143,7 +143,7 @@ def minimize_exact_scipy_slsqp(self, tol=1e-6, maxiter=1000): } - def run(self, tol=1e-6, maxiter=1000, method='SLSQP', constraint_weight=1.0): + def run(self, tol=1e-6, maxiter=1000, method='SLSQP', constraint_weight=1e4): method_up = method.upper() if method_up == 'SLSQP': return self.minimize_exact_scipy_slsqp(tol=tol, maxiter=maxiter) From 5bba340d7660f0381edeb53a674a50b917980d9c Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 20:47:51 +0800 Subject: [PATCH 10/18] change qfm class to a better structure for jax; normalize constrain; --- essos/qfm.py | 281 ++++++++++++++++++++----------- essos/surfaces.py | 13 +- examples/optimize_qfm_surface.py | 185 +++++++++++--------- 3 files changed, 292 insertions(+), 187 deletions(-) diff --git a/essos/qfm.py b/essos/qfm.py index 0ba115f..5df1fe5 100644 --- a/essos/qfm.py +++ b/essos/qfm.py @@ -1,154 +1,233 @@ +# qfm_jax.py import jax -from jax import vmap +from jax import vmap, grad, value_and_grad, device_get import jax.numpy as jnp -from jaxopt import LBFGS, ScipyMinimize +from jaxopt import LBFGS +from essos.surfaces import SurfaceRZFourier from scipy.optimize import minimize -from essos.surfaces import SurfaceRZFourier + class QfmSurface: def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float = None, - toroidal_flux_idx: int = 0): + toroidal_flux_idx: int = 0): assert label in ["area", "volume", "toroidal_flux"], f"Unsupported label: {label}" - + self.field = field - self.surface = surface - self.surface_optimize = self._build_surface_with_x(surface, surface.x) + self.surface = surface + self.surface_optimize = self._build_surface_with_x(surface, surface.x) self.label = label - self.toroidal_flux_idx = int(toroidal_flux_idx) + self.toroidal_flux_idx = int(toroidal_flux_idx) self.name = str(id(self)) if targetlabel is None: - if label == "volume": - self.targetlabel = surface.volume - elif label == "area": - self.targetlabel = surface.area - elif label == "toroidal_flux": - self.targetlabel = self._toroidal_flux(surface) - else: - raise ValueError(f"Unsupported label: {label}") + self.targetlabel = { + "volume": surface.volume, + "area": surface.area, + "toroidal_flux": self._toroidal_flux(surface) + }[label] else: self.targetlabel = targetlabel - def _toroidal_flux(self, surf: SurfaceRZFourier) -> jnp.ndarray: - idx = self.toroidal_flux_idx - gamma = surf.gamma - curve = gamma[idx, :, :] - dl = jnp.roll(curve, -1, axis=0) - curve + def _toroidal_flux(self, surf: SurfaceRZFourier): + curve = surf.gamma[self.toroidal_flux_idx] + dl = jnp.roll(curve, -1, axis=0) - curve A_vals = vmap(self.field.A)(curve) - Adl = jnp.sum(A_vals * dl, axis=1) - tf = jnp.sum(Adl) - return tf + return jnp.sum(jnp.sum(A_vals * dl, axis=1)) + + def _build_surface_with_x(self, surface, x): + rc_safe = device_get(surface.rc) # <- 确保不是 tracer + zs_safe = device_get(surface.zs) + x_safe = device_get(x) # <- 确保不是 tracer - def _build_surface_with_x(self, surface: SurfaceRZFourier, x): s = SurfaceRZFourier( - rc=surface.rc, - zs=surface.zs, - nfp=surface.nfp, - ntheta=surface.ntheta, - nphi=surface.nphi, + rc=rc_safe, + zs=zs_safe, + nfp=int(surface.nfp), + ntheta=int(surface.ntheta), + nphi=int(surface.nphi), range_torus=surface.range_torus, - close=False + close=True ) - s.x = x + s.x = x_safe return s def objective(self, x): - surf = self._build_surface_with_x(self.surface_optimize, x) + surf = self.surface_optimize + x_old = surf.x + surf.x = x N = surf.unitnormal norm_N = jnp.linalg.norm(surf.normal, axis=2) - points_flat = surf.gamma.reshape(-1, 3) - B = B_flat = vmap(self.field.B)(points_flat) - B = B.reshape(N.shape) + points = surf.gamma.reshape(-1, 3) + B = vmap(self.field.B)(points).reshape(N.shape) B_n = jnp.sum(B * N, axis=2) norm_B = jnp.linalg.norm(B, axis=2) - result = jnp.sum(B_n**2 * norm_N) / jnp.sum(norm_B**2 * norm_N) - return result + value = jnp.sum(B_n**2 * norm_N) / jnp.sum(norm_B**2 * norm_N) + surf.x = x_old + return value + def constraint(self, x): - """ - result estimate - volume: 1e-6 - area: 1e-6 - toroidal flux: 1e-12 - """ - surf = self._build_surface_with_x(self.surface_optimize, x) - if self.label == "volume": - val = surf.volume - self.targetlabel - elif self.label == "area": - val = surf.area - self.targetlabel - elif self.label == "toroidal_flux": - val = self._toroidal_flux(surf) - self.targetlabel - else: - raise ValueError(f"Unsupported label: {self.label}") - return val + surf = self.surface_optimize + x_old = surf.x + surf.x = x + + raw_c = { + "volume": surf.volume - self.targetlabel, + "area": surf.area - self.targetlabel, + "toroidal_flux": self._toroidal_flux(surf) - self.targetlabel + }[self.label] + + c = raw_c / jnp.abs(self.targetlabel) + + surf.x = x_old + return c + def penalty_objective(self, x, constraint_weight=1.0): - """ - weight estimate - volume: 1e1 - area: 1e1 - toroidal flux: 1e10 - """ r = self.objective(x) c = self.constraint(x) - result = r + 0.5 * constraint_weight * c**2 - return jnp.asarray(result), None + return r + 0.5 * constraint_weight * c**2 + + def default_callback(self, info): + if isinstance(info, dict): + # LBFGS + it = info.get("iter", -1) + r = info["objective"] + c = info["constraint"] + print(f"[LBFGS iter {it}] objective={r:.6e} constraint={c:.3e} " + f"penalty={info['penalty']:.6e} grad_norm={info['grad_norm']:.3e}") + else: + # SLSQP + # 最小修改:用 self 属性跟踪迭代次数 + it = getattr(self, "_slsqp_iter", 0) + 1 + setattr(self, "_slsqp_iter", it) + + x = jnp.array(info) + obj = float(self.objective(x)) + cst = float(self.constraint(x)) + penalty = float(self.penalty_objective(x)) + grad_norm = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z))(x))) + print(f"[SLSQP iter {it}] objective={obj:.6e} constraint={cst:.3e} " + f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}") + + def minimize_lbfgs(self, x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e4, + return_trace=False, log_every=1, callback=None, **kwargs): + x0 = self.surface_optimize.x if x0 is None else x0 + + # ---------- 定义目标函数,返回 scalar + aux dict(全用 jnp.array) ---------- + def fn(x): + value = self.penalty_objective(x, constraint_weight) + aux = { + "objective": self.objective(x), + "constraint": self.constraint(x), + "penalty": value + } + return value, aux + + solver = LBFGS(fun=fn, maxiter=maxiter, tol=tol, has_aux=True) + state = solver.init_state(x0) + + trace = [] + x = x0 + for k in range(maxiter): + x, state = solver.update(x, state) + + info = {key: device_get(v) if isinstance(v, jnp.ndarray) else v for key, v in state.aux.items()} + info["iter"] = k + 1 + info["grad_norm"] = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z, constraint_weight))(x))) + info["error"] = float(state.error) + + if return_trace: + trace.append(info) + + if callback is None: + self.default_callback(info) + else: + callback(info) + + if state.error <= tol: + break + + x_safe = device_get(x) # 拉回 host + self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe) - def minimize_penalty_lbfgs(self, tol=1e-6, maxiter=1000, constraint_weight=1e4): - value_and_grad_fn = jax.value_and_grad( - lambda x: self.penalty_objective(x, constraint_weight), - has_aux=True - ) - solver = LBFGS( - fun=value_and_grad_fn, - value_and_grad=True, - has_aux=True, - implicit_diff=False, - tol=tol, - maxiter=maxiter - ) - x0 = self.surface_optimize.x - res = solver.run(x0) - self.surface_optimize = self._build_surface_with_x(self.surface_optimize, res.params) return { - "fun": res.state.value, - "gradient": jax.grad(lambda x: self.penalty_objective(x, constraint_weight)[0])(res.params), - "iter": res.state.iter_num, - "info": res.state, - "success": res.state.error <= tol, + "fun": float(self.penalty_objective(x, constraint_weight)), + "gradient": jnp.array(grad(lambda z: self.penalty_objective(z, constraint_weight))(x)), + "iter": k + 1, + "info": state, + "success": state.error <= tol, "s": self.surface_optimize, } - def minimize_exact_scipy_slsqp(self, tol=1e-6, maxiter=1000): - fun = lambda x: jnp.asarray(self.objective(x)).item() - jac = lambda x: jnp.asarray(jax.grad(self.objective)(x)) - con_fun = lambda x: jnp.asarray(self.constraint(x)).item() - con_jac = lambda x: jnp.asarray(jax.grad(self.constraint)(x)) - constraints = [{"type": "eq", "fun": con_fun, "jac": con_jac}] - x0 = self.surface_optimize.x + def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, **kwargs): + x0 = jnp.array(self.surface_optimize.x if x0 is None else x0) + res = minimize( - fun=fun, x0=jnp.array(x0), jac=jac, - constraints=constraints, method='SLSQP', - tol=tol, options={"maxiter": maxiter} + fun=lambda x: float(self.objective(x)), + x0=x0, + method="SLSQP", + constraints={"type": "eq", "fun": lambda x: float(self.constraint(x))}, + tol=tol, + options={"maxiter": maxiter, "disp": False}, + callback=self.default_callback ) - self.surface_optimize = self._build_surface_with_x(self.surface_optimize, res.x) + x_safe = device_get(res.x) + self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe) + return { "fun": res.fun, - "gradient": jac(res.x), + "gradient": jnp.array(jax.grad(self.objective)(res.x)), "iter": res.nit, "info": res, "success": res.success, "s": self.surface_optimize, } + def run( + self, + method: str = "SLSQP", + # 通用优化参数 + tol: float = 1e-6, + maxiter: int = 1000, + + # LBFGS 专用参数 + x0=None, + constraint_weight: float = 1e4, + return_trace: bool = False, + log_every: int = 1, - def run(self, tol=1e-6, maxiter=1000, method='SLSQP', constraint_weight=1e4): + # 可选非必须参数(注释保留) + # early_stop: bool = False, # LBFGS 可选早停策略 + # c_tol: float = 5e-7, # LBFGS 可选约束容忍度 + # rel_tol: float = 1e-5, # LBFGS 可选相对误差容忍度 + # g_tol: float = 5e-1, # LBFGS 可选梯度容忍度 + # patience: int = 50, # LBFGS 可选早停耐心值 + **kwargs # 额外参数自动传递给优化函数 + ): + """ + 统一优化入口: + - method="SLSQP": 使用 scipy.optimize.minimize 的 SLSQP 等式约束优化 + - method="LBFGS": 使用 jaxopt LBFGS 带惩罚项优化,可返回逐步 trace,支持 log + """ method_up = method.upper() - if method_up == 'SLSQP': - return self.minimize_exact_scipy_slsqp(tol=tol, maxiter=maxiter) - elif method_up == 'LBFGS': - return self.minimize_penalty_lbfgs( - tol=tol, maxiter=maxiter, constraint_weight=constraint_weight) + if method_up == "SLSQP": + return self.minimize_slsqp( + x0=x0, + tol=tol, + maxiter=maxiter, + **kwargs + ) + elif method_up == "LBFGS": + return self.minimize_lbfgs( + x0=x0, + tol=tol, + maxiter=maxiter, + constraint_weight=constraint_weight, + return_trace=return_trace, + log_every=log_every, + **kwargs + ) else: raise ValueError(f"Unknown method '{method}'") diff --git a/essos/surfaces.py b/essos/surfaces.py index 2388d78..e0bb9be 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -258,10 +258,15 @@ def volume(self): @property def area(self): - n = self.normal # (nphi, ntheta, 3) - norm_n = jnp.linalg.norm(n, axis=2) # shape: (nphi, ntheta) - avg_area = jnp.mean(norm_n) - return avg_area + n = self.normal # shape: (nphi, ntheta, 3) + norm_n = jnp.linalg.norm(n, axis=2) + + dphi = 2 * jnp.pi / self.nphi + dtheta = 2 * jnp.pi / self.ntheta + + area = jnp.sum(norm_n) * dphi * dtheta + return area + def change_resolution(self, mpol: int, ntor: int): """ diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index 917b0ea..97046f3 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -1,8 +1,9 @@ import os -number_of_processors_to_use = 3 +number_of_processors_to_use = 5 os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' import jax.numpy as jnp +from jax import device_get # ← 最小改动:用于避免 tracer 泄漏 import matplotlib.pyplot as plt from time import time @@ -12,8 +13,8 @@ from essos.fields import Vmec, BiotSavart # Load initial guess surface -ntheta=35 -nphi=36 +ntheta=25 +nphi=25 vmec = os.path.join('input_files','input.rotating_ellipse') surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) surf.change_resolution(5,5) @@ -26,26 +27,33 @@ # Load coils and construct field from essos.coils import Coils_from_json -coils = Coils_from_json("input_files/stellarator_coils.json") # from optimize_coils_vmec_surface.py +coils = Coils_from_json("input_files/stellarator_coils.json") field = BiotSavart(coils) # QFM optimization setup method = 'slsqp' # lbfgs, slsqp -label = 'area' # 'area', 'volume', 'toroidal_flux' +label = 'toroidal_flux' # 'area', 'volume', 'toroidal_flux' + +if method == 'lbfgs': + tol = 1e-4 +elif method == 'slsqp': + tol = 1e-6 + +maxiter = 10000 initial_label = None targetlabel = None if label == 'toroidal_flux': + constraint_weight = 1e-3 initial_label = toroidal_flux(surf, field) targetlabel = toroidal_flux(truevmec.surface, field) elif label == 'volume': + constraint_weight = 1e-3 initial_label = surf.volume targetlabel = truevmec.surface.volume elif label == 'area': + constraint_weight = 1e-3 initial_label = surf.area targetlabel = truevmec.surface.area -tol = 1e-6 -constraint_weight = 1e10 -maxiter = 1000 BdotN_over_B_initial = BdotN_over_B(surf, BiotSavart(coils)) @@ -54,11 +62,23 @@ print("Degrees of Freedom:", qfm.surface.x.shape[0]) start_time = time() -result = qfm.run(tol=tol, maxiter=maxiter, method=method, constraint_weight=constraint_weight) +print('start') + + +result = qfm.run( + tol=tol, + maxiter=maxiter, + method=method, + constraint_weight=constraint_weight, + log_every=10 # 仅对 LBFGS 有意义,SLSQP 保留无害 +) + +print('done') end_time = time() # Evaluate final objective and constraint -x_opt = result["s"].x +# ← 最小改动:把 x_opt 拉回 host,避免写入 surf.x 时出现 tracer 泄漏 +x_opt = device_get(result["s"].x) qfm_loss = float(jnp.asarray(qfm.objective(x_opt))) c_loss = float(jnp.asarray(qfm.constraint(x_opt))) @@ -85,18 +105,12 @@ print(f"target label: {label} target label value: {targetlabel}") print(f"Final labels -> area: {final_area:.6e}, volume: {final_volume:.6e}, toroidal_flux: {final_tf:.6e}") - # Plot surfaces fig = plt.figure(figsize=(8, 4)) ax1 = fig.add_subplot(131, projection='3d') ax2 = fig.add_subplot(132, projection='3d') ax3 = fig.add_subplot(133, projection='3d') -# coils.plot(ax=ax1, show=False) -# coils.plot(ax=ax2, show=False) -# coils.plot(ax=ax3, show=False) - - initialsurf.plot(ax=ax1, show=False) truevmec.surface.plot(ax=ax2, show=False) result['s'].plot(ax=ax3, show=False) @@ -109,69 +123,76 @@ plt.show() -# # Field line tracing -# from jax import block_until_ready -# from essos.dynamics import Tracing - -# tmax = 100000000000 -# nfieldlines_per_core = 5 -# nfieldlines = nfieldlines_per_core * number_of_processors_to_use -# R0 = jnp.linspace(12.2, 13.5, nfieldlines) -# trace_tolerance = 1e-7 -# num_steps = 60000 - -# Z0 = jnp.zeros(nfieldlines) -# phi0 = jnp.zeros(nfieldlines) -# initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T - -# time0 = time() -# tracing = block_until_ready(Tracing( -# field=field, -# model='FieldLineAdaptative', -# initial_conditions=initial_xyz, -# maxtime=tmax, -# times_to_trace=num_steps, -# atol=trace_tolerance, -# rtol=trace_tolerance -# )) -# print(f"ESSOS tracing took {time() - time0:.2f} seconds") - -# trajectories = tracing.trajectories -# traj = trajectories[0] -# R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] - -# phi_u = jnp.unwrap(phi) -# phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] -# phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] - -# theta = jnp.linspace(0, 2 * jnp.pi, 200) - -# def compute_rz_on_phi(surface, theta, phi=0.0): -# angles = jnp.outer(theta, surface.xm) - phi * surface.xn -# R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) -# Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) -# return R, Z - -# # Contours from optimized surface -# R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) -# R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) - -# # Contours from true VMEC surface -# R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) -# R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) - -# fig, ax = plt.subplots(figsize=(6, 6)) - -# tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) -# ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") -# ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") -# ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") -# ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") - -# ax.set_xlabel("R") -# ax.set_ylabel("Z") -# ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") -# ax.legend() -# ax.axis("equal") -# plt.tight_layout() -# plt.show() + + + + + + + +# Field line tracing +from jax import block_until_ready +from essos.dynamics import Tracing + +tmax = 100000000000 +nfieldlines_per_core = 5 +nfieldlines = nfieldlines_per_core * number_of_processors_to_use +R0 = jnp.linspace(12.2, 13.5, nfieldlines) +trace_tolerance = 1e-7 +num_steps = 60000 + +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T + +time0 = time() +tracing = block_until_ready(Tracing( + field=field, + model='FieldLineAdaptative', + initial_conditions=initial_xyz, + maxtime=tmax, + times_to_trace=num_steps, + atol=trace_tolerance, + rtol=trace_tolerance +)) +print(f"ESSOS tracing took {time() - time0:.2f} seconds") + +trajectories = tracing.trajectories +traj = trajectories[0] +R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] + +phi_u = jnp.unwrap(phi) +phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] +phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] + +theta = jnp.linspace(0, 2 * jnp.pi, 200) + +def compute_rz_on_phi(surface, theta, phi=0.0): + angles = jnp.outer(theta, surface.xm) - phi * surface.xn + R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) + Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) + return R, Z + +# Contours from optimized surface +R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) +R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) + +# Contours from true VMEC surface +R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) +R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) + +fig, ax = plt.subplots(figsize=(6, 6)) + +tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) +ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") +ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") +ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") +ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") + +ax.set_xlabel("R") +ax.set_ylabel("Z") +ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") +ax.legend() +ax.axis("equal") +plt.tight_layout() +plt.show() \ No newline at end of file From 393ca8f5f316db1ce8b44fbb4cbf243718d73f9e Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 20:50:48 +0800 Subject: [PATCH 11/18] delete unused code --- essos/multiobjectiveoptimizer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/essos/multiobjectiveoptimizer.py b/essos/multiobjectiveoptimizer.py index f3a669d..25c920b 100644 --- a/essos/multiobjectiveoptimizer.py +++ b/essos/multiobjectiveoptimizer.py @@ -8,9 +8,6 @@ from functools import partial from essos.coils import Coils, Curves, CreateEquallySpacedCurves from essos.fields import BiotSavart -import numpy as np -from pandas.plotting import parallel_coordinates -import pandas as pd class MultiObjectiveOptimizer: From 3de61c30e94ba02782ba329c5ea9b2053f25c069 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 20:54:31 +0800 Subject: [PATCH 12/18] comment fieldlinetracing --- examples/optimize_qfm_surface.py | 132 +++++++++++++++---------------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index 97046f3..a06439f 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -130,69 +130,69 @@ -# Field line tracing -from jax import block_until_ready -from essos.dynamics import Tracing - -tmax = 100000000000 -nfieldlines_per_core = 5 -nfieldlines = nfieldlines_per_core * number_of_processors_to_use -R0 = jnp.linspace(12.2, 13.5, nfieldlines) -trace_tolerance = 1e-7 -num_steps = 60000 - -Z0 = jnp.zeros(nfieldlines) -phi0 = jnp.zeros(nfieldlines) -initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T - -time0 = time() -tracing = block_until_ready(Tracing( - field=field, - model='FieldLineAdaptative', - initial_conditions=initial_xyz, - maxtime=tmax, - times_to_trace=num_steps, - atol=trace_tolerance, - rtol=trace_tolerance -)) -print(f"ESSOS tracing took {time() - time0:.2f} seconds") - -trajectories = tracing.trajectories -traj = trajectories[0] -R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] - -phi_u = jnp.unwrap(phi) -phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] -phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] - -theta = jnp.linspace(0, 2 * jnp.pi, 200) - -def compute_rz_on_phi(surface, theta, phi=0.0): - angles = jnp.outer(theta, surface.xm) - phi * surface.xn - R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) - Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) - return R, Z - -# Contours from optimized surface -R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) -R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) - -# Contours from true VMEC surface -R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) -R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) - -fig, ax = plt.subplots(figsize=(6, 6)) - -tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) -ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") -ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") -ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") -ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") - -ax.set_xlabel("R") -ax.set_ylabel("Z") -ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") -ax.legend() -ax.axis("equal") -plt.tight_layout() -plt.show() \ No newline at end of file +# # Field line tracing +# from jax import block_until_ready +# from essos.dynamics import Tracing + +# tmax = 100000000000 +# nfieldlines_per_core = 5 +# nfieldlines = nfieldlines_per_core * number_of_processors_to_use +# R0 = jnp.linspace(12.2, 13.5, nfieldlines) +# trace_tolerance = 1e-7 +# num_steps = 60000 + +# Z0 = jnp.zeros(nfieldlines) +# phi0 = jnp.zeros(nfieldlines) +# initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T + +# time0 = time() +# tracing = block_until_ready(Tracing( +# field=field, +# model='FieldLineAdaptative', +# initial_conditions=initial_xyz, +# maxtime=tmax, +# times_to_trace=num_steps, +# atol=trace_tolerance, +# rtol=trace_tolerance +# )) +# print(f"ESSOS tracing took {time() - time0:.2f} seconds") + +# trajectories = tracing.trajectories +# traj = trajectories[0] +# R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] + +# phi_u = jnp.unwrap(phi) +# phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] +# phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] + +# theta = jnp.linspace(0, 2 * jnp.pi, 200) + +# def compute_rz_on_phi(surface, theta, phi=0.0): +# angles = jnp.outer(theta, surface.xm) - phi * surface.xn +# R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) +# Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) +# return R, Z + +# # Contours from optimized surface +# R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) +# R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) + +# # Contours from true VMEC surface +# R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) +# R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) + +# fig, ax = plt.subplots(figsize=(6, 6)) + +# tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) +# ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") +# ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") +# ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") +# ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") + +# ax.set_xlabel("R") +# ax.set_ylabel("Z") +# ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") +# ax.legend() +# ax.axis("equal") +# plt.tight_layout() +# plt.show() \ No newline at end of file From 07dc3e08ab55015c3a08e99b7e571f845d055535 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 22:36:35 +0800 Subject: [PATCH 13/18] comment in english; change callback method --- essos/qfm.py | 114 +++++++++++++++++-------------- examples/optimize_qfm_surface.py | 6 +- 2 files changed, 67 insertions(+), 53 deletions(-) diff --git a/essos/qfm.py b/essos/qfm.py index 5df1fe5..4c78923 100644 --- a/essos/qfm.py +++ b/essos/qfm.py @@ -1,6 +1,5 @@ -# qfm_jax.py import jax -from jax import vmap, grad, value_and_grad, device_get +from jax import vmap, grad, device_get import jax.numpy as jnp from jaxopt import LBFGS from essos.surfaces import SurfaceRZFourier @@ -35,9 +34,9 @@ def _toroidal_flux(self, surf: SurfaceRZFourier): return jnp.sum(jnp.sum(A_vals * dl, axis=1)) def _build_surface_with_x(self, surface, x): - rc_safe = device_get(surface.rc) # <- 确保不是 tracer + rc_safe = device_get(surface.rc) # <- Ensure it's not a tracer zs_safe = device_get(surface.zs) - x_safe = device_get(x) # <- 确保不是 tracer + x_safe = device_get(x) # <- Ensure it's not a tracer s = SurfaceRZFourier( rc=rc_safe, @@ -65,7 +64,6 @@ def objective(self, x): surf.x = x_old return value - def constraint(self, x): surf = self.surface_optimize x_old = surf.x @@ -77,44 +75,50 @@ def constraint(self, x): "toroidal_flux": self._toroidal_flux(surf) - self.targetlabel }[self.label] - c = raw_c / jnp.abs(self.targetlabel) + c = raw_c / jnp.abs(self.targetlabel) surf.x = x_old return c - def penalty_objective(self, x, constraint_weight=1.0): r = self.objective(x) c = self.constraint(x) return r + 0.5 * constraint_weight * c**2 - def default_callback(self, info): + def _callback(self, info, printlog=True): if isinstance(info, dict): # LBFGS it = info.get("iter", -1) r = info["objective"] c = info["constraint"] - print(f"[LBFGS iter {it}] objective={r:.6e} constraint={c:.3e} " - f"penalty={info['penalty']:.6e} grad_norm={info['grad_norm']:.3e}") + penalty = info["penalty"] + grad_norm = info["grad_norm"] + + # Print logs if printlog is True + if printlog: + print(f"[LBFGS iter {it}] objective={r:.6e} constraint={c:.3e} " + f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}") else: # SLSQP - # 最小修改:用 self 属性跟踪迭代次数 it = getattr(self, "_slsqp_iter", 0) + 1 setattr(self, "_slsqp_iter", it) - x = jnp.array(info) - obj = float(self.objective(x)) - cst = float(self.constraint(x)) - penalty = float(self.penalty_objective(x)) - grad_norm = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z))(x))) - print(f"[SLSQP iter {it}] objective={obj:.6e} constraint={cst:.3e} " - f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}") + obj = float(self.objective(info)) + cst = float(self.constraint(info)) + penalty = float(self.penalty_objective(info)) + grad_norm = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z))(info))) + + # Print logs if printlog is True + if printlog: + print(f"[SLSQP iter {it}] objective={obj:.6e} constraint={cst:.3e} " + f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}") + def minimize_lbfgs(self, x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e4, - return_trace=False, log_every=1, callback=None, **kwargs): + printlog=True, **kwargs): x0 = self.surface_optimize.x if x0 is None else x0 - # ---------- 定义目标函数,返回 scalar + aux dict(全用 jnp.array) ---------- + # ---------- Define objective function, return scalar + aux dict (all use jnp.array) ---------- def fn(x): value = self.penalty_objective(x, constraint_weight) aux = { @@ -137,18 +141,13 @@ def fn(x): info["grad_norm"] = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z, constraint_weight))(x))) info["error"] = float(state.error) - if return_trace: - trace.append(info) - - if callback is None: - self.default_callback(info) - else: - callback(info) + # Ensure we call _callback for logging every step if printlog is True + self._callback(info, printlog) if state.error <= tol: break - x_safe = device_get(x) # 拉回 host + x_safe = device_get(x) # Move back to host self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe) return { @@ -160,10 +159,10 @@ def fn(x): "s": self.surface_optimize, } - - def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, **kwargs): + def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, printlog=True, **kwargs): x0 = jnp.array(self.surface_optimize.x if x0 is None else x0) + # Run the SLSQP optimizer res = minimize( fun=lambda x: float(self.objective(x)), x0=x0, @@ -171,52 +170,68 @@ def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, **kwargs): constraints={"type": "eq", "fun": lambda x: float(self.constraint(x))}, tol=tol, options={"maxiter": maxiter, "disp": False}, - callback=self.default_callback + callback=lambda x: self._callback(x, printlog) # Use internal callback directly ) + + # Store the optimized x in the surface x_safe = device_get(res.x) self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe) + # Return the result with optimization trace return { "fun": res.fun, "gradient": jnp.array(jax.grad(self.objective)(res.x)), "iter": res.nit, "info": res, "success": res.success, - "s": self.surface_optimize, + "s": self.surface_optimize } def run( self, method: str = "SLSQP", - # 通用优化参数 + # General optimization parameters tol: float = 1e-6, maxiter: int = 1000, - # LBFGS 专用参数 + # Method-specific parameters x0=None, - constraint_weight: float = 1e4, - return_trace: bool = False, - log_every: int = 1, - - # 可选非必须参数(注释保留) - # early_stop: bool = False, # LBFGS 可选早停策略 - # c_tol: float = 5e-7, # LBFGS 可选约束容忍度 - # rel_tol: float = 1e-5, # LBFGS 可选相对误差容忍度 - # g_tol: float = 5e-1, # LBFGS 可选梯度容忍度 - # patience: int = 50, # LBFGS 可选早停耐心值 - **kwargs # 额外参数自动传递给优化函数 + constraint_weight: float = 1e-3, + printlog: bool = True, + **kwargs ): """ - 统一优化入口: - - method="SLSQP": 使用 scipy.optimize.minimize 的 SLSQP 等式约束优化 - - method="LBFGS": 使用 jaxopt LBFGS 带惩罚项优化,可返回逐步 trace,支持 log + Main optimization function to run either SLSQP or LBFGS methods. + + Args: + method (str): Optimization method to use, either 'SLSQP' or 'LBFGS'. + tol (float): Tolerance for stopping criteria. + maxiter (int): Maximum number of iterations. + x0 (array): Initial guess for optimization. + constraint_weight (float): Weight for the constraint term in penalty function. + printlog (bool): Whether to print log information. + **kwargs: Additional arguments to pass to the specific optimization method. + + Returns: + dict: A dictionary containing the optimization result, including: + - 'fun': Final objective function value. + - 'gradient': Final gradient. + - 'iter': Number of iterations. + - 'info': Optimization details. + - 'success': Whether the optimization was successful. + - 's': Optimized surface. """ + + # Convert method to uppercase to standardize comparison method_up = method.upper() + + # Validate method input if method_up == "SLSQP": return self.minimize_slsqp( x0=x0, tol=tol, maxiter=maxiter, + printlog=printlog, **kwargs ) elif method_up == "LBFGS": @@ -225,8 +240,7 @@ def run( tol=tol, maxiter=maxiter, constraint_weight=constraint_weight, - return_trace=return_trace, - log_every=log_every, + printlog=printlog, **kwargs ) else: diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index a06439f..455d647 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -32,7 +32,7 @@ # QFM optimization setup method = 'slsqp' # lbfgs, slsqp -label = 'toroidal_flux' # 'area', 'volume', 'toroidal_flux' +label = 'area' # 'area', 'volume', 'toroidal_flux' if method == 'lbfgs': tol = 1e-4 @@ -70,14 +70,14 @@ maxiter=maxiter, method=method, constraint_weight=constraint_weight, - log_every=10 # 仅对 LBFGS 有意义,SLSQP 保留无害 + printlog=1 ) print('done') end_time = time() # Evaluate final objective and constraint -# ← 最小改动:把 x_opt 拉回 host,避免写入 surf.x 时出现 tracer 泄漏 + x_opt = device_get(result["s"].x) qfm_loss = float(jnp.asarray(qfm.objective(x_opt))) c_loss = float(jnp.asarray(qfm.constraint(x_opt))) From 422055974502b89ae99d30564906b637f38a7085 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 22:37:20 +0800 Subject: [PATCH 14/18] comment in english; change callback method --- essos/qfm.py | 30 ++---------------------------- examples/optimize_qfm_surface.py | 4 ++-- 2 files changed, 4 insertions(+), 30 deletions(-) diff --git a/essos/qfm.py b/essos/qfm.py index 4c78923..43ce5a0 100644 --- a/essos/qfm.py +++ b/essos/qfm.py @@ -34,9 +34,9 @@ def _toroidal_flux(self, surf: SurfaceRZFourier): return jnp.sum(jnp.sum(A_vals * dl, axis=1)) def _build_surface_with_x(self, surface, x): - rc_safe = device_get(surface.rc) # <- Ensure it's not a tracer + rc_safe = device_get(surface.rc) zs_safe = device_get(surface.zs) - x_safe = device_get(x) # <- Ensure it's not a tracer + x_safe = device_get(x) s = SurfaceRZFourier( rc=rc_safe, @@ -190,42 +190,16 @@ def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, printlog=True, **kwarg def run( self, method: str = "SLSQP", - # General optimization parameters tol: float = 1e-6, maxiter: int = 1000, - - # Method-specific parameters x0=None, constraint_weight: float = 1e-3, printlog: bool = True, **kwargs ): - """ - Main optimization function to run either SLSQP or LBFGS methods. - - Args: - method (str): Optimization method to use, either 'SLSQP' or 'LBFGS'. - tol (float): Tolerance for stopping criteria. - maxiter (int): Maximum number of iterations. - x0 (array): Initial guess for optimization. - constraint_weight (float): Weight for the constraint term in penalty function. - printlog (bool): Whether to print log information. - **kwargs: Additional arguments to pass to the specific optimization method. - - Returns: - dict: A dictionary containing the optimization result, including: - - 'fun': Final objective function value. - - 'gradient': Final gradient. - - 'iter': Number of iterations. - - 'info': Optimization details. - - 'success': Whether the optimization was successful. - - 's': Optimized surface. - """ - # Convert method to uppercase to standardize comparison method_up = method.upper() - # Validate method input if method_up == "SLSQP": return self.minimize_slsqp( x0=x0, diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py index 455d647..fea39f4 100644 --- a/examples/optimize_qfm_surface.py +++ b/examples/optimize_qfm_surface.py @@ -3,7 +3,7 @@ os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' import jax.numpy as jnp -from jax import device_get # ← 最小改动:用于避免 tracer 泄漏 +from jax import device_get import matplotlib.pyplot as plt from time import time @@ -31,7 +31,7 @@ field = BiotSavart(coils) # QFM optimization setup -method = 'slsqp' # lbfgs, slsqp +method = 'lbfgs' # lbfgs, slsqp label = 'area' # 'area', 'volume', 'toroidal_flux' if method == 'lbfgs': From c63debe873a595c467a3e9f468865e458db4e960 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 22:51:36 +0800 Subject: [PATCH 15/18] add qfm test --- tests/test_qfm.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/test_qfm.py diff --git a/tests/test_qfm.py b/tests/test_qfm.py new file mode 100644 index 0000000..1c298e3 --- /dev/null +++ b/tests/test_qfm.py @@ -0,0 +1,73 @@ +import pytest +import jax.numpy as jnp +from jax import device_get +from essos.surfaces import SurfaceRZFourier, BdotN_over_B, toroidal_flux +from essos.fields import BiotSavart +from essos.qfm import QfmSurface +from essos.coils import Coils_from_json +from unittest.mock import MagicMock + +# Mock function to simulate VMEC +def mock_vmec(): + vmec = MagicMock() + vmec.nfp = 2 + vmec.r_axis = 10.0 + vmec.surface = surface() # Assume surface function is defined elsewhere + return vmec + +# Mock surface for testing +def surface(): + surface = MagicMock() + surface.nphi = 3 + surface.ntheta = 3 + surface.gamma = jnp.ones((3, 3, 3)) + surface.unitnormal = jnp.ones((3, 3, 3)) + surface.volume = 1000 + surface.area = 500 + return surface + +# Mock field for testing +def mock_field(): + coils = Coils_from_json("input_files/stellarator_coils.json") + return BiotSavart(coils) + +# Test QfmSurface class +def test_qfm_surface(): + # Setup + vmec = mock_vmec() + field = mock_field() + surface_instance = vmec.surface + label = "toroidal_flux" + targetlabel = toroidal_flux(surface_instance, field) + qfm = QfmSurface(field=field, surface=surface_instance, label=label, targetlabel=targetlabel) + + # Check initialization + assert qfm.label == label + assert qfm.targetlabel == targetlabel + assert qfm.field == field + assert qfm.surface == surface_instance + + # Test the optimization run + method = "slsqp" # or 'lbfgs' + result = qfm.run( + tol=1e-6, + maxiter=1000, + method=method, + constraint_weight=1e-3, + log_every=10 + ) + + # Check if the optimization was successful + assert result["success"] + assert "s" in result # Check if optimized surface is returned + + # Check if final objective and constraint values are within expected range + x_opt = device_get(result["s"].x) + qfm_loss = float(jnp.asarray(qfm.objective(x_opt))) + c_loss = float(jnp.asarray(qfm.constraint(x_opt))) + assert qfm_loss < 1e-3 # Expected value for the objective + assert abs(c_loss) < 1e-3 # Expected value for the constraint + +# Run test +if __name__ == "__main__": + pytest.main() From 4c30c61570aa2c1574e4d4e0d464e69951a91f82 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 23:04:04 +0800 Subject: [PATCH 16/18] add qfm test --- tests/test_qfm.py | 175 ++++++++++++++++++++++++++++------------------ 1 file changed, 108 insertions(+), 67 deletions(-) diff --git a/tests/test_qfm.py b/tests/test_qfm.py index 1c298e3..6986422 100644 --- a/tests/test_qfm.py +++ b/tests/test_qfm.py @@ -1,73 +1,114 @@ import pytest +from unittest.mock import MagicMock import jax.numpy as jnp -from jax import device_get -from essos.surfaces import SurfaceRZFourier, BdotN_over_B, toroidal_flux -from essos.fields import BiotSavart +from jax import random +from essos.surfaces import SurfaceRZFourier from essos.qfm import QfmSurface -from essos.coils import Coils_from_json -from unittest.mock import MagicMock +from essos.fields import BiotSavart + + +class MockSurface: + def __init__(self): + self.nfp = 2 + self.ntheta = 3 + self.nphi = 3 + self.gamma = jnp.ones((self.ntheta, self.nphi, 3)) # Simulated 3D points + self.unitnormal = jnp.ones((self.ntheta, self.nphi, 3)) # Simulated unit normal + self.volume = 10.0 + self.area = 5.0 + self.x = jnp.ones(5) # Example parameter array + + def change_resolution(self, ntheta, nphi): + self.ntheta = ntheta + self.nphi = nphi + self.gamma = jnp.ones((ntheta, nphi, 3)) # Simulated 3D points + self.unitnormal = jnp.ones((ntheta, nphi, 3)) # Simulated unit normal + + +class MockField: + def A(self, point): + return jnp.array([1.0, 0.0, 0.0]) # Mock field A function + + def B(self, point): + return jnp.array([0.0, 1.0, 0.0]) # Mock field B function + + +@pytest.fixture +def mock_data(): + surface = MockSurface() + field = MockField() + return surface, field + + +def test_qfm_surface_initialization(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="area") + + assert qfm.label == "area" + assert qfm.targetlabel == surface.area + assert qfm.surface == surface + assert isinstance(qfm.surface_optimize, SurfaceRZFourier) + assert qfm.toroidal_flux_idx == 0 # Default value + + +def test_minimize_lbfgs(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="area") + + # Mock the optimizer + qfm.minimize_lbfgs = MagicMock() + qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e3) + + # Call minimize_lbfgs and check if the function was called + qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e3) + qfm.minimize_lbfgs.assert_called_once() + + +def test_minimize_slsqp(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="volume") + + # Mock the optimizer + qfm.minimize_slsqp = MagicMock() + qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) + + # Call minimize_slsqp and check if the function was called + qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) + qfm.minimize_slsqp.assert_called_once() + + +def test_callback(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="toroidal_flux") + + # Mock the callback + info = { + "objective": 1.0, + "constraint": 0.5, + "penalty": 0.2, + "grad_norm": 0.1, + "iter": 1, + } + + # Test LBFGS callback + qfm._callback(info) + qfm._callback(info, printlog=False) + + +def test_run_method(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="area") + + # Mock the run method for LBFGS + result_lbfgs = qfm.run(method="LBFGS", tol=1e-6, maxiter=1000) + assert "s" in result_lbfgs + assert result_lbfgs["success"] is True + + # Mock the run method for SLSQP + result_slsqp = qfm.run(method="SLSQP", tol=1e-6, maxiter=1000) + assert "s" in result_slsqp + assert result_slsqp["success"] is True + -# Mock function to simulate VMEC -def mock_vmec(): - vmec = MagicMock() - vmec.nfp = 2 - vmec.r_axis = 10.0 - vmec.surface = surface() # Assume surface function is defined elsewhere - return vmec - -# Mock surface for testing -def surface(): - surface = MagicMock() - surface.nphi = 3 - surface.ntheta = 3 - surface.gamma = jnp.ones((3, 3, 3)) - surface.unitnormal = jnp.ones((3, 3, 3)) - surface.volume = 1000 - surface.area = 500 - return surface - -# Mock field for testing -def mock_field(): - coils = Coils_from_json("input_files/stellarator_coils.json") - return BiotSavart(coils) - -# Test QfmSurface class -def test_qfm_surface(): - # Setup - vmec = mock_vmec() - field = mock_field() - surface_instance = vmec.surface - label = "toroidal_flux" - targetlabel = toroidal_flux(surface_instance, field) - qfm = QfmSurface(field=field, surface=surface_instance, label=label, targetlabel=targetlabel) - - # Check initialization - assert qfm.label == label - assert qfm.targetlabel == targetlabel - assert qfm.field == field - assert qfm.surface == surface_instance - - # Test the optimization run - method = "slsqp" # or 'lbfgs' - result = qfm.run( - tol=1e-6, - maxiter=1000, - method=method, - constraint_weight=1e-3, - log_every=10 - ) - - # Check if the optimization was successful - assert result["success"] - assert "s" in result # Check if optimized surface is returned - - # Check if final objective and constraint values are within expected range - x_opt = device_get(result["s"].x) - qfm_loss = float(jnp.asarray(qfm.objective(x_opt))) - c_loss = float(jnp.asarray(qfm.constraint(x_opt))) - assert qfm_loss < 1e-3 # Expected value for the objective - assert abs(c_loss) < 1e-3 # Expected value for the constraint - -# Run test if __name__ == "__main__": pytest.main() From 8a1760b5789977dd3c48e93bd962382bed36d98b Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 23:08:55 +0800 Subject: [PATCH 17/18] add qfm test --- tests/test_qfm.py | 39 ++++++++------------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/tests/test_qfm.py b/tests/test_qfm.py index 6986422..3552af9 100644 --- a/tests/test_qfm.py +++ b/tests/test_qfm.py @@ -12,25 +12,25 @@ def __init__(self): self.nfp = 2 self.ntheta = 3 self.nphi = 3 - self.gamma = jnp.ones((self.ntheta, self.nphi, 3)) # Simulated 3D points - self.unitnormal = jnp.ones((self.ntheta, self.nphi, 3)) # Simulated unit normal + self.gamma = jnp.ones((self.ntheta, self.nphi, 3)) + self.unitnormal = jnp.ones((self.ntheta, self.nphi, 3)) self.volume = 10.0 self.area = 5.0 - self.x = jnp.ones(5) # Example parameter array + self.x = jnp.ones(5) def change_resolution(self, ntheta, nphi): self.ntheta = ntheta self.nphi = nphi - self.gamma = jnp.ones((ntheta, nphi, 3)) # Simulated 3D points - self.unitnormal = jnp.ones((ntheta, nphi, 3)) # Simulated unit normal + self.gamma = jnp.ones((ntheta, nphi, 3)) + self.unitnormal = jnp.ones((ntheta, nphi, 3)) class MockField: def A(self, point): - return jnp.array([1.0, 0.0, 0.0]) # Mock field A function + return jnp.array([1.0, 0.0, 0.0]) def B(self, point): - return jnp.array([0.0, 1.0, 0.0]) # Mock field B function + return jnp.array([0.0, 1.0, 0.0]) @pytest.fixture @@ -48,18 +48,16 @@ def test_qfm_surface_initialization(mock_data): assert qfm.targetlabel == surface.area assert qfm.surface == surface assert isinstance(qfm.surface_optimize, SurfaceRZFourier) - assert qfm.toroidal_flux_idx == 0 # Default value + assert qfm.toroidal_flux_idx == 0 def test_minimize_lbfgs(mock_data): surface, field = mock_data qfm = QfmSurface(field, surface, label="area") - # Mock the optimizer qfm.minimize_lbfgs = MagicMock() qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e3) - # Call minimize_lbfgs and check if the function was called qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e3) qfm.minimize_lbfgs.assert_called_once() @@ -68,43 +66,22 @@ def test_minimize_slsqp(mock_data): surface, field = mock_data qfm = QfmSurface(field, surface, label="volume") - # Mock the optimizer qfm.minimize_slsqp = MagicMock() qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) - # Call minimize_slsqp and check if the function was called qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) qfm.minimize_slsqp.assert_called_once() -def test_callback(mock_data): - surface, field = mock_data - qfm = QfmSurface(field, surface, label="toroidal_flux") - - # Mock the callback - info = { - "objective": 1.0, - "constraint": 0.5, - "penalty": 0.2, - "grad_norm": 0.1, - "iter": 1, - } - - # Test LBFGS callback - qfm._callback(info) - qfm._callback(info, printlog=False) - def test_run_method(mock_data): surface, field = mock_data qfm = QfmSurface(field, surface, label="area") - # Mock the run method for LBFGS result_lbfgs = qfm.run(method="LBFGS", tol=1e-6, maxiter=1000) assert "s" in result_lbfgs assert result_lbfgs["success"] is True - # Mock the run method for SLSQP result_slsqp = qfm.run(method="SLSQP", tol=1e-6, maxiter=1000) assert "s" in result_slsqp assert result_slsqp["success"] is True From 08b27d2fedac3ece5eccd45eb6c7e19dc79f7ee6 Mon Sep 17 00:00:00 2001 From: zhouyebi Date: Thu, 25 Sep 2025 23:41:51 +0800 Subject: [PATCH 18/18] add qfm test --- tests/test_qfm.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_qfm.py b/tests/test_qfm.py index 3552af9..2f31e8a 100644 --- a/tests/test_qfm.py +++ b/tests/test_qfm.py @@ -9,14 +9,22 @@ class MockSurface: def __init__(self): + self.rc = jnp.array([[1., 2., 3.], + [1., 2., 3.], + [1., 2., 3.]]) + self.zs = jnp.array([[0.5, 1.5, 2.5], + [0.5, 1.5, 2.5], + [0.5, 1.5, 2.5]]) self.nfp = 2 self.ntheta = 3 self.nphi = 3 - self.gamma = jnp.ones((self.ntheta, self.nphi, 3)) - self.unitnormal = jnp.ones((self.ntheta, self.nphi, 3)) - self.volume = 10.0 - self.area = 5.0 - self.x = jnp.ones(5) + self.range_torus = "half period" + + self.area = 1.23 + self.volume = 4.56 + + self.x = jnp.ones(16) + self.gamma = jnp.ones((1, 3, 3)) # 添加这个就不会再报错了 def change_resolution(self, ntheta, nphi): self.ntheta = ntheta @@ -56,9 +64,7 @@ def test_minimize_lbfgs(mock_data): qfm = QfmSurface(field, surface, label="area") qfm.minimize_lbfgs = MagicMock() - qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e3) - - qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e3) + qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e-3) qfm.minimize_lbfgs.assert_called_once() @@ -67,8 +73,6 @@ def test_minimize_slsqp(mock_data): qfm = QfmSurface(field, surface, label="volume") qfm.minimize_slsqp = MagicMock() - qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) - qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) qfm.minimize_slsqp.assert_called_once() @@ -80,12 +84,11 @@ def test_run_method(mock_data): result_lbfgs = qfm.run(method="LBFGS", tol=1e-6, maxiter=1000) assert "s" in result_lbfgs - assert result_lbfgs["success"] is True + assert result_lbfgs["success"] == True result_slsqp = qfm.run(method="SLSQP", tol=1e-6, maxiter=1000) assert "s" in result_slsqp - assert result_slsqp["success"] is True - + assert result_slsqp["success"] == True if __name__ == "__main__": pytest.main()