diff --git a/essos/fields.py b/essos/fields.py index d9e28ee..ac0eb18 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -33,6 +33,17 @@ def B(self, points): dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") return jnp.mean(dB_sum, axis=0) + @partial(jit, static_argnames=['self']) + def A(self, points): + dif_R = (jnp.array(points)-self.gamma) + dA = self.gamma_dash / jnp.linalg.norm(dif_R, axis=-1, keepdims=True) + A_vec = jnp.sum( + jnp.mean(self.currents[:, None, None] * dA * 1e-7, axis=1), + axis=0 + ) + + return A_vec + @partial(jit, static_argnames=['self']) def B_covariant(self, points): return self.B(points) diff --git a/essos/multiobjectiveoptimizer.py b/essos/multiobjectiveoptimizer.py index f3a669d..25c920b 100644 --- a/essos/multiobjectiveoptimizer.py +++ b/essos/multiobjectiveoptimizer.py @@ -8,9 +8,6 @@ from functools import partial from essos.coils import Coils, Curves, CreateEquallySpacedCurves from essos.fields import BiotSavart -import numpy as np -from pandas.plotting import parallel_coordinates -import pandas as pd class MultiObjectiveOptimizer: diff --git a/essos/qfm.py b/essos/qfm.py new file mode 100644 index 0000000..43ce5a0 --- /dev/null +++ b/essos/qfm.py @@ -0,0 +1,221 @@ +import jax +from jax import vmap, grad, device_get +import jax.numpy as jnp +from jaxopt import LBFGS +from essos.surfaces import SurfaceRZFourier +from scipy.optimize import minimize + + +class QfmSurface: + def __init__(self, field, surface: SurfaceRZFourier, label: str, targetlabel: float = None, + toroidal_flux_idx: int = 0): + assert label in ["area", "volume", "toroidal_flux"], f"Unsupported label: {label}" + + self.field = field + self.surface = surface + self.surface_optimize = self._build_surface_with_x(surface, surface.x) + self.label = label + self.toroidal_flux_idx = int(toroidal_flux_idx) + self.name = str(id(self)) + + if targetlabel is None: + self.targetlabel = { + "volume": surface.volume, + "area": surface.area, + "toroidal_flux": self._toroidal_flux(surface) + }[label] + else: + self.targetlabel = targetlabel + + def _toroidal_flux(self, surf: SurfaceRZFourier): + curve = surf.gamma[self.toroidal_flux_idx] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(self.field.A)(curve) + return jnp.sum(jnp.sum(A_vals * dl, axis=1)) + + def _build_surface_with_x(self, surface, x): + rc_safe = device_get(surface.rc) + zs_safe = device_get(surface.zs) + x_safe = device_get(x) + + s = SurfaceRZFourier( + rc=rc_safe, + zs=zs_safe, + nfp=int(surface.nfp), + ntheta=int(surface.ntheta), + nphi=int(surface.nphi), + range_torus=surface.range_torus, + close=True + ) + s.x = x_safe + return s + + def objective(self, x): + surf = self.surface_optimize + x_old = surf.x + surf.x = x + N = surf.unitnormal + norm_N = jnp.linalg.norm(surf.normal, axis=2) + points = surf.gamma.reshape(-1, 3) + B = vmap(self.field.B)(points).reshape(N.shape) + B_n = jnp.sum(B * N, axis=2) + norm_B = jnp.linalg.norm(B, axis=2) + value = jnp.sum(B_n**2 * norm_N) / jnp.sum(norm_B**2 * norm_N) + surf.x = x_old + return value + + def constraint(self, x): + surf = self.surface_optimize + x_old = surf.x + surf.x = x + + raw_c = { + "volume": surf.volume - self.targetlabel, + "area": surf.area - self.targetlabel, + "toroidal_flux": self._toroidal_flux(surf) - self.targetlabel + }[self.label] + + c = raw_c / jnp.abs(self.targetlabel) + + surf.x = x_old + return c + + def penalty_objective(self, x, constraint_weight=1.0): + r = self.objective(x) + c = self.constraint(x) + return r + 0.5 * constraint_weight * c**2 + + def _callback(self, info, printlog=True): + if isinstance(info, dict): + # LBFGS + it = info.get("iter", -1) + r = info["objective"] + c = info["constraint"] + penalty = info["penalty"] + grad_norm = info["grad_norm"] + + # Print logs if printlog is True + if printlog: + print(f"[LBFGS iter {it}] objective={r:.6e} constraint={c:.3e} " + f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}") + else: + # SLSQP + it = getattr(self, "_slsqp_iter", 0) + 1 + setattr(self, "_slsqp_iter", it) + + obj = float(self.objective(info)) + cst = float(self.constraint(info)) + penalty = float(self.penalty_objective(info)) + grad_norm = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z))(info))) + + # Print logs if printlog is True + if printlog: + print(f"[SLSQP iter {it}] objective={obj:.6e} constraint={cst:.3e} " + f"penalty={penalty:.6e} grad_norm={grad_norm:.3e}") + + + def minimize_lbfgs(self, x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e4, + printlog=True, **kwargs): + x0 = self.surface_optimize.x if x0 is None else x0 + + # ---------- Define objective function, return scalar + aux dict (all use jnp.array) ---------- + def fn(x): + value = self.penalty_objective(x, constraint_weight) + aux = { + "objective": self.objective(x), + "constraint": self.constraint(x), + "penalty": value + } + return value, aux + + solver = LBFGS(fun=fn, maxiter=maxiter, tol=tol, has_aux=True) + state = solver.init_state(x0) + + trace = [] + x = x0 + for k in range(maxiter): + x, state = solver.update(x, state) + + info = {key: device_get(v) if isinstance(v, jnp.ndarray) else v for key, v in state.aux.items()} + info["iter"] = k + 1 + info["grad_norm"] = float(jnp.linalg.norm(grad(lambda z: self.penalty_objective(z, constraint_weight))(x))) + info["error"] = float(state.error) + + # Ensure we call _callback for logging every step if printlog is True + self._callback(info, printlog) + + if state.error <= tol: + break + + x_safe = device_get(x) # Move back to host + self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe) + + return { + "fun": float(self.penalty_objective(x, constraint_weight)), + "gradient": jnp.array(grad(lambda z: self.penalty_objective(z, constraint_weight))(x)), + "iter": k + 1, + "info": state, + "success": state.error <= tol, + "s": self.surface_optimize, + } + + def minimize_slsqp(self, x0=None, tol=1e-6, maxiter=1000, printlog=True, **kwargs): + x0 = jnp.array(self.surface_optimize.x if x0 is None else x0) + + # Run the SLSQP optimizer + res = minimize( + fun=lambda x: float(self.objective(x)), + x0=x0, + method="SLSQP", + constraints={"type": "eq", "fun": lambda x: float(self.constraint(x))}, + tol=tol, + options={"maxiter": maxiter, "disp": False}, + callback=lambda x: self._callback(x, printlog) # Use internal callback directly + ) + + # Store the optimized x in the surface + x_safe = device_get(res.x) + self.surface_optimize = self._build_surface_with_x(self.surface_optimize, x_safe) + + # Return the result with optimization trace + return { + "fun": res.fun, + "gradient": jnp.array(jax.grad(self.objective)(res.x)), + "iter": res.nit, + "info": res, + "success": res.success, + "s": self.surface_optimize + } + + def run( + self, + method: str = "SLSQP", + tol: float = 1e-6, + maxiter: int = 1000, + x0=None, + constraint_weight: float = 1e-3, + printlog: bool = True, + **kwargs + ): + + method_up = method.upper() + + if method_up == "SLSQP": + return self.minimize_slsqp( + x0=x0, + tol=tol, + maxiter=maxiter, + printlog=printlog, + **kwargs + ) + elif method_up == "LBFGS": + return self.minimize_lbfgs( + x0=x0, + tol=tol, + maxiter=maxiter, + constraint_weight=constraint_weight, + printlog=printlog, + **kwargs + ) + else: + raise ValueError(f"Unknown method '{method}'") diff --git a/essos/surfaces.py b/essos/surfaces.py index 0048e3c..e0bb9be 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -9,6 +9,16 @@ mesh = Mesh(devices(), ("dev",)) sharding = NamedSharding(mesh, PartitionSpec("dev", None)) +@partial(jit, static_argnames=['surface','field']) +def toroidal_flux(surface, field, idx=0) -> jnp.ndarray: + gamma = surface.gamma + curve = gamma[idx, :, :] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(field.A)(curve) + Adl = jnp.sum(A_vals * dl, axis=1) + tf = jnp.sum(Adl) + return tf + @partial(jit, static_argnames=['surface','field']) def B_on_surface(surface, field): ntheta = surface.ntheta @@ -235,7 +245,88 @@ def x(self): @x.setter def x(self, new_dofs): self.dofs = new_dofs - + + @property + def volume(self): + + xyz = self.gamma # shape: (nphi, ntheta, 3) + n = self.normal # shape: (nphi, ntheta, 3) + + integrand = jnp.sum(xyz * n, axis=2) # dot(x, n), shape: (nphi, ntheta) + volume = jnp.mean(integrand) / 3.0 + return volume + + @property + def area(self): + n = self.normal # shape: (nphi, ntheta, 3) + norm_n = jnp.linalg.norm(n, axis=2) + + dphi = 2 * jnp.pi / self.nphi + dtheta = 2 * jnp.pi / self.ntheta + + area = jnp.sum(norm_n) * dphi * dtheta + return area + + + def change_resolution(self, mpol: int, ntor: int): + """ + Change the values of `mpol` and `ntor`. + New Fourier coefficients are zero by default. + Old coefficients outside the new range are discarded. + """ + rc_old, zs_old = self.rc, self.zs + mpol_old, ntor_old = self.mpol, self.ntor + + rc_new = jnp.zeros((mpol, 2 * ntor + 1)) + zs_new = jnp.zeros((mpol, 2 * ntor + 1)) + + m_keep = min(mpol_old, mpol) + n_keep = min(ntor_old, ntor) + + old_slice = slice(ntor_old - n_keep, ntor_old + n_keep + 1) + new_slice = slice(ntor - n_keep, ntor + n_keep + 1) + + # Copy overlapping region + rc_new = rc_new.at[:m_keep, new_slice].set(rc_old[:m_keep, old_slice]) + zs_new = zs_new.at[:m_keep, new_slice].set(zs_old[:m_keep, old_slice]) + + # Update attributes + self.mpol, self.ntor = mpol, ntor + self.rc, self.zs = rc_new, zs_new + + # Recompute xm/xn and interpolation arrays + m1d = jnp.arange(self.mpol) + n1d = jnp.arange(-self.ntor, self.ntor + 1) + n2d, m2d = jnp.meshgrid(n1d, m1d) + self.xm = m2d.flatten()[self.ntor:] + self.xn = self.nfp * n2d.flatten()[self.ntor:] + + indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T + self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] + self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] + + # Update degrees of freedom + self.num_dofs_rc = len(jnp.ravel(self.rc)[self.ntor:]) + self.num_dofs_zs = len(jnp.ravel(self.zs)[self.ntor:]) + self._dofs = jnp.concatenate( + (jnp.ravel(self.rc)[self.ntor:], jnp.ravel(self.zs)[self.ntor:]) + ) + + # Recompute angles and geometry + self.angles = ( + jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) + - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) + ) + (self._gamma, self._gammadash_theta, self._gammadash_phi, + self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + + # Recompute AbsB if available + if hasattr(self, 'bmnc'): + self._AbsB = self._set_AbsB() + + return self + + def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.") diff --git a/examples/input_files/stellarator_coils.json b/examples/input_files/stellarator_coils.json new file mode 100644 index 0000000..2db7d7d --- /dev/null +++ b/examples/input_files/stellarator_coils.json @@ -0,0 +1 @@ +{"nfp": 2, "stellsym": true, "order": 3, "n_segments": 45, "dofs_curves": [[[11.622700935298976, 0.021162524110782358, 4.969310302793117, 0.21619524408779187, 1.0480524616370408, 0.28649150353772795, -0.28257357882380735], [2.650045608829704, 1.3195115302904592, 2.0655255644324324, -0.9222231514058897, -0.6185897653334609, -0.029257891517016514, -0.4689941563806397], [0.5736288312625168, -6.124609285728119, -0.25270656488797033, -0.7494953058315302, -0.5577484125670837, 0.41280301860103114, 0.49815152814320796]], [[8.270668252860887, 0.3665817338447269, 3.1105370020049437, -1.1258699831006203, 1.3263614203440663, 0.2102848349841804, -0.49142766630482954], [7.030418225942375, 0.6256285318976378, 4.821568953388758, -0.3255223631956292, -0.5563748261863908, 0.5699454535195534, 0.4766999014677532], [1.1730264767378422, -5.641665411926324, -0.3776722058863022, -0.22889203846752135, -0.37533249101718297, -0.13615989719970903, 0.2925684224469988]], [[2.870219295136962, 1.011171856312114, 1.0052842269429074, -1.616205601094772, 0.4568093873850295, 0.8322530493441007, -0.47523615678931685], [9.19851171039639, 0.15301426329616008, 6.079831607393782, -0.1245216794618166, -0.5264330665117136, 0.028667523169553216, 0.4360349874253642], [0.5149651564138196, -5.068527850344334, -0.2765171026569377, -0.4083657568321486, -0.061669206943173696, -0.01678957665952365, 0.08107054639495188]]], "dofs_currents": [0.9714976808550443, 1.0086510674635079, 1.019851251681448]} \ No newline at end of file diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py index 57324b2..e54161f 100644 --- a/examples/optimize_coils_vmec_surface.py +++ b/examples/optimize_coils_vmec_surface.py @@ -15,7 +15,7 @@ max_coil_curvature = 1.0 order_Fourier_series_coils = 3 number_coil_points = order_Fourier_series_coils*15 -maximum_function_evaluations = 50 +maximum_function_evaluations = 500 number_coils_per_half_field_period = 3 tolerance_optimization = 1e-5 ntheta=35 @@ -67,11 +67,11 @@ plt.tight_layout() plt.show() -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") +# Save the coils to a json file +coils_optimized.to_json("input_files/stellarator_coils.json") +# Load the coils from a json file +from essos.coils import Coils_from_json +coils = Coils_from_json("input_files/stellarator_coils.json") # # Save results in vtk format to analyze in Paraview # from essos.fields import BiotSavart diff --git a/examples/optimize_multiple_objectives.py b/examples/optimize_multiple_objectives.py index f38ed8e..78b09d6 100644 --- a/examples/optimize_multiple_objectives.py +++ b/examples/optimize_multiple_objectives.py @@ -17,12 +17,12 @@ "max_coil_curvature": 0.0, }, opt_config={ - "n_trials": 2, - "maximum_function_evaluations": 300, + "n_trials": 20, + "maximum_function_evaluations": 50, "tolerance_optimization": 1e-5, "optimizer_choices": ["adam", "amsgrad", "sgd"], "num_coils": 4, - "order_Fourier": 6, + "order_Fourier": 3, } ) @@ -46,10 +46,7 @@ print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") - - - manager.plot_pareto_fronts(z_thresh=3, save= True) manager.plot_optimization_history(z_thresh=3, save= True) manager.plot_param_importances(save= True) -manager.plot_parallel_coordinates(z_thresh=3, save= True) +# manager.plot_parallel_coordinates(z_thresh=3, save= True) diff --git a/examples/optimize_qfm_surface.py b/examples/optimize_qfm_surface.py new file mode 100644 index 0000000..fea39f4 --- /dev/null +++ b/examples/optimize_qfm_surface.py @@ -0,0 +1,198 @@ +import os +number_of_processors_to_use = 5 +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' + +import jax.numpy as jnp +from jax import device_get +import matplotlib.pyplot as plt +from time import time + +from essos.surfaces import BdotN_over_B, toroidal_flux +from essos.surfaces import SurfaceRZFourier +from essos.qfm import QfmSurface +from essos.fields import Vmec, BiotSavart + +# Load initial guess surface +ntheta=25 +nphi=25 +vmec = os.path.join('input_files','input.rotating_ellipse') +surf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) +surf.change_resolution(5,5) + +initialsurf = SurfaceRZFourier(vmec, ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) + +# Load target VMEC surface +truevmec = Vmec(os.path.join(os.path.dirname(__file__), 'input_files', 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), + ntheta=ntheta, nphi=nphi, range_torus='half period', close=True) + +# Load coils and construct field +from essos.coils import Coils_from_json +coils = Coils_from_json("input_files/stellarator_coils.json") +field = BiotSavart(coils) + +# QFM optimization setup +method = 'lbfgs' # lbfgs, slsqp +label = 'area' # 'area', 'volume', 'toroidal_flux' + +if method == 'lbfgs': + tol = 1e-4 +elif method == 'slsqp': + tol = 1e-6 + +maxiter = 10000 +initial_label = None +targetlabel = None +if label == 'toroidal_flux': + constraint_weight = 1e-3 + initial_label = toroidal_flux(surf, field) + targetlabel = toroidal_flux(truevmec.surface, field) +elif label == 'volume': + constraint_weight = 1e-3 + initial_label = surf.volume + targetlabel = truevmec.surface.volume +elif label == 'area': + constraint_weight = 1e-3 + initial_label = surf.area + targetlabel = truevmec.surface.area + +BdotN_over_B_initial = BdotN_over_B(surf, BiotSavart(coils)) + +# Initialize QFM optimizer +qfm = QfmSurface(field=field, surface=surf, label=label, targetlabel=targetlabel) + +print("Degrees of Freedom:", qfm.surface.x.shape[0]) +start_time = time() +print('start') + + +result = qfm.run( + tol=tol, + maxiter=maxiter, + method=method, + constraint_weight=constraint_weight, + printlog=1 +) + +print('done') +end_time = time() + +# Evaluate final objective and constraint + +x_opt = device_get(result["s"].x) +qfm_loss = float(jnp.asarray(qfm.objective(x_opt))) +c_loss = float(jnp.asarray(qfm.constraint(x_opt))) + +BdotN_over_B_optimized = BdotN_over_B(result['s'], BiotSavart(coils)) +print("Optimization method:", method) +print("Optimization label:", label) +print("Optimization success:", result['success']) +print(f"final qfm objective = {qfm_loss:.3e}, final constraint objective = {c_loss:.3e}") +print("Iterations:", result['iter']) +print(f"Optimization time: {end_time - start_time}") + +print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") +print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") + +initial_area = surf.area +initial_volume = surf.volume +initial_tf = toroidal_flux(surf, field) + +final_area = result['s'].area +final_volume = result['s'].volume +final_tf = toroidal_flux(result['s'], field) + +print(f"Initial labels -> area: {initial_area:.6e}, volume: {initial_volume:.6e}, toroidal_flux: {initial_tf:.6e}") +print(f"target label: {label} target label value: {targetlabel}") +print(f"Final labels -> area: {final_area:.6e}, volume: {final_volume:.6e}, toroidal_flux: {final_tf:.6e}") + +# Plot surfaces +fig = plt.figure(figsize=(8, 4)) +ax1 = fig.add_subplot(131, projection='3d') +ax2 = fig.add_subplot(132, projection='3d') +ax3 = fig.add_subplot(133, projection='3d') + +initialsurf.plot(ax=ax1, show=False) +truevmec.surface.plot(ax=ax2, show=False) +result['s'].plot(ax=ax3, show=False) + +ax1.set_title("Initial Surface") +ax2.set_title("True VMEC Surface") +ax3.set_title("Final Surface") + +plt.tight_layout() +plt.show() + + + + + + + + + +# # Field line tracing +# from jax import block_until_ready +# from essos.dynamics import Tracing + +# tmax = 100000000000 +# nfieldlines_per_core = 5 +# nfieldlines = nfieldlines_per_core * number_of_processors_to_use +# R0 = jnp.linspace(12.2, 13.5, nfieldlines) +# trace_tolerance = 1e-7 +# num_steps = 60000 + +# Z0 = jnp.zeros(nfieldlines) +# phi0 = jnp.zeros(nfieldlines) +# initial_xyz = jnp.array([R0 * jnp.cos(phi0), R0 * jnp.sin(phi0), Z0]).T + +# time0 = time() +# tracing = block_until_ready(Tracing( +# field=field, +# model='FieldLineAdaptative', +# initial_conditions=initial_xyz, +# maxtime=tmax, +# times_to_trace=num_steps, +# atol=trace_tolerance, +# rtol=trace_tolerance +# )) +# print(f"ESSOS tracing took {time() - time0:.2f} seconds") + +# trajectories = tracing.trajectories +# traj = trajectories[0] +# R, phi, Z = traj[:, 0], traj[:, 1], traj[:, 2] + +# phi_u = jnp.unwrap(phi) +# phi0_cross = jnp.where((phi_u[:-1] < 0) & (phi_u[1:] >= 0))[0] +# phi90_cross = jnp.where((phi_u[:-1] < jnp.pi / 2) & (phi_u[1:] >= jnp.pi / 2))[0] + +# theta = jnp.linspace(0, 2 * jnp.pi, 200) + +# def compute_rz_on_phi(surface, theta, phi=0.0): +# angles = jnp.outer(theta, surface.xm) - phi * surface.xn +# R = jnp.sum(surface.rmnc_interp * jnp.cos(angles), axis=1) +# Z = jnp.sum(surface.zmns_interp * jnp.sin(angles), axis=1) +# return R, Z + +# # Contours from optimized surface +# R0_opt, Z0_opt = compute_rz_on_phi(result['s'], theta, phi=0.0) +# R90_opt, Z90_opt = compute_rz_on_phi(result['s'], theta, phi=jnp.pi/2) + +# # Contours from true VMEC surface +# R0_true, Z0_true = compute_rz_on_phi(truevmec.surface, theta, phi=0.0) +# R90_true, Z90_true = compute_rz_on_phi(truevmec.surface, theta, phi=jnp.pi/2) + +# fig, ax = plt.subplots(figsize=(6, 6)) + +# tracing.poincare_plot(ax=ax, show=False, shifts=[0, jnp.pi / 2]) +# ax.plot(R0_opt, Z0_opt, color='black', linewidth=1.5, label=r"Optimized @ $\phi = 0$") +# ax.plot(R90_opt, Z90_opt, color='black', linestyle='--', linewidth=1.5, label=r"Optimized @ $\phi = \pi/2$") +# ax.plot(R0_true, Z0_true, color='blue', linewidth=1.2, label=r"True VMEC @ $\phi = 0$") +# ax.plot(R90_true, Z90_true, color='blue', linestyle='--', linewidth=1.2, label=r"True VMEC @ $\phi = \pi/2$") + +# ax.set_xlabel("R") +# ax.set_ylabel("Z") +# ax.set_title("Poincaré + Surfaces Comparison @ φ = 0 and π/2") +# ax.legend() +# ax.axis("equal") +# plt.tight_layout() +# plt.show() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index aea8487..81ded98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ netcdf4 f90nml pyevtk optuna -pandas \ No newline at end of file +pandas +jaxopt \ No newline at end of file diff --git a/tests/test_qfm.py b/tests/test_qfm.py new file mode 100644 index 0000000..2f31e8a --- /dev/null +++ b/tests/test_qfm.py @@ -0,0 +1,94 @@ +import pytest +from unittest.mock import MagicMock +import jax.numpy as jnp +from jax import random +from essos.surfaces import SurfaceRZFourier +from essos.qfm import QfmSurface +from essos.fields import BiotSavart + + +class MockSurface: + def __init__(self): + self.rc = jnp.array([[1., 2., 3.], + [1., 2., 3.], + [1., 2., 3.]]) + self.zs = jnp.array([[0.5, 1.5, 2.5], + [0.5, 1.5, 2.5], + [0.5, 1.5, 2.5]]) + self.nfp = 2 + self.ntheta = 3 + self.nphi = 3 + self.range_torus = "half period" + + self.area = 1.23 + self.volume = 4.56 + + self.x = jnp.ones(16) + self.gamma = jnp.ones((1, 3, 3)) # 添加这个就不会再报错了 + + def change_resolution(self, ntheta, nphi): + self.ntheta = ntheta + self.nphi = nphi + self.gamma = jnp.ones((ntheta, nphi, 3)) + self.unitnormal = jnp.ones((ntheta, nphi, 3)) + + +class MockField: + def A(self, point): + return jnp.array([1.0, 0.0, 0.0]) + + def B(self, point): + return jnp.array([0.0, 1.0, 0.0]) + + +@pytest.fixture +def mock_data(): + surface = MockSurface() + field = MockField() + return surface, field + + +def test_qfm_surface_initialization(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="area") + + assert qfm.label == "area" + assert qfm.targetlabel == surface.area + assert qfm.surface == surface + assert isinstance(qfm.surface_optimize, SurfaceRZFourier) + assert qfm.toroidal_flux_idx == 0 + + +def test_minimize_lbfgs(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="area") + + qfm.minimize_lbfgs = MagicMock() + qfm.minimize_lbfgs(x0=None, tol=1e-6, maxiter=1000, constraint_weight=1e-3) + qfm.minimize_lbfgs.assert_called_once() + + +def test_minimize_slsqp(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="volume") + + qfm.minimize_slsqp = MagicMock() + qfm.minimize_slsqp(x0=None, tol=1e-6, maxiter=1000) + qfm.minimize_slsqp.assert_called_once() + + + +def test_run_method(mock_data): + surface, field = mock_data + qfm = QfmSurface(field, surface, label="area") + + result_lbfgs = qfm.run(method="LBFGS", tol=1e-6, maxiter=1000) + assert "s" in result_lbfgs + assert result_lbfgs["success"] == True + + result_slsqp = qfm.run(method="SLSQP", tol=1e-6, maxiter=1000) + assert "s" in result_slsqp + assert result_slsqp["success"] == True + +if __name__ == "__main__": + pytest.main()