From 871d16938657a1b1d0eaca5bf558cac1d4435698 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 20 Apr 2025 20:12:43 +0200 Subject: [PATCH 01/63] Add new integrator scripts for field tracing and guiding center dynamics - Implemented `fo_integrators.py` for full orbit tracing with various methods and parameters. - Implemented `gc_integrators.py` for guiding center dynamics with adaptative and constant step sizes. - Enhanced `Tracing` class in `dynamics.py` to support multiple methods and step sizes. --- analysis/fo_integrators.py | 90 +++++++++++++++++++++++++++++++++ analysis/gc_integrators.py | 100 +++++++++++++++++++++++++++++++++++++ essos/dynamics.py | 36 ++++++++++--- 3 files changed, 219 insertions(+), 7 deletions(-) create mode 100644 analysis/fo_integrators.py create mode 100644 analysis/gc_integrators.py diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py new file mode 100644 index 0000000..aaa0e36 --- /dev/null +++ b/analysis/fo_integrators.py @@ -0,0 +1,90 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 16}) +from essos.fields import BiotSavart +from essos.coils import Coils_from_json +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +# import integrators +import diffrax + +# Input parameters +tmax = 1e-4 +nparticles = number_of_processors_to_use +R0 = jnp.linspace(1.23, 1.27, nparticles) +trace_tolerance = 1e-12 +num_steps = 5000 +mass=PROTON_MASS +energy=4000*ONE_EV + +print(f"dt = {tmax/num_steps:.2e}") + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils_from_json(json_file) +field = BiotSavart(coils) + +# Initialize particles +Z0 = jnp.zeros(nparticles) +phi0 = jnp.zeros(nparticles) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, field=field) + +fig, ax = plt.subplots(figsize=(7, 5)) + +method_names = ['Tsit5', 'Dopri5', 'Dopri8', 'Boris'] +methods = [getattr(diffrax, method) for method in method_names[:-1]] + ['Boris'] +for method_name, method in zip(method_names, methods): + if method_name != 'Boris': + energies = [] + tracing_times = [] + for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14]: + time0 = time() + tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, + maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + tracing_times += [time() - time0] + + print(f"Tracing with adaptative {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'adaptative {method_name}', marker='o', markersize=3, linestyle='-') + + energies = [] + tracing_times = [] + for num_steps in [5000, 10000, 20000, 50000, 100000]: + time0 = time() + tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, + stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + tracing_times += [time() - time0] + + print(f"Tracing with {method_name} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'{method_name}', marker='o', markersize=4, linestyle='-') + +from matplotlib.ticker import LogFormatterMathtext + +ax.legend() +ax.set_xlabel('Computation time (s)') +ax.set_ylabel('Relative Energy Error') +# ax.set_xscale('log') +ax.set_yscale('log') +# ax.xaxis.set_major_formatter(LogFormatterMathtext()) +ax.yaxis.set_major_formatter(LogFormatterMathtext()) +ax.tick_params(axis='x', which='both', length=0) +yticks = [1e-1, 1e-4, 1e-7, 1e-10, 1e-13, 1e-16] +ax.set_yticks(yticks) +# xticks = [1e-1, 1e-0, 1e1, 1e2] +# ax.set_xticks(xticks) + +plt.tight_layout() +plt.savefig(os.path.dirname(__file__) + '/fo_integration.pdf') +plt.show() + +## Save results in vtk format to analyze in Paraview +# tracing.to_vtk('trajectories') +# coils.to_vtk('coils') \ No newline at end of file diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py new file mode 100644 index 0000000..98ca38b --- /dev/null +++ b/analysis/gc_integrators.py @@ -0,0 +1,100 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 16}) +from essos.fields import BiotSavart +from essos.coils import Coils_from_json +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +# import integrators +import diffrax + +# Input parameters +tmax = 1e-4 +nparticles = number_of_processors_to_use +R0 = jnp.linspace(1.23, 1.27, nparticles) +trace_tolerance = 1e-12 +num_steps = 1500 +mass=PROTON_MASS +energy=4000*ONE_EV + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils_from_json(json_file) +field = BiotSavart(coils) + +# Initialize particles +Z0 = jnp.zeros(nparticles) +phi0 = jnp.zeros(nparticles) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy) + +fig, ax = plt.subplots(figsize=(7, 5)) + +for method in ['Tsit5', 'Dopri5', 'Dopri8']: + energies = [] + tracing_times = [] + for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14, 1e-16]: + time0 = time() + tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, + maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + tracing_times += [time() - time0] + + print(f"Tracing with adaptative {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'adaptative {method}', marker='o', markersize=3, linestyle='-') + + energies = [] + tracing_times = [] + for num_steps in [500, 1000, 2000, 5000, 10000]: + time0 = time() + tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, + stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + tracing_times += [time() - time0] + + print(f"Tracing with {method} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') + +# num_steps = 100 +# for method in ['Kvaerno5', 'Kvaerno4']: +# energies = [] +# tracing_times = [] +# for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14, 1e-16]: +# time0 = time() +# tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, +# stepsize="adaptative", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) +# tracing_times += [time() - time0] + +# print(f"Tracing with adaptative {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + +# energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] +# ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') + +from matplotlib.ticker import LogFormatterMathtext + +ax.legend() +ax.set_xlabel('Computation time (s)') +ax.set_ylabel('Relative Energy Error') +# ax.set_xscale('log') +ax.set_yscale('log') +# ax.xaxis.set_major_formatter(LogFormatterMathtext()) +ax.yaxis.set_major_formatter(LogFormatterMathtext()) +ax.tick_params(axis='x', which='both', length=0) +yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] +ax.set_yticks(yticks) +# xticks = [1e-1, 1e-0, 1e1, 1e2] +# ax.set_xticks(xticks) + +plt.tight_layout() +plt.savefig(os.path.dirname(__file__) + '/gc_integration.pdf') +plt.show() + +## Save results in vtk format to analyze in Paraview +# tracing.to_vtk('trajectories') +# coils.to_vtk('coils') \ No newline at end of file diff --git a/essos/dynamics.py b/essos/dynamics.py index 63929db..28fe49c 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -5,7 +5,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax import jit, vmap, tree_util, random, lax, device_put from functools import partial -from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, PIDController, Event +from diffrax import diffeqsolve, ODETerm, SaveAt, Dopri8, PIDController, Event, AbstractSolver, ConstantStepSize from essos.coils import Coils from essos.fields import BiotSavart, Vmec from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, FUSION_ALPHA_PARTICLE_ENERGY @@ -133,19 +133,29 @@ def FieldLine(t, # return lax.cond(condition, zero_derivatives, compute_derivatives, operand=None) class Tracing(): - def __init__(self, trajectories_input=None, initial_conditions=None, times=None, - field=None, model=None, maxtime: float = 1e-7, timesteps: int = 500, - tol_step_size = 1e-7, particles=None, condition=None): + def __init__(self, trajectories_input=None, initial_conditions=None, times=None, field=None, + model=None, method=None, maxtime: float = 1e-7, timesteps: int = 500, stepsize: str = "adaptative", + trajectories=None, tol_step_size = 1e-10, particles=None, condition=None): + + assert method == None or \ + method == 'Boris' or \ + issubclass(method, AbstractSolver), "Method must be None, 'Boris', or a DIFFRAX solver" + if method == 'Boris': + assert model == 'FullOrbit' or model == 'FullOrbit_Boris', "Method 'Boris' is only available for FullOrbit models" if isinstance(field, Coils): self.field = BiotSavart(field) else: self.field = field + assert stepsize in ["adaptative", "constant"], "stepsize must be 'adaptative' or 'constant'" + self.model = model + self.method = method self.initial_conditions = initial_conditions self.times = times self.maxtime = maxtime self.timesteps = timesteps + self.stepsize = stepsize self.tol_step_size = tol_step_size self._trajectories = trajectories_input self.particles = particles @@ -160,6 +170,8 @@ def condition_Vmec(t, y, args, **kwargs): self.ODE_term = ODETerm(GuidingCenter) self.args = (self.field, self.particles) self.initial_conditions = jnp.concatenate([self.particles.initial_xyz, self.particles.initial_vparallel[:, None]], axis=1) + if self.method is None: + self.method = Dopri8 elif model == 'FullOrbit' or model == 'FullOrbit_Boris': self.ODE_term = ODETerm(Lorentz) self.args = (self.field, self.particles) @@ -168,9 +180,15 @@ def condition_Vmec(t, y, args, **kwargs): self.initial_conditions = jnp.concatenate([self.particles.initial_xyz_fullorbit, self.particles.initial_vxvyvz], axis=1) if field is None: raise ValueError("Field parameter is required for FullOrbit model") + if self.method is None: + self.method = 'Boris' elif model == 'FieldLine': self.ODE_term = ODETerm(FieldLine) self.args = self.field + if self.method is None: + self.method = Dopri8 + else: + raise ValueError("Model must be one of: 'GuidingCenter', 'FullOrbit', 'FullOrbit_Boris', or 'FieldLine'") if self.times is None: self.times = jnp.linspace(0, self.maxtime, self.timesteps) @@ -215,7 +233,7 @@ def trace(self): @jit def compute_trajectory(initial_condition) -> jnp.ndarray: # initial_condition = initial_condition[0] - if self.model == 'FullOrbit_Boris': + if self.model == 'FullOrbit_Boris' or self.method == 'Boris': dt=self.maxtime / self.timesteps def update_state(state, _): # def update_fn(state): @@ -239,18 +257,22 @@ def update_state(state, _): else: import warnings warnings.simplefilter("ignore", category=FutureWarning) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation + if self.stepsize == "adaptative": + controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size) + elif self.stepsize == "constant": + controller = ConstantStepSize() trajectory = diffeqsolve( self.ODE_term, t0=0.0, t1=self.maxtime, dt0=self.maxtime / self.timesteps, y0=initial_condition, - solver=Tsit5(), + solver=self.method(), args=self.args, saveat=SaveAt(ts=self.times), throw=False, # adjoint=DirectAdjoint(), - stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size), + stepsize_controller = controller, max_steps=10000000000, event = Event(self.condition) ).ys From f86f5ee493fbb52854d6aeeffd7fee33a2c101e0 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 20 Apr 2025 20:16:09 +0200 Subject: [PATCH 02/63] Improve loss functions speed --- essos/objective_functions.py | 39 ++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/essos/objective_functions.py b/essos/objective_functions.py index 42c94ea..53ab751 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -93,29 +93,37 @@ def loss_particle_drift(field, particles, maxtime=1e-5, num_steps=300, trace_tol tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, timesteps=num_steps, tol_step_size=trace_tolerance) trajectories = tracing.trajectories + R_axis = jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(field.coils.dofs_curves))) - radial_factor = jnp.sqrt(jnp.square(trajectories[:,:,0])+jnp.square(trajectories[:,:,1]))-R_axis + radial_factor = jnp.sqrt(trajectories[:, :, 0]**2 + trajectories[:,:,1]**2)-R_axis vertical_factor = trajectories[:,:,2] - radial_drift=jnp.square(radial_factor)+jnp.square(vertical_factor) - radial_drift=jnp.sum(jnp.diff(radial_drift,axis=1),axis=1)/num_steps - angular_drift = jnp.arctan2(trajectories[:, :, 2]+1e-10, jnp.sqrt(trajectories[:, :, 0]**2+trajectories[:, :, 1]**2)-R_axis) - angular_drift=(jnp.sum(jnp.diff(angular_drift,axis=1),axis=1))/num_steps + + radial_drift = radial_factor**2 + vertical_factor**2 + # radial_drift = jnp.sqrt(radial_drift) + radial_drift = jnp.mean(jnp.diff(radial_drift, axis=1), axis=1) + + angular_drift = jnp.arctan2(vertical_factor, radial_factor+1e-10) + angular_drift = jnp.mean(jnp.diff(angular_drift, axis=1), axis=1) + return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)), jnp.ravel(jnp.abs(vertical_factor)))) + # return jnp.concatenate((jnp.ravel(jnp.abs(angular_drift)), jnp.ravel(jnp.abs(radial_drift)))) # @partial(jit, static_argnums=(0)) -def loss_coil_length(field): - return jnp.ravel(field.coils.length) +def loss_coil_length(field, max_coil_length): + coil_length = jnp.ravel(field.coils.length) + return jnp.maximum(coil_length-max_coil_length, 0) # @partial(jit, static_argnums=(0)) -def loss_coil_curvature(field): - return jnp.mean(field.coils.curvature, axis=1) +def loss_coil_curvature(field, max_coil_curvature): + coil_curvature = jnp.mean(field.coils.curvature, axis=1) + return jnp.maximum(coil_curvature-max_coil_curvature, 0) # @partial(jit, static_argnums=(0, 1)) -def loss_normB_axis(field, npoints=15): +def loss_normB_axis(field, target_B_on_axis, npoints=15): R_axis = jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(field.coils.dofs_curves))) phi_array = jnp.linspace(0, 2 * jnp.pi, npoints) B_axis = vmap(lambda phi: field.AbsB(jnp.array([R_axis * jnp.cos(phi), R_axis * jnp.sin(phi), 0])))(phi_array) - return B_axis + return jnp.abs(B_axis-target_B_on_axis) @partial(jit, static_argnums=(1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, currents_scale, nfp, max_coil_curvature=0.5, @@ -130,12 +138,9 @@ def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, curr field = BiotSavart(coils) particles_drift_loss = loss_particle_drift(field, particles, maxtime, num_steps, trace_tolerance, model=model) - normB_axis = loss_normB_axis(field) - normB_axis_loss = jnp.abs(normB_axis-target_B_on_axis) - coil_length = loss_coil_length(field) - coil_length_loss = jnp.array([jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])]))]) - coil_curvature = loss_coil_curvature(field) - coil_curvature_loss = jnp.array([jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])]))]) + normB_axis_loss = loss_normB_axis(field, target_B_on_axis) + coil_length_loss = loss_coil_length(field, max_coil_length) + coil_curvature_loss = loss_coil_curvature(field, max_coil_curvature) loss = jnp.concatenate((normB_axis_loss, coil_length_loss, particles_drift_loss, coil_curvature_loss)) return jnp.sum(loss) From 408c2c5109aff928b6ed9f5849873f97d80db418 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 21 Apr 2025 17:13:48 +0200 Subject: [PATCH 03/63] Refine integrator analysis --- analysis/fo_integrators.py | 16 ++++++++++++---- analysis/gc_integrators.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index aaa0e36..638f81f 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -42,7 +42,12 @@ if method_name != 'Boris': energies = [] tracing_times = [] - for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14]: + for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-13, 1e-14]: + if method_name == 'Dopri8': + if trace_tolerance == 1e-13: + trace_tolerance = 1e-14 + elif trace_tolerance == 1e-14: + trace_tolerance = 1e-15 time0 = time() tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) @@ -55,7 +60,9 @@ energies = [] tracing_times = [] - for num_steps in [5000, 10000, 20000, 50000, 100000]: + for num_steps in [100000, 200000, 300000, 500000, 1000000]: + if method_name == 'Boris' or method_name == 'Dopri8': + num_steps //= 10 time0 = time() tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) @@ -75,9 +82,10 @@ ax.set_yscale('log') # ax.xaxis.set_major_formatter(LogFormatterMathtext()) ax.yaxis.set_major_formatter(LogFormatterMathtext()) -ax.tick_params(axis='x', which='both', length=0) -yticks = [1e-1, 1e-4, 1e-7, 1e-10, 1e-13, 1e-16] +ax.tick_params(axis='x', which='minor', length=0) +yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] ax.set_yticks(yticks) +ax.set_ylim(top=1e-6) # xticks = [1e-1, 1e-0, 1e1, 1e2] # ax.set_xticks(xticks) diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index 98ca38b..5a9c9ad 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -85,7 +85,7 @@ ax.set_yscale('log') # ax.xaxis.set_major_formatter(LogFormatterMathtext()) ax.yaxis.set_major_formatter(LogFormatterMathtext()) -ax.tick_params(axis='x', which='both', length=0) +ax.tick_params(axis='x', which='minor', length=0) yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] ax.set_yticks(yticks) # xticks = [1e-1, 1e-0, 1e1, 1e2] From 6ff79dbf1af4baa8ce66622da01663f1c127b1b4 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 27 Apr 2025 12:59:03 +0200 Subject: [PATCH 04/63] Optimize rotation matrix computation in RotatedCurve function --- essos/coils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index cc1e715..a0cfc08 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -492,10 +492,7 @@ def RotatedCurve(curve, phi, flip): [jnp.sin(phi), jnp.cos(phi), 0], [0, 0, 1]]).T if flip: - rotmat = rotmat @ jnp.array( - [[1, 0, 0], - [0, -1, 0], - [0, 0, -1]]) + rotmat = rotmat @ jnp.diag(jnp.array([1, -1, -1])) return curve @ rotmat @partial(jit, static_argnames=['nfp', 'stellsym']) From fdb1508931be506bd4d6596cd772f1ef29edde53 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 28 Apr 2025 20:07:12 +0200 Subject: [PATCH 05/63] Edit fo and gc integration analysis plots --- analysis/fo_integrators.py | 13 +++++-------- analysis/gc_integrators.py | 13 +++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index 638f81f..c2b0663 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -4,7 +4,7 @@ from time import time import jax.numpy as jnp import matplotlib.pyplot as plt -plt.rcParams.update({'font.size': 16}) +plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart from essos.coils import Coils_from_json from essos.constants import PROTON_MASS, ONE_EV @@ -34,7 +34,7 @@ initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, field=field) -fig, ax = plt.subplots(figsize=(7, 5)) +fig, ax = plt.subplots(figsize=(9, 6)) method_names = ['Tsit5', 'Dopri5', 'Dopri8', 'Boris'] methods = [getattr(diffrax, method) for method in method_names[:-1]] + ['Boris'] @@ -80,17 +80,14 @@ ax.set_ylabel('Relative Energy Error') # ax.set_xscale('log') ax.set_yscale('log') -# ax.xaxis.set_major_formatter(LogFormatterMathtext()) -ax.yaxis.set_major_formatter(LogFormatterMathtext()) ax.tick_params(axis='x', which='minor', length=0) yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] ax.set_yticks(yticks) ax.set_ylim(top=1e-6) -# xticks = [1e-1, 1e-0, 1e1, 1e2] -# ax.set_xticks(xticks) - +plt.grid() plt.tight_layout() -plt.savefig(os.path.dirname(__file__) + '/fo_integration.pdf') +plt.savefig(os.path.join(os.path.dirname(__file__), 'fo_integration.pdf')) +plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'fo_integration.pdf')) plt.show() ## Save results in vtk format to analyze in Paraview diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index 5a9c9ad..e699dc8 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -4,7 +4,7 @@ from time import time import jax.numpy as jnp import matplotlib.pyplot as plt -plt.rcParams.update({'font.size': 16}) +plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart from essos.coils import Coils_from_json from essos.constants import PROTON_MASS, ONE_EV @@ -32,7 +32,7 @@ initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy) -fig, ax = plt.subplots(figsize=(7, 5)) +fig, ax = plt.subplots(figsize=(9, 6)) for method in ['Tsit5', 'Dopri5', 'Dopri8']: energies = [] @@ -83,16 +83,13 @@ ax.set_ylabel('Relative Energy Error') # ax.set_xscale('log') ax.set_yscale('log') -# ax.xaxis.set_major_formatter(LogFormatterMathtext()) -ax.yaxis.set_major_formatter(LogFormatterMathtext()) ax.tick_params(axis='x', which='minor', length=0) yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] ax.set_yticks(yticks) -# xticks = [1e-1, 1e-0, 1e1, 1e2] -# ax.set_xticks(xticks) - +plt.grid() plt.tight_layout() -plt.savefig(os.path.dirname(__file__) + '/gc_integration.pdf') +plt.savefig(os.path.join(os.path.dirname(__file__), 'gc_integration.pdf')) +plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'gc_integration.pdf')) plt.show() ## Save results in vtk format to analyze in Paraview From 415bbaedc93af232af74e7797104c5ca34709e59 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 28 Apr 2025 20:11:13 +0200 Subject: [PATCH 06/63] Improve loss functions computational efficiency --- essos/objective_functions.py | 66 +++++++++++++++++------------------- essos/optimization.py | 3 +- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/essos/objective_functions.py b/essos/objective_functions.py index 53ab751..b5bd2b6 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -9,12 +9,12 @@ from essos.coils import Curves, Coils from essos.optimization import new_nearaxis_from_x_and_old_nearaxis -@partial(jit, static_argnums=(1, 4, 5, 6, 7, 8)) -def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, max_coil_length=42, +@partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8)) +def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, n_segments=60, stellsym=True, max_coil_curvature=0.1): - len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) - dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], (dofs_curves.shape)) - dofs_currents = x[len_dofs_curves_ravelled:] + dofs_curves_size = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] + dofs_curves = jnp.reshape(x[:dofs_curves_size], (dofs_curves_shape)) + dofs_currents = x[dofs_curves_size:] curves = Curves(dofs_curves, n_segments, nfp, stellsym) coils = Coils(curves=curves, currents=dofs_currents*currents_scale) @@ -32,13 +32,11 @@ def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, gradB_nearaxis = field_nearaxis.grad_B_axis.T gradB_coils = vmap(field.dB_by_dX)(points.T) - coil_length = loss_coil_length(field) - coil_curvature = loss_coil_curvature(field) - B_difference_loss = jnp.sum(jnp.abs(jnp.array(B_coils)-jnp.array(B_nearaxis))) gradB_difference_loss = jnp.sum(jnp.abs(jnp.array(gradB_coils)-jnp.array(gradB_nearaxis))) - coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) + coil_length_loss = 1e3*jnp.max(loss_coil_length(field, max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(field, max_coil_curvature)) + return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss @@ -58,22 +56,19 @@ def difference_B_gradB_onaxis(nearaxis_field, coils_field): return jnp.array(B_coils)-jnp.array(B_nearaxis), jnp.array(gradB_coils)-jnp.array(gradB_nearaxis) -@partial(jit, static_argnums=(1, 4, 5, 6, 7, 8)) -def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, max_coil_length=42, +@partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8)) +def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, n_segments=60, stellsym=True, max_coil_curvature=0.1): - len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) - dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], (dofs_curves.shape)) + dofs_curves_size = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] + dofs_curves = jnp.reshape(x[:dofs_curves_size], (dofs_curves_shape)) len_dofs_nearaxis = len(field_nearaxis.x) - dofs_currents = x[len_dofs_curves_ravelled:-len_dofs_nearaxis] + dofs_currents = x[dofs_curves_size:-len_dofs_nearaxis] curves = Curves(dofs_curves, n_segments, nfp, stellsym) coils = Coils(curves=curves, currents=dofs_currents*currents_scale) field = BiotSavart(coils) new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len_dofs_nearaxis:], field_nearaxis) - coil_length = loss_coil_length(field) - coil_curvature = loss_coil_curvature(field) - elongation = new_field_nearaxis.elongation iota = new_field_nearaxis.iota @@ -81,8 +76,8 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, B_difference_loss = 3*jnp.sum(jnp.abs(B_difference)) gradB_difference_loss = jnp.sum(jnp.abs(gradB_difference)) - coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) + coil_length_loss = 1e3*jnp.max(loss_coil_length(field, max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(field, max_coil_curvature)) elongation_loss = jnp.sum(jnp.abs(elongation)) iota_loss = 30/jnp.abs(iota) @@ -100,13 +95,16 @@ def loss_particle_drift(field, particles, maxtime=1e-5, num_steps=300, trace_tol radial_drift = radial_factor**2 + vertical_factor**2 # radial_drift = jnp.sqrt(radial_drift) + # print("radial_drift", radial_drift) radial_drift = jnp.mean(jnp.diff(radial_drift, axis=1), axis=1) angular_drift = jnp.arctan2(vertical_factor, radial_factor+1e-10) angular_drift = jnp.mean(jnp.diff(angular_drift, axis=1), axis=1) - return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)), jnp.ravel(jnp.abs(vertical_factor)))) + # return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)), jnp.ravel(jnp.abs(vertical_factor)))) + # return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)))) # return jnp.concatenate((jnp.ravel(jnp.abs(angular_drift)), jnp.ravel(jnp.abs(radial_drift)))) + return jnp.concatenate((jnp.ravel(jnp.abs(vertical_factor)),)) # @partial(jit, static_argnums=(0)) def loss_coil_length(field, max_coil_length): @@ -125,13 +123,13 @@ def loss_normB_axis(field, target_B_on_axis, npoints=15): B_axis = vmap(lambda phi: field.AbsB(jnp.array([R_axis * jnp.cos(phi), R_axis * jnp.sin(phi), 0])))(phi_array) return jnp.abs(B_axis-target_B_on_axis) -@partial(jit, static_argnums=(1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) -def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, currents_scale, nfp, max_coil_curvature=0.5, +@partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) +def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves_shape, currents_scale, nfp, max_coil_curvature=0.5, n_segments=60, stellsym=True, target_B_on_axis=5.7, maxtime=1e-5, max_coil_length=22, num_steps=30, trace_tolerance=1e-5, model='GuidingCenter'): - len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) - dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], dofs_curves.shape) - dofs_currents = x[len_dofs_curves_ravelled:] + dofs_curves_size = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] + dofs_curves = jnp.reshape(x[:dofs_curves_size], (dofs_curves_shape)) + dofs_currents = x[dofs_curves_size:] curves = Curves(dofs_curves, n_segments, nfp, stellsym) coils = Coils(curves=curves, currents=dofs_currents*currents_scale) @@ -145,23 +143,21 @@ def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, curr loss = jnp.concatenate((normB_axis_loss, coil_length_loss, particles_drift_loss, coil_curvature_loss)) return jnp.sum(loss) -@partial(jit, static_argnums=(1, 4, 5, 6, 7)) -def loss_BdotN(x, vmec, dofs_curves, currents_scale, nfp, max_coil_length=42, +@partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8)) +def loss_BdotN(x, vmec, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, n_segments=60, stellsym=True, max_coil_curvature=0.1): - len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) - dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], (dofs_curves.shape)) - dofs_currents = x[len_dofs_curves_ravelled:] + dofs_curves_size = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] + dofs_curves = jnp.reshape(x[:dofs_curves_size], (dofs_curves_shape)) + dofs_currents = x[dofs_curves_size:] curves = Curves(dofs_curves, n_segments, nfp, stellsym) coils = Coils(curves=curves, currents=dofs_currents*currents_scale) field = BiotSavart(coils) bdotn_over_b = BdotN_over_B(vmec.surface, field) - coil_length = loss_coil_length(field) - coil_curvature = loss_coil_curvature(field) + coil_length_loss = jnp.max(loss_coil_length(field, max_coil_length)) + coil_curvature_loss = jnp.max(loss_coil_curvature(field, max_coil_curvature)) bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) - coil_length_loss = jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss \ No newline at end of file diff --git a/essos/optimization.py b/essos/optimization.py index 9e2d5f6..bd134a1 100644 --- a/essos/optimization.py +++ b/essos/optimization.py @@ -31,7 +31,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- dofs_curves_shape = coils.dofs_curves.shape currents_scale = coils.currents_scale - loss_partial = partial(func, dofs_curves=coils.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym, **kwargs) + loss_partial = partial(func, dofs_curves_shape=coils.dofs_curves.shape, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym, **kwargs) ## Without JAX gradients, using finite differences # result = least_squares(loss_partial, x0=initial_dofs, verbose=2, diff_step=1e-4, @@ -43,6 +43,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- # result = least_squares(loss_partial, x0=initial_dofs, verbose=2, jac=jac_loss_partial, # ftol=tolerance_optimization, gtol=tolerance_optimization, # xtol=1e-14, max_nfev=maximum_function_evaluations) + print("Starting optimization") result = minimize(loss_partial, x0=initial_dofs, jac=jac_loss_partial, method=method, tol=tolerance_optimization, options={'maxiter': maximum_function_evaluations, 'disp': True, 'gtol': 1e-14, 'ftol': 1e-14}) From f38fdea2e774d3aad8e6a24f9bfb1c491c3ee6e7 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 28 Apr 2025 20:12:05 +0200 Subject: [PATCH 07/63] Feature: gradient analysis --- analysis/gradients.py | 110 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 analysis/gradients.py diff --git a/analysis/gradients.py b/analysis/gradients.py new file mode 100644 index 0000000..57d0552 --- /dev/null +++ b/analysis/gradients.py @@ -0,0 +1,110 @@ +import os +from functools import partial +number_of_processors_to_use = 8 # Parallelization, this should divide ntheta*nphi +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +from jax import jit, grad +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.coils import Coils, CreateEquallySpacedCurves +from essos.fields import Vmec +from essos.objective_functions import loss_BdotN + +# Optimization parameters +max_coil_length = 40 +max_coil_curvature = 0.5 +order_Fourier_series_coils = 6 +number_coil_points = order_Fourier_series_coils*10 +maximum_function_evaluations = 300 +number_coils_per_half_field_period = 4 +tolerance_optimization = 1e-5 +ntheta=32 +nphi=32 + +# Initialize VMEC field +vmec = Vmec(os.path.join(os.path.dirname(__file__), '../examples/input_files', + 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), + ntheta=ntheta, nphi=nphi, range_torus='half period') + +# Initialize coils +current_on_each_coil = 1 +number_of_field_periods = vmec.nfp +major_radius_coils = vmec.r_axis +minor_radius_coils = vmec.r_axis/1.5 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) + +coils = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + + +loss_partial = partial(loss_BdotN, dofs_curves_shape=coils.dofs_curves.shape, currents_scale=coils.currents_scale, + nfp=coils.nfp, n_segments=coils.n_segments, stellsym=coils.stellsym, + vmec=vmec, max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature) + +grad_loss_partial = jit(grad(loss_partial)) + +time0 = time() +grad_loss = grad_loss_partial(coils.x) +print(f"Gradient took {time()-time0:.4f} seconds") + +time0 = time() +grad_loss_comp = grad_loss_partial(coils.x) +print(f"Compiled gradient took {time()-time0:.4f} seconds") + +# Parameter to perturb +param = 42 + +# Set the possible perturbations +h_list = jnp.arange(-10, -1, 0.5) +h_list = 10.**h_list + +# Number of orders for finite differences +fd_loss = jnp.zeros(4) + +# Array to store the relative difference +fd_diff = jnp.zeros((fd_loss.size, h_list.size)) + +# Compute finite differences +for index, h in enumerate(h_list): + delta = jnp.zeros(coils.x.shape) + delta = delta.at[param].set(h) + + # 1st order finite differences + fd_loss = fd_loss.at[0].set((loss_partial(coils.x+delta)-loss_partial(coils.x))/h) + # 2nd order finite differences + fd_loss = fd_loss.at[1].set((loss_partial(coils.x+delta)-loss_partial(coils.x-delta))/(2*h)) + # 4th order finite differences + fd_loss = fd_loss.at[2].set((loss_partial(coils.x-2*delta)-8*loss_partial(coils.x-delta)+8*loss_partial(coils.x+delta)-loss_partial(coils.x+2*delta))/(12*h)) + # 6th order finite differences + fd_loss = fd_loss.at[3].set((loss_partial(coils.x+3*delta)-9*loss_partial(coils.x+2*delta)+45*loss_partial(coils.x+delta)-45*loss_partial(coils.x-delta)+9*loss_partial(coils.x-2*delta)-loss_partial(coils.x-3*delta))/(60*h)) + + fd_diff_h = jnp.abs((grad_loss[param]-fd_loss)/grad_loss[param]) + fd_diff = fd_diff.at[:, index].set(fd_diff_h) + + +# plot relative difference +plt.figure(figsize=(9, 6)) +plt.plot(h_list, fd_diff[0], "o-", label=f'1st order', clip_on=False) +plt.plot(h_list, fd_diff[1], "^-", label=f'2nd order', clip_on=False) +plt.plot(h_list, fd_diff[2], "*-", label=f'4th order', clip_on=False) +plt.plot(h_list, fd_diff[3], "s-", label=f'6th order', clip_on=False) +plt.legend() +plt.xlabel('Finite differences stepsize h') +plt.ylabel('Relative difference') +plt.xscale('log') +plt.yscale('log') +plt.xlim(jnp.min(h_list), jnp.max(h_list)) +plt.grid(which='both', axis='x') +plt.grid(which='major', axis='y') +for spine in plt.gca().spines.values(): + spine.set_zorder(0) +# plt.yticks([1e-11, 1e-9, 1e-7, 1e-5, 1e-3]) +# plt.gca().yaxis.set_minor_locator(plt.NullLocator()) +plt.tight_layout() +plt.savefig(os.path.join(os.path.dirname(__file__), 'gradients.pdf')) +plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" ,'gradients.pdf')) +plt.show() \ No newline at end of file From 079cdb7d294675d4dbbaa5774b03edcf2df122f7 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 4 May 2025 18:51:24 +0200 Subject: [PATCH 08/63] Add block_until_ready to integrators analysis --- analysis/fo_integrators.py | 3 +++ analysis/gc_integrators.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index c2b0663..48802ee 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -2,6 +2,7 @@ number_of_processors_to_use = 1 # Parallelization, this should divide nparticles os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time +from jax import block_until_ready import jax.numpy as jnp import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) @@ -51,6 +52,7 @@ time0 = time() tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + block_until_ready(tracing) tracing_times += [time() - time0] print(f"Tracing with adaptative {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") @@ -66,6 +68,7 @@ time0 = time() tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + block_until_ready(tracing) tracing_times += [time() - time0] print(f"Tracing with {method_name} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index e699dc8..29d8298 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -2,6 +2,7 @@ number_of_processors_to_use = 1 # Parallelization, this should divide nparticles os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time +from jax import block_until_ready import jax.numpy as jnp import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) @@ -41,6 +42,7 @@ time0 = time() tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + block_until_ready(tracing) tracing_times += [time() - time0] print(f"Tracing with adaptative {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") @@ -54,6 +56,7 @@ time0 = time() tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + block_until_ready(tracing) tracing_times += [time() - time0] print(f"Tracing with {method} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") From fe5d81d610eba26872f1b5a2223f9b5556775e15 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 4 May 2025 18:51:41 +0200 Subject: [PATCH 09/63] Minor tweaks to gradient analysis plot --- analysis/gradients.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/analysis/gradients.py b/analysis/gradients.py index 57d0552..8cf8798 100644 --- a/analysis/gradients.py +++ b/analysis/gradients.py @@ -59,7 +59,7 @@ param = 42 # Set the possible perturbations -h_list = jnp.arange(-10, -1, 0.5) +h_list = jnp.arange(-10, -1.9, 1/3) h_list = 10.**h_list # Number of orders for finite differences @@ -88,10 +88,10 @@ # plot relative difference plt.figure(figsize=(9, 6)) -plt.plot(h_list, fd_diff[0], "o-", label=f'1st order', clip_on=False) -plt.plot(h_list, fd_diff[1], "^-", label=f'2nd order', clip_on=False) -plt.plot(h_list, fd_diff[2], "*-", label=f'4th order', clip_on=False) -plt.plot(h_list, fd_diff[3], "s-", label=f'6th order', clip_on=False) +plt.plot(h_list, fd_diff[0], "o-", label=f'1st order', clip_on=False, linewidth=2.5) +plt.plot(h_list, fd_diff[1], "^-", label=f'2nd order', clip_on=False, linewidth=2.5) +plt.plot(h_list, fd_diff[2], "*-", label=f'4th order', clip_on=False, linewidth=2.5) +plt.plot(h_list, fd_diff[3], "s-", label=f'6th order', clip_on=False, linewidth=2.5) plt.legend() plt.xlabel('Finite differences stepsize h') plt.ylabel('Relative difference') From 9595ab44ef01dfc1ab2bcad328f75252810d62d6 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 4 May 2025 18:52:21 +0200 Subject: [PATCH 10/63] Create Poincare Plots analysis --- analysis/poincare_plots.py | 143 +++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 analysis/poincare_plots.py diff --git a/analysis/poincare_plots.py b/analysis/poincare_plots.py new file mode 100644 index 0000000..79fc5e8 --- /dev/null +++ b/analysis/poincare_plots.py @@ -0,0 +1,143 @@ +import os +from functools import partial +number_of_processors_to_use = 4 # Parallelization, this should divide ntheta*nphi +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +from jax import jit, grad, block_until_ready +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.coils import Coils_from_json +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE +from essos.fields import BiotSavart +from essos.dynamics import Tracing, Particles + + +# Input parameters +tmax_fl = 50000 +tmax_gc = 5e-3 +tmax_fo = 1e-3 + +nparticles = number_of_processors_to_use*8 +nfieldlines = number_of_processors_to_use*8 +s = 0.25 # s-coordinate: flux surface label +trace_tolerance = 1e-14 +dt_fo = 1e-10 +dt_gc = 1e-7 +timesteps_gc = int(tmax_gc/dt_gc) +timesteps_fo = int(tmax_fo/dt_fo) +mass = PROTON_MASS +energy = 4000*ONE_EV +print("cyclotron period:", 1/(ELEMENTARY_CHARGE*5/mass)) + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils_from_json(json_file) +field = BiotSavart(coils) + +R0_fieldlines = jnp.linspace(1.21, 1.41, nfieldlines) +R0_particles= jnp.linspace(1.21, 1.41, nparticles) +Z0_fieldlines = jnp.zeros(nfieldlines) +Z0_particles = jnp.zeros(nparticles) +phi0_fieldlines = jnp.zeros(nfieldlines) +phi0_particles = jnp.zeros(nparticles) + +initial_xyz_fieldlines=jnp.array([R0_fieldlines*jnp.cos(phi0_fieldlines), R0_fieldlines*jnp.sin(phi0_fieldlines), Z0_fieldlines]).T +initial_xyz_particles=jnp.array([R0_particles*jnp.cos(phi0_particles), R0_particles*jnp.sin(phi0_particles), Z0_particles]).T + +particles = Particles(initial_xyz=initial_xyz_particles, mass=mass, energy=energy, field=field, min_vparallel_over_v=0.8) + +# Trace in ESSOS +time0 = time() +tracing_fl = Tracing(field=field, model='FieldLine', initial_conditions=initial_xyz_fieldlines, + maxtime=tmax_fl, timesteps=tmax_fl*10, tol_step_size=trace_tolerance) +block_until_ready(tracing_fl) +print(f"ESSOS tracing of {nfieldlines} field lines took {time()-time0:.2f} seconds") + +time0 = time() +tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo, + timesteps=timesteps_fo, tol_step_size=trace_tolerance) +tracing_fo.trajectories = tracing_fo.trajectories[:, 0::1000, :] +tracing_fo.times = tracing_fo.times[0::1000] +tracing_fo.energy = tracing_fo.energy[:, 0::1000] +block_until_ready(tracing_fo) +print(f"ESSOS tracing of {nparticles} particles with FO for {tmax_fo:.1e}s took {time()-time0:.2f} seconds") + +time0 = time() +tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, maxtime=tmax_gc, + timesteps=timesteps_gc, tol_step_size=trace_tolerance) +block_until_ready(tracing_gc) +print(f"ESSOS tracing of {nparticles} particles with GC for {tmax_gc:.1e}s took {time()-time0:.2f} seconds") + +# plt.figure(figsize=(9, 6)) +# plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy[0]/particles.energy-1), label='Guiding Center', color='red') +# plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy[0]/particles.energy-1), label='Full Orbit', color='blue') +# plt.xlabel('Time (ms)') +# plt.ylabel('Relative Energy Error') +# plt.xlim(0, tmax*1000) +# plt.ylim(bottom=0) +# plt.legend() +# plt.tight_layout() +# plt.savefig(os.path.join(os.path.dirname(__file__), 'energies.png'), dpi=300) + + +# fig = plt.figure(figsize=(9, 6)) +# ax = fig.add_subplot(projection='3d') +# coils.plot(ax=ax, show=False) +# tracing_fl.plot(ax=ax, show=False) +# plt.tight_layout() + +# fig = plt.figure(figsize=(9, 6)) +# ax = fig.add_subplot(projection='3d') +# coils.plot(ax=ax, show=False) +# tracing_fo.plot(ax=ax, show=False) +# plt.tight_layout() + +# fig = plt.figure(figsize=(9, 6)) +# ax = fig.add_subplot(projection='3d') +# coils.plot(ax=ax, show=False) +# tracing_gc.plot(ax=ax, show=False) +# plt.tight_layout() + +# fig, ax = plt.subplots(figsize=(9, 6)) +# time0 = time() +# tracing_fl.poincare_plot(ax=ax, shifts=[jnp.pi/2], show=False, s=0.5) +# print(f"ESSOS Poincare plot of {nfieldlines} field lines took {time()-time0:.2f} seconds") +# plt.xlabel('R (m)') +# plt.ylabel('Z (m)') +# ax.set_xlim(0.3, 1.3) +# ax.set_ylim(-0.3, 0.3) +# plt.grid(visible=False) +# plt.tight_layout() +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fl.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_fl.png'), dpi=300) + + +# fig, ax = plt.subplots(figsize=(9, 6)) +# time0 = time() +# tracing_fo.poincare_plot(ax=ax, shifts=[jnp.pi/2], show=False) +# print(f"ESSOS Poincare plot of {nparticles} particles took {time()-time0:.2f} seconds") +# plt.xlabel('R (m)') +# plt.ylabel('Z (m)') +# plt.xlim(0.3, 1.3) +# plt.ylim(-0.3, 0.3) +# plt.grid(visible=False) +# plt.tight_layout() +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fo.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_fo.png'), dpi=300) + + +# fig, ax = plt.subplots(figsize=(9, 6)) +# time0 = time() +# tracing_gc.poincare_plot(ax=ax, shifts=[jnp.pi/2], show=False) +# print(f"ESSOS Poincare plot of {nparticles} particles took {time()-time0:.2f} seconds") +# plt.xlabel('R (m)') +# plt.ylabel('Z (m)') +# ax.set_xlim(0.3, 1.3) +# ax.set_ylim(-0.3, 0.3) +# plt.grid(visible=False) +# plt.tight_layout() +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_gc.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_gc.png'), dpi=300) + +# plt.show() \ No newline at end of file From 108123e64dc8468195e2331f8f8cc407f56b6340 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 11 May 2025 15:48:41 +0200 Subject: [PATCH 11/63] Refactor integrators analysis: add cyclotron frequency calculation and adjust num_steps based on dt --- analysis/fo_integrators.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index 48802ee..71c03af 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -8,7 +8,7 @@ plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart from essos.coils import Coils_from_json -from essos.constants import PROTON_MASS, ONE_EV +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles # import integrators import diffrax @@ -18,11 +18,10 @@ nparticles = number_of_processors_to_use R0 = jnp.linspace(1.23, 1.27, nparticles) trace_tolerance = 1e-12 -num_steps = 5000 mass=PROTON_MASS energy=4000*ONE_EV - -print(f"dt = {tmax/num_steps:.2e}") +cyclotron_frequency = ELEMENTARY_CHARGE*5/mass +print("cyclotron period:", 1/cyclotron_frequency) # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') @@ -41,6 +40,8 @@ methods = [getattr(diffrax, method) for method in method_names[:-1]] + ['Boris'] for method_name, method in zip(method_names, methods): if method_name != 'Boris': + starting_dt = 1e-10 + num_steps = int(tmax/starting_dt) energies = [] tracing_times = [] for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-13, 1e-14]: @@ -62,9 +63,9 @@ energies = [] tracing_times = [] - for num_steps in [100000, 200000, 300000, 500000, 1000000]: - if method_name == 'Boris' or method_name == 'Dopri8': - num_steps //= 10 + for n_points_in_gyration in [5, 10, 20, 50, 100]: + dt = 1/(n_points_in_gyration*cyclotron_frequency) + num_steps = int(tmax/dt) time0 = time() tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) From c80ff2cb4d880e7324954a14cf451e82119d0d0f Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 11 May 2025 16:02:54 +0200 Subject: [PATCH 12/63] Change loss functions to be quadratic & implement separation loss --- essos/objective_functions.py | 67 ++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/essos/objective_functions.py b/essos/objective_functions.py index b5bd2b6..ade9ac6 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -2,6 +2,7 @@ jax.config.update("jax_enable_x64", True) import jax.numpy as jnp from jax import jit, vmap +from jax.lax import fori_loop from functools import partial from essos.dynamics import Tracing from essos.fields import BiotSavart @@ -95,38 +96,52 @@ def loss_particle_drift(field, particles, maxtime=1e-5, num_steps=300, trace_tol radial_drift = radial_factor**2 + vertical_factor**2 # radial_drift = jnp.sqrt(radial_drift) - # print("radial_drift", radial_drift) radial_drift = jnp.mean(jnp.diff(radial_drift, axis=1), axis=1) angular_drift = jnp.arctan2(vertical_factor, radial_factor+1e-10) angular_drift = jnp.mean(jnp.diff(angular_drift, axis=1), axis=1) - # return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)), jnp.ravel(jnp.abs(vertical_factor)))) + return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)), jnp.ravel(jnp.abs(vertical_factor)))) # return jnp.concatenate((jnp.max(radial_drift)*jnp.ravel(2./jnp.pi*jnp.abs(jnp.arctan(radial_drift/(angular_drift+1e-10)))), jnp.ravel(jnp.abs(radial_drift)))) # return jnp.concatenate((jnp.ravel(jnp.abs(angular_drift)), jnp.ravel(jnp.abs(radial_drift)))) - return jnp.concatenate((jnp.ravel(jnp.abs(vertical_factor)),)) + # return jnp.concatenate((jnp.ravel(jnp.abs(vertical_factor)),)) -# @partial(jit, static_argnums=(0)) -def loss_coil_length(field, max_coil_length): - coil_length = jnp.ravel(field.coils.length) - return jnp.maximum(coil_length-max_coil_length, 0) +@partial(jit, static_argnames=['max_coil_length']) +def loss_coil_length(coils, max_coil_length): + return jnp.square((coils.length-max_coil_length)/max_coil_length) -# @partial(jit, static_argnums=(0)) -def loss_coil_curvature(field, max_coil_curvature): - coil_curvature = jnp.mean(field.coils.curvature, axis=1) - return jnp.maximum(coil_curvature-max_coil_curvature, 0) +@partial(jit, static_argnames=['max_coil_curvature']) +def loss_coil_curvature(coils, max_coil_curvature): + pointwise_curvature_loss = jnp.square(jnp.maximum(coils.curvature-max_coil_curvature, 0)) + return jnp.mean(pointwise_curvature_loss, axis=1) -# @partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnames=['min_separation']) +def loss_coil_separation(coils, min_separation): + i_vals, j_vals = jnp.triu_indices(len(coils), k=1) + + def pair_loss(i, j): + gamma_i = coils.gamma[i] + gamma_j = coils.gamma[j] + dists = jnp.linalg.norm(gamma_i[:, None, :] - gamma_j[None, :, :], axis=2) + penalty = jnp.maximum(0, min_separation - dists) + return jnp.mean(jnp.square(penalty)) + + losses = jax.vmap(pair_loss)(i_vals, j_vals) + return jnp.sum(losses) + +# @partial(jit, static_argnames=['target_B_on_axis', 'npoints']) def loss_normB_axis(field, target_B_on_axis, npoints=15): R_axis = jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(field.coils.dofs_curves))) phi_array = jnp.linspace(0, 2 * jnp.pi, npoints) B_axis = vmap(lambda phi: field.AbsB(jnp.array([R_axis * jnp.cos(phi), R_axis * jnp.sin(phi), 0])))(phi_array) - return jnp.abs(B_axis-target_B_on_axis) + return jnp.square(B_axis-target_B_on_axis) @partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves_shape, currents_scale, nfp, max_coil_curvature=0.5, n_segments=60, stellsym=True, target_B_on_axis=5.7, maxtime=1e-5, - max_coil_length=22, num_steps=30, trace_tolerance=1e-5, model='GuidingCenter'): + max_coil_length=22, num_steps=30, trace_tolerance=1e-5, model='GuidingCenter', + coil_length_loss_factor=1, coil_curvature_loss_factor=1): + dofs_curves_size = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] dofs_curves = jnp.reshape(x[:dofs_curves_size], (dofs_curves_shape)) dofs_currents = x[dofs_curves_size:] @@ -135,17 +150,19 @@ def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves_shape coils = Coils(curves=curves, currents=dofs_currents*currents_scale) field = BiotSavart(coils) - particles_drift_loss = loss_particle_drift(field, particles, maxtime, num_steps, trace_tolerance, model=model) - normB_axis_loss = loss_normB_axis(field, target_B_on_axis) - coil_length_loss = loss_coil_length(field, max_coil_length) - coil_curvature_loss = loss_coil_curvature(field, max_coil_curvature) + particles_drift_loss = jnp.sum(loss_particle_drift(field, particles, maxtime, num_steps, trace_tolerance, model=model)) + normB_axis_loss = jnp.sum(loss_normB_axis(field, target_B_on_axis)) + coil_length_loss = coil_length_loss_factor * jnp.sum(loss_coil_length(coils, max_coil_length)) + coil_curvature_loss = coil_curvature_loss_factor * jnp.sum(loss_coil_curvature(coils, max_coil_curvature)) + coils_separation_loss = jnp.sum(loss_coil_separation(coils, 0.5)) - loss = jnp.concatenate((normB_axis_loss, coil_length_loss, particles_drift_loss, coil_curvature_loss)) - return jnp.sum(loss) + return normB_axis_loss + coil_length_loss + particles_drift_loss + coil_curvature_loss + coils_separation_loss @partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8)) def loss_BdotN(x, vmec, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, - n_segments=60, stellsym=True, max_coil_curvature=0.1): + n_segments=60, stellsym=True, max_coil_curvature=0.1, + coil_length_loss_factor=1, coil_curvature_loss_factor=1): + dofs_curves_size = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] dofs_curves = jnp.reshape(x[:dofs_curves_size], (dofs_curves_shape)) dofs_currents = x[dofs_curves_size:] @@ -154,10 +171,8 @@ def loss_BdotN(x, vmec, dofs_curves_shape, currents_scale, nfp, max_coil_length= coils = Coils(curves=curves, currents=dofs_currents*currents_scale) field = BiotSavart(coils) - bdotn_over_b = BdotN_over_B(vmec.surface, field) - coil_length_loss = jnp.max(loss_coil_length(field, max_coil_length)) - coil_curvature_loss = jnp.max(loss_coil_curvature(field, max_coil_curvature)) - - bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) + coil_length_loss = coil_length_loss_factor * jnp.sum(loss_coil_length(coils, max_coil_length)) + coil_curvature_loss = coil_curvature_loss_factor * jnp.sum(loss_coil_curvature(coils, max_coil_curvature)) + bdotn_over_b_loss = jnp.sum(jnp.square(BdotN_over_B(vmec.surface, field))) return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss \ No newline at end of file From ea0af03f6f1533922930fb8a68fb91f5601850a8 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 11 May 2025 16:03:26 +0200 Subject: [PATCH 13/63] Minor improvements in coils class --- essos/coils.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index a0cfc08..a34bf7f 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -21,8 +21,8 @@ class Curves: stellsym (bool): Stellarator symmetry order (int): Order of the Fourier series curves jnp.ndarray - shape (n_indcurves*nfp*(1+stellsym), 3, 2*order+1)): Curves obtained by applying rotations and flipping corresponding to nfp fold rotational symmetry and optionally stellarator symmetry - gamma (jnp.array - shape (n_coils, n_segments, 3)): Discretized curves - gamma_dash (jnp.array - shape (n_coils, n_segments, 3)): Discretized curves derivatives + gamma (jnp.array - shape (n_curves, n_segments, 3)): Discretized curves + gamma_dash (jnp.array - shape (n_curves, n_segments, 3)): Discretized curves derivatives """ def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stellsym: bool = True): @@ -63,7 +63,7 @@ def _tree_flatten(self): def _tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data) - partial(jit, static_argnames=['self']) + # @partial(jit, static_argnames=['self']) def _set_gamma(self): def fori_createdata(order_index: int, data: jnp.ndarray) -> jnp.ndarray: return data[0] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], jnp.cos(2 * jnp.pi * order_index * self.quadpoints)), \ @@ -366,7 +366,7 @@ def x(self, new_dofs): old_dofs_curves = jnp.ravel(self.dofs) old_dofs_currents = jnp.ravel(self.dofs_currents) new_dofs_curves = new_dofs[:old_dofs_curves.shape[0]] - new_dofs_currents = new_dofs[old_dofs_curves.shape[0]:] + new_dofs_currents = new_dofs[old_dofs_currents.shape[0]:] self.dofs_curves = jnp.reshape(new_dofs_curves, (self.dofs_curves.shape)) self.dofs_currents = new_dofs_currents @@ -401,13 +401,10 @@ def __eq__(self, other): return jnp.all(self.dofs == other.dofs) and jnp.all(self.dofs_currents == other.dofs_currents) else: raise TypeError(f"Invalid argument type. Got {type(other)}, expected Coils.") - - def __ne__(self, other): - return not self.__eq__(other) def _tree_flatten(self): - children = (Curves(self.dofs, self.n_segments, self.nfp, self.stellsym), self._dofs_currents) # arrays / dynamic values + children = (Curves(self.dofs, self.n_segments, self.nfp, self.stellsym), self._dofs_currents*self._currents_scale) # arrays / dynamic values aux_data = {} # static values return (children, aux_data) @@ -487,13 +484,13 @@ def CreateEquallySpacedCurves(n_curves: int, order: int, R: float, r: float, n_s return Curves(curves, n_segments=n_segments, nfp=nfp, stellsym=stellsym) def RotatedCurve(curve, phi, flip): - rotmat = jnp.array( - [[jnp.cos(phi), -jnp.sin(phi), 0], - [jnp.sin(phi), jnp.cos(phi), 0], - [0, 0, 1]]).T + rotmat_T = jnp.array( + [[ jnp.cos(phi), jnp.sin(phi), 0], + [-jnp.sin(phi), jnp.cos(phi), 0], + [ 0, 0, 1]]) if flip: - rotmat = rotmat @ jnp.diag(jnp.array([1, -1, -1])) - return curve @ rotmat + rotmat_T = rotmat_T @ jnp.diag(jnp.array([1, -1, -1])) + return curve @ rotmat_T @partial(jit, static_argnames=['nfp', 'stellsym']) def apply_symmetries_to_curves(base_curves, nfp, stellsym): From 5617e7f281c7552c27b30759bf4f3f10681d5cdc Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 11 May 2025 16:03:41 +0200 Subject: [PATCH 14/63] Add join method to Particles class and optimize tracing parameters in examples --- essos/dynamics.py | 43 +++++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/essos/dynamics.py b/essos/dynamics.py index 28fe49c..b53af69 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -73,6 +73,20 @@ def to_full_orbit(self, field): total_speed=self.total_speed, mass=self.mass, charge=self.charge, phase_angle_full_orbit=self.phase_angle_full_orbit) + def join(self, other, field=None): + assert isinstance(other, Particles), "Cannot join with non-Particles object" + assert self.charge == other.charge, "Cannot join particles with different charges" + assert self.mass == other.mass, "Cannot join particles with different masses" + assert self.energy == other.energy, "Cannot join particles with different energies" + + charge = self.charge + mass = self.mass + energy = self.energy + initial_xyz = jnp.concatenate((self.initial_xyz, other.initial_xyz), axis=0) + initial_vparallel_over_v = jnp.concatenate((self.initial_vparallel_over_v, other.initial_vparallel_over_v), axis=0) + + return Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, charge=charge, mass=mass, energy=energy, field=field) + @partial(jit, static_argnums=(2)) def GuidingCenter(t, initial_condition, @@ -135,7 +149,7 @@ def FieldLine(t, class Tracing(): def __init__(self, trajectories_input=None, initial_conditions=None, times=None, field=None, model=None, method=None, maxtime: float = 1e-7, timesteps: int = 500, stepsize: str = "adaptative", - trajectories=None, tol_step_size = 1e-10, particles=None, condition=None): + tol_step_size = 1e-10, particles=None, condition=None): assert method == None or \ method == 'Boris' or \ @@ -228,9 +242,7 @@ def compute_energy_fo(trajectory): self.total_particles_lost = None self.loss_times = None - @partial(jit, static_argnums=(0)) def trace(self): - @jit def compute_trajectory(initial_condition) -> jnp.ndarray: # initial_condition = initial_condition[0] if self.model == 'FullOrbit_Boris' or self.method == 'Boris': @@ -278,27 +290,9 @@ def update_state(state, _): ).ys return trajectory - # if len(jax.devices())!=len(self.initial_conditions): - # return vmap(compute_trajectory)(self.initial_conditions[:,None,:]) - # else: - # # num_devices = len(jax.devices()) - # shape = self.initial_conditions.shape - # # distributed_initial_conditions = self.initial_conditions.reshape(num_devices, -1, *shape[1:]) - # mesh = Mesh(devices=jax.devices(), axis_names=('workers')) - # in_spec = PartitionSpec('workers') # Distribute along the workers axis - # out_spec = PartitionSpec('workers') # Gather results along the same axis - # return shard_map(compute_trajectory, mesh, in_specs=in_spec, out_specs=out_spec, check_rep=False)( - # self.initial_conditions).reshape((shape[0], self.timesteps, shape[1])) - return jit(vmap(compute_trajectory), in_shardings=sharding, out_shardings=sharding)( device_put(self.initial_conditions, sharding)) - # trajectories = [] - # for initial_condition in self.initial_conditions: - # trajectory = compute_trajectory(initial_condition) - # trajectories.append(trajectory) - # return jnp.array(trajectories) - @property def trajectories(self): return self._trajectories @@ -308,8 +302,9 @@ def trajectories(self, value): self._trajectories = value def _tree_flatten(self): - children = (self.trajectories,) # arrays / dynamic values - aux_data = {'field': self.field, 'model': self.model} # static values + children = (self.trajectories, self.initial_conditions, self.times, self.field) # arrays / dynamic values + aux_data = {'model': self.model, 'method': self.method, 'maxtime': self.maxtime, 'timesteps': self.timesteps,'stepsize': + self.stepsize, 'tol_step_size': self.tol_step_size, 'particles': self.particles, 'condition': self.condition} # static values return (children, aux_data) @classmethod @@ -359,7 +354,7 @@ def poincare_plot(self, shifts = [jnp.pi/2], orientation = 'toroidal', length = """ Plot Poincare plots using scipy to find the roots of an interpolation. Can take particle trace or field lines. Args: - shifts (list, optional): Apply a linear shift to dependent data. Default is [0]. + shifts (list, optional): Apply a linear shift to dependent data. Default is [pi/2]. orientation (str, optional): 'toroidal' - find time values when toroidal angle = shift [0, 2pi]. 'z' - find time values where z coordinate = shift. Default is 'toroidal'. From 0bb4462d74c7b9b731be520c1caaf8bd868573e6 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sun, 11 May 2025 20:50:12 +0200 Subject: [PATCH 15/63] Simplify curve appending logic in apply_symmetries_to_curves --- essos/coils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index a34bf7f..817fc6f 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -499,11 +499,8 @@ def apply_symmetries_to_curves(base_curves, nfp, stellsym): for k in range(0, nfp): for flip in flip_list: for i in range(len(base_curves)): - if k == 0 and not flip: - curves.append(base_curves[i]) - else: - rotcurve = RotatedCurve(base_curves[i].T, 2*jnp.pi*k/nfp, flip) - curves.append(rotcurve.T) + rotcurve = RotatedCurve(base_curves[i].T, 2*jnp.pi*k/nfp, flip) + curves.append(rotcurve.T) return jnp.array(curves) @partial(jit, static_argnames=['nfp', 'stellsym']) From 91f8cd01a7fbefe0074cf194846ecc23ca3b1f5b Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 12 May 2025 12:32:05 +0200 Subject: [PATCH 16/63] Update cyclotron frequency calculation and adjust tracing parameters in integrators --- analysis/fo_integrators.py | 16 ++++------ analysis/gc_integrators.py | 60 +++++++++++++++----------------------- analysis/poincare_plots.py | 2 +- 3 files changed, 29 insertions(+), 49 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index 71c03af..6bafe72 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -20,7 +20,7 @@ trace_tolerance = 1e-12 mass=PROTON_MASS energy=4000*ONE_EV -cyclotron_frequency = ELEMENTARY_CHARGE*5/mass +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass print("cyclotron period:", 1/cyclotron_frequency) # Load coils and field @@ -36,20 +36,15 @@ fig, ax = plt.subplots(figsize=(9, 6)) -method_names = ['Tsit5', 'Dopri5', 'Dopri8', 'Boris'] +method_names = ['Dopri8', 'Boris'] methods = [getattr(diffrax, method) for method in method_names[:-1]] + ['Boris'] for method_name, method in zip(method_names, methods): if method_name != 'Boris': - starting_dt = 1e-10 + starting_dt = 1e-9 num_steps = int(tmax/starting_dt) energies = [] tracing_times = [] - for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-13, 1e-14]: - if method_name == 'Dopri8': - if trace_tolerance == 1e-13: - trace_tolerance = 1e-14 - elif trace_tolerance == 1e-14: - trace_tolerance = 1e-15 + for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14]: time0 = time() tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) @@ -63,7 +58,7 @@ energies = [] tracing_times = [] - for n_points_in_gyration in [5, 10, 20, 50, 100]: + for n_points_in_gyration in [5, 10, 20, 30, 40]: dt = 1/(n_points_in_gyration*cyclotron_frequency) num_steps = int(tmax/dt) time0 = time() @@ -77,7 +72,6 @@ energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method_name}', marker='o', markersize=4, linestyle='-') -from matplotlib.ticker import LogFormatterMathtext ax.legend() ax.set_xlabel('Computation time (s)') diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index 29d8298..67eab81 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -8,78 +8,64 @@ plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart from essos.coils import Coils_from_json -from essos.constants import PROTON_MASS, ONE_EV +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles # import integrators import diffrax -# Input parameters -tmax = 1e-4 -nparticles = number_of_processors_to_use -R0 = jnp.linspace(1.23, 1.27, nparticles) -trace_tolerance = 1e-12 -num_steps = 1500 -mass=PROTON_MASS -energy=4000*ONE_EV - # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils_from_json(json_file) field = BiotSavart(coils) -# Initialize particles -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy) +# Particle parameters +nparticles = number_of_processors_to_use +mass=PROTON_MASS +energy=5000*ONE_EV +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass +print("cyclotron period:", 1/cyclotron_frequency) + +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.9], phase_angle_full_orbit=0) + +# Tracing parameters +tmax = 1e-4 +dt = 5e-8 +num_steps = int(tmax/dt) fig, ax = plt.subplots(figsize=(9, 6)) for method in ['Tsit5', 'Dopri5', 'Dopri8']: energies = [] tracing_times = [] - for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14, 1e-16]: + for trace_tolerance in [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]: time0 = time() tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) - block_until_ready(tracing) + block_until_ready(tracing.trajectories) tracing_times += [time() - time0] print(f"Tracing with adaptative {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") - energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'adaptative {method}', marker='o', markersize=3, linestyle='-') energies = [] tracing_times = [] - for num_steps in [500, 1000, 2000, 5000, 10000]: + for dt in [2e-7, 1e-7, 5e-8, 2.5e-8]: + num_steps = int(tmax/dt) time0 = time() tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) - block_until_ready(tracing) + block_until_ready(tracing.trajectories) tracing_times += [time() - time0] print(f"Tracing with {method} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") - energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') -# num_steps = 100 -# for method in ['Kvaerno5', 'Kvaerno4']: -# energies = [] -# tracing_times = [] -# for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14, 1e-16]: -# time0 = time() -# tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, -# stepsize="adaptative", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) -# tracing_times += [time() - time0] - -# print(f"Tracing with adaptative {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") - -# energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] -# ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') - -from matplotlib.ticker import LogFormatterMathtext ax.legend() ax.set_xlabel('Computation time (s)') diff --git a/analysis/poincare_plots.py b/analysis/poincare_plots.py index 79fc5e8..4c8179f 100644 --- a/analysis/poincare_plots.py +++ b/analysis/poincare_plots.py @@ -28,7 +28,7 @@ timesteps_fo = int(tmax_fo/dt_fo) mass = PROTON_MASS energy = 4000*ONE_EV -print("cyclotron period:", 1/(ELEMENTARY_CHARGE*5/mass)) +print("cyclotron period:", 1/(ELEMENTARY_CHARGE*0.3/mass)) # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') From 549a43f42f8c24052ba3d0bf110e1a65b21ef59c Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 12 May 2025 12:32:27 +0200 Subject: [PATCH 17/63] Enhance Particles class with phase angle parameter and update Tracing class for improved step size handling --- essos/dynamics.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/essos/dynamics.py b/essos/dynamics.py index b53af69..aee72a2 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -5,7 +5,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax import jit, vmap, tree_util, random, lax, device_put from functools import partial -from diffrax import diffeqsolve, ODETerm, SaveAt, Dopri8, PIDController, Event, AbstractSolver, ConstantStepSize +from diffrax import diffeqsolve, ODETerm, SaveAt, Dopri8, PIDController, Event, AbstractSolver, ConstantStepSize, StepTo from essos.coils import Coils from essos.fields import BiotSavart, Vmec from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, FUSION_ALPHA_PARTICLE_ENERGY @@ -45,7 +45,7 @@ def compute_orbit_params(xyz, vpar): class Particles(): def __init__(self, initial_xyz=None, initial_vparallel_over_v=None, charge=ALPHA_PARTICLE_CHARGE, mass=ALPHA_PARTICLE_MASS, energy=FUSION_ALPHA_PARTICLE_ENERGY, min_vparallel_over_v=-1, - max_vparallel_over_v=1, field=None, initial_vxvyvz=None, initial_xyz_fullorbit=None): + max_vparallel_over_v=1, field=None, initial_vxvyvz=None, initial_xyz_fullorbit=None, phase_angle_full_orbit = 0): self.charge = charge self.mass = mass self.energy = energy @@ -53,7 +53,7 @@ def __init__(self, initial_xyz=None, initial_vparallel_over_v=None, charge=ALPHA self.nparticles = len(initial_xyz) self.initial_xyz_fullorbit = initial_xyz_fullorbit self.initial_vxvyvz = initial_vxvyvz - self.phase_angle_full_orbit = 0 + self.phase_angle_full_orbit = phase_angle_full_orbit if initial_vparallel_over_v is not None: self.initial_vparallel_over_v = jnp.array(initial_vparallel_over_v) @@ -267,25 +267,27 @@ def update_state(state, _): _, trajectory = lax.scan(update_state, initial_condition, jnp.arange(len(self.times)-1)) trajectory = jnp.vstack([initial_condition, trajectory]) else: - import warnings - warnings.simplefilter("ignore", category=FutureWarning) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation + # import warnings + # warnings.simplefilter("ignore", category=FutureWarning) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation if self.stepsize == "adaptative": controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size) + dt0 = self.maxtime / self.timesteps elif self.stepsize == "constant": - controller = ConstantStepSize() + controller = StepTo(self.times) + dt0 = None trajectory = diffeqsolve( self.ODE_term, t0=0.0, t1=self.maxtime, - dt0=self.maxtime / self.timesteps, + dt0=dt0, y0=initial_condition, solver=self.method(), args=self.args, saveat=SaveAt(ts=self.times), - throw=False, + throw=True, # adjoint=DirectAdjoint(), stepsize_controller = controller, - max_steps=10000000000, + max_steps = int(1e10), event = Event(self.condition) ).ys return trajectory @@ -302,8 +304,8 @@ def trajectories(self, value): self._trajectories = value def _tree_flatten(self): - children = (self.trajectories, self.initial_conditions, self.times, self.field) # arrays / dynamic values - aux_data = {'model': self.model, 'method': self.method, 'maxtime': self.maxtime, 'timesteps': self.timesteps,'stepsize': + children = (self.trajectories, self.initial_conditions, self.times) # arrays / dynamic values + aux_data = {'field': self.field, 'model': self.model, 'method': self.method, 'maxtime': self.maxtime, 'timesteps': self.timesteps,'stepsize': self.stepsize, 'tol_step_size': self.tol_step_size, 'particles': self.particles, 'condition': self.condition} # static values return (children, aux_data) From 23bf2718a94b76bd81e42f0a1282c9bbbf12e9ea Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 12 May 2025 12:32:51 +0200 Subject: [PATCH 18/63] Refactor coil separation loss function and optimize tracing parameters in examples --- essos/objective_functions.py | 6 +++ examples/compare_guidingcenter_fullorbit.py | 52 +++++++++++-------- ...ze_coils_particle_confinement_fullorbit.py | 48 ++++++----------- ...oils_particle_confinement_guidingcenter.py | 6 +-- examples/optimize_coils_vmec_surface.py | 6 +-- 5 files changed, 58 insertions(+), 60 deletions(-) diff --git a/essos/objective_functions.py b/essos/objective_functions.py index ade9ac6..8ffa78e 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -117,6 +117,12 @@ def loss_coil_curvature(coils, max_coil_curvature): @partial(jit, static_argnames=['min_separation']) def loss_coil_separation(coils, min_separation): + # Sort coils by angle + # sorting = jnp.argsort(jnp.arctan2(coils.curves[:,1,0], coils.curves[:,0,0])%(2*jnp.pi)) + # This can be useful to only cosider the separation between adjacent coils + # i_vals, j_vals = jnp.arange(len(coils)), jnp.arange(1, len(coils)+1)%len(coils) + # but in this case gamma_i and gamma_j have to be sorted with the sorting mask + i_vals, j_vals = jnp.triu_indices(len(coils), k=1) def pair_loss(i, j): diff --git a/examples/compare_guidingcenter_fullorbit.py b/examples/compare_guidingcenter_fullorbit.py index 1b0a69f..302120b 100644 --- a/examples/compare_guidingcenter_fullorbit.py +++ b/examples/compare_guidingcenter_fullorbit.py @@ -7,31 +7,36 @@ import matplotlib.pyplot as plt from essos.fields import BiotSavart from essos.coils import Coils_from_json -from essos.constants import PROTON_MASS, ONE_EV +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles from jax import block_until_ready -# Input parameters -tmax = 1e-2 -nparticles = number_of_processors_to_use -R0 = jnp.linspace(1.23, 1.27, nparticles) -trace_tolerance = 1e-5 -num_steps_gc = 5000 -num_steps_fo = 100000 -mass=PROTON_MASS -energy=5000*ONE_EV - # Load coils and field json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils_from_json(json_file) field = BiotSavart(coils) -# Initialize particles -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -initial_vparallel_over_v = [0.1] -particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, field=field, initial_vparallel_over_v=initial_vparallel_over_v) +# Particle parameters +nparticles = number_of_processors_to_use +mass=PROTON_MASS +energy=5000*ONE_EV +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass +print("cyclotron period:", 1/cyclotron_frequency) + +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) + +particles_passing = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.1], phase_angle_full_orbit=0) +particles_traped = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.9], phase_angle_full_orbit=0) +particles = particles_passing.join(particles_traped, field=field) + +# Tracing parameters +tmax = 1e-4 +trace_tolerance = 1e-15 +dt_gc = 1e-7 +dt_fo = 1e-9 +num_steps_gc = int(tmax/dt_gc) +num_steps_fo = int(tmax/dt_fo) # Trace in ESSOS time0 = time() @@ -41,9 +46,9 @@ print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") time0 = time() -tracing_fullorbit = Tracing(field=field, model='FullOrbit_Boris', particles=particles, - maxtime=tmax, timesteps=num_steps_fo, tol_step_size=trace_tolerance) -trajectories_fullorbit = block_until_ready(tracing_fullorbit.trajectories) +tracing_fullorbit = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax, + timesteps=num_steps_fo, tol_step_size=trace_tolerance) +block_until_ready(tracing_fullorbit.trajectories) print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") # Plot trajectories, velocity parallel to the magnetic field, and energy error @@ -57,7 +62,7 @@ tracing_guidingcenter.plot(ax=ax1, show=False) tracing_fullorbit.plot(ax=ax1, show=False) -for i, (trajectory_gc, trajectory_fo) in enumerate(zip(trajectories_guidingcenter, trajectories_fullorbit)): +for i, (trajectory_gc, trajectory_fo) in enumerate(zip(trajectories_guidingcenter, tracing_fullorbit.trajectories)): ax2.plot(tracing_guidingcenter.times, jnp.abs(tracing_guidingcenter.energy[i]-particles.energy)/particles.energy, '-', label=f'Particle {i+1} GC', linewidth=1.0, alpha=0.7) ax2.plot(tracing_fullorbit.times, jnp.abs(tracing_fullorbit.energy[i]-particles.energy)/particles.energy, '--', label=f'Particle {i+1} FO', linewidth=1.0, markersize=0.5, alpha=0.7) def compute_v_parallel(trajectory_t): @@ -84,5 +89,6 @@ def compute_v_parallel(trajectory_t): plt.show() ## Save results in vtk format to analyze in Paraview -# tracing.to_vtk('trajectories') -# coils.to_vtk('coils') \ No newline at end of file +tracing_guidingcenter.to_vtk('trajectories_gc') +tracing_fullorbit.to_vtk('trajectories_fo') +coils.to_vtk('coils') \ No newline at end of file diff --git a/examples/optimize_coils_particle_confinement_fullorbit.py b/examples/optimize_coils_particle_confinement_fullorbit.py index a58c95b..9f37313 100644 --- a/examples/optimize_coils_particle_confinement_fullorbit.py +++ b/examples/optimize_coils_particle_confinement_fullorbit.py @@ -1,12 +1,11 @@ import os -number_of_processors_to_use = 12 # Parallelization, this should divide nparticles +number_of_processors_to_use = 4 # Parallelization, this should divide nparticles os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time import jax.numpy as jnp import matplotlib.pyplot as plt from essos.dynamics import Particles, Tracing -from essos.fields import BiotSavart from essos.coils import Coils, CreateEquallySpacedCurves from essos.optimization import optimize_loss_function from essos.objective_functions import loss_optimize_coils_for_particle_confinement @@ -15,20 +14,14 @@ target_B_on_axis = 5.7 max_coil_length = 31 max_coil_curvature = 0.4 -nparticles = number_of_processors_to_use +nparticles = number_of_processors_to_use*3 order_Fourier_series_coils = 4 number_coil_points = 80 -maximum_function_evaluations = 30 -maxtime_tracing = 1e-5 +maximum_function_evaluations = 29 +maxtime_tracing = 2e-5 number_coils_per_half_field_period = 3 number_of_field_periods = 2 -model = 'FullOrbit_Boris' -timesteps = 3000#int(3*maxtime_tracing/1e-8) - -nparticles_plot = number_of_processors_to_use*2 -model_plot = 'GuidingCenter' -timesteps_plot = 10000 -maxtime_tracing_plot = 3e-5 +model = 'GuidingCenter' # Initialize coils current_on_each_coil = 1.84e7 @@ -45,25 +38,19 @@ phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T particles = Particles(initial_xyz=initial_xyz) -particles.to_full_orbit(BiotSavart(coils_initial)) -tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=maxtime_tracing, model=model, timesteps=timesteps) +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=maxtime_tracing, model=model, tol_step_size = 1e-14) # Optimize coils print(f'Optimizing coils with {maximum_function_evaluations} function evaluations and maxtime_tracing={maxtime_tracing}') time0 = time() -coils_optimized = optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils_initial.x, - coils=coils_initial, tolerance_optimization=1e-4, particles=particles, - maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature, - target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, model=model, - maxtime=maxtime_tracing, num_steps=timesteps) +coils_optimized = optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils_initial.x, coils=coils_initial, + tolerance_optimization=1e-4, particles=particles, maximum_function_evaluations=maximum_function_evaluations, + max_coil_curvature=max_coil_curvature, target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, + model=model, maxtime=maxtime_tracing, num_steps=500, trace_tolerance=1e-5) +# coils_optimized = optimize_coils_for_particle_confinement(coils_initial, particles, target_B_on_axis=target_B_on_axis, maxtime=maxtime_tracing, model=model, +# max_coil_length=max_coil_length, maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature) print(f" Optimization took {time()-time0:.2f} seconds") -particles.to_full_orbit(BiotSavart(coils_optimized)) - -phi_array_plot = jnp.linspace(0, 2*jnp.pi, nparticles_plot) -initial_xyz_plot=jnp.array([major_radius_coils*jnp.cos(phi_array_plot), major_radius_coils*jnp.sin(phi_array_plot), 0*phi_array_plot]).T -particles_plot = Particles(initial_xyz=initial_xyz_plot) -particles.to_full_orbit(BiotSavart(coils_optimized)) -tracing_optimized = Tracing(field=coils_optimized, particles=particles, maxtime=maxtime_tracing_plot, model=model_plot, timesteps=timesteps_plot) +tracing_optimized = Tracing(field=coils_optimized, particles=particles, maxtime=maxtime_tracing, model=model) # Plot trajectories, before and after optimization fig = plt.figure(figsize=(9, 8)) @@ -75,14 +62,13 @@ coils_initial.plot(ax=ax1, show=False) tracing_initial.plot(ax=ax1, show=False) for i, trajectory in enumerate(tracing_initial.trajectories): - ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}', linewidth=0.2) + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') ax3.set_xlabel('R (m)');ax3.set_ylabel('Z (m)');#ax3.legend() coils_optimized.plot(ax=ax2, show=False) tracing_optimized.plot(ax=ax2, show=False) -# for i, trajectory in enumerate(tracing_optimized.trajectories): -# ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}', linewidth=0.2) -# ax4.set_xlabel('R (m)');ax4.set_ylabel('Z (m)');#ax4.legend() -plotting_data = tracing_optimized.poincare_plot(ax=ax4, shifts = [jnp.pi/4, jnp.pi/2, 3*jnp.pi/4], show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)');ax4.set_ylabel('Z (m)');#ax4.legend() plt.tight_layout() plt.show() diff --git a/examples/optimize_coils_particle_confinement_guidingcenter.py b/examples/optimize_coils_particle_confinement_guidingcenter.py index eae5f63..ff49100 100644 --- a/examples/optimize_coils_particle_confinement_guidingcenter.py +++ b/examples/optimize_coils_particle_confinement_guidingcenter.py @@ -1,6 +1,6 @@ import os -number_of_processors_to_use = 12 # Parallelization, this should divide nparticles +number_of_processors_to_use = 4 # Parallelization, this should divide nparticles os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time import jax.numpy as jnp @@ -14,7 +14,7 @@ target_B_on_axis = 5.7 max_coil_length = 31 max_coil_curvature = 0.4 -nparticles = number_of_processors_to_use +nparticles = number_of_processors_to_use*3 order_Fourier_series_coils = 4 number_coil_points = 80 maximum_function_evaluations = 29 @@ -38,7 +38,7 @@ phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T particles = Particles(initial_xyz=initial_xyz) -tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=maxtime_tracing, model=model) +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=maxtime_tracing, model=model, tol_step_size = 1e-14) # Optimize coils print(f'Optimizing coils with {maximum_function_evaluations} function evaluations and maxtime_tracing={maxtime_tracing}') diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py index 1296aea..cf2b337 100644 --- a/examples/optimize_coils_vmec_surface.py +++ b/examples/optimize_coils_vmec_surface.py @@ -13,9 +13,9 @@ # Optimization parameters max_coil_length = 40 max_coil_curvature = 0.5 -order_Fourier_series_coils = 6 +order_Fourier_series_coils = 10 number_coil_points = order_Fourier_series_coils*10 -maximum_function_evaluations = 300 +maximum_function_evaluations = 500 number_coils_per_half_field_period = 4 tolerance_optimization = 1e-5 ntheta=32 @@ -37,7 +37,7 @@ n_segments=number_coil_points, nfp=number_of_field_periods, stellsym=True) coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) - +print(coils_initial.dofs_curves.shape) # Optimize coils print(f'Optimizing coils with {maximum_function_evaluations} function evaluations.') time0 = time() From dcef625d1cd67f1e7acf33d2f39955b5e845da37 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 12 May 2025 12:33:42 +0200 Subject: [PATCH 19/63] Tentative to incorporate pytree methods for BiotSavart class --- essos/fields.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/essos/fields.py b/essos/fields.py index 94b2b18..fabc09a 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -44,7 +44,19 @@ def dAbsB_by_dX(self, points): @partial(jit, static_argnames=['self']) def to_xyz(self, points): return points + +# def _tree_flatten(self): +# children = (self.coils,) +# aux_data = {} +# return (children, aux_data) + +# @classmethod +# def _tree_unflatten(cls, aux_data, children): +# return cls(*children, **aux_data) +# tree_util.register_pytree_node(BiotSavart, +# BiotSavart._tree_flatten, +# BiotSavart._tree_unflatten) class Vmec(): def __init__(self, wout_filename, ntheta=50, nphi=50, close=True, range_torus='full torus'): From 3c96714d9bba75b01422a84c60eedf2c13416932 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 14 May 2025 16:59:03 +0200 Subject: [PATCH 20/63] Refactor Tracing class initialization and improve parameter handling for adaptive step size --- analysis/fo_integrators.py | 44 +++++++++--------- analysis/gc_integrators.py | 20 ++++---- essos/dynamics.py | 94 ++++++++++++++++++++++++-------------- 3 files changed, 91 insertions(+), 67 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index 6bafe72..d654beb 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -13,42 +13,40 @@ # import integrators import diffrax -# Input parameters -tmax = 1e-4 +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils_from_json(json_file) +field = BiotSavart(coils) + +# Particle parameters nparticles = number_of_processors_to_use -R0 = jnp.linspace(1.23, 1.27, nparticles) -trace_tolerance = 1e-12 mass=PROTON_MASS -energy=4000*ONE_EV +energy=5000*ONE_EV cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass print("cyclotron period:", 1/cyclotron_frequency) -# Load coils and field -json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) -field = BiotSavart(coils) +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.8], field=field) -# Initialize particles -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, field=field) +# Tracing parameters +tmax = 1e-5 +dt = 1e-9 +num_steps = int(tmax/dt) fig, ax = plt.subplots(figsize=(9, 6)) -method_names = ['Dopri8', 'Boris'] +method_names = ['Tsit5', 'Dopri5', 'Dopri8', 'Boris'] methods = [getattr(diffrax, method) for method in method_names[:-1]] + ['Boris'] for method_name, method in zip(method_names, methods): if method_name != 'Boris': - starting_dt = 1e-9 - num_steps = int(tmax/starting_dt) energies = [] tracing_times = [] for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14]: time0 = time() - tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, - maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) - block_until_ready(tracing) + tracing = Tracing('FullOrbit', field, tmax, method=method, timesteps=num_steps, + stepsize='adaptive', tol_step_size=trace_tolerance, particles=particles) + block_until_ready(tracing.trajectories) tracing_times += [time() - time0] print(f"Tracing with adaptative {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") @@ -62,9 +60,9 @@ dt = 1/(n_points_in_gyration*cyclotron_frequency) num_steps = int(tmax/dt) time0 = time() - tracing = Tracing(field=field, model='FullOrbit', method=method, particles=particles, - stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) - block_until_ready(tracing) + tracing = Tracing('FullOrbit', field, tmax, method=method, timesteps=num_steps, + stepsize="constant", particles=particles) + block_until_ready(tracing.trajectories) tracing_times += [time() - time0] print(f"Tracing with {method_name} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index 67eab81..50478ba 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -27,11 +27,11 @@ # Particles initialization initial_xyz=jnp.array([[1.23, 0, 0]]) -particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.9], phase_angle_full_orbit=0) +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.8]) # Tracing parameters tmax = 1e-4 -dt = 5e-8 +dt = 1e-7 num_steps = int(tmax/dt) fig, ax = plt.subplots(figsize=(9, 6)) @@ -39,25 +39,25 @@ for method in ['Tsit5', 'Dopri5', 'Dopri8']: energies = [] tracing_times = [] - for trace_tolerance in [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]: + for trace_tolerance in [1e-9, 1e-10, 1e-11, 1e-12, 1e-13]: time0 = time() - tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, - maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), timesteps=num_steps, + stepsize='adaptive', tol_step_size=trace_tolerance, particles=particles,) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] - print(f"Tracing with adaptative {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + print(f"Tracing with adaptive {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] - ax.plot(tracing_times, energies, label=f'adaptative {method}', marker='o', markersize=3, linestyle='-') + ax.plot(tracing_times, energies, label=f'adaptive {method}', marker='o', markersize=3, linestyle='-') energies = [] tracing_times = [] - for dt in [2e-7, 1e-7, 5e-8, 2.5e-8]: + for dt in [2e-7, 1e-7, 5e-8, 2e-8]: num_steps = int(tmax/dt) time0 = time() - tracing = Tracing(field=field, model='GuidingCenter', method=getattr(diffrax, method), particles=particles, - stepsize="constant", maxtime=tmax, timesteps=num_steps, tol_step_size=trace_tolerance) + tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), + timesteps=num_steps, stepsize="constant", particles=particles) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] diff --git a/essos/dynamics.py b/essos/dynamics.py index aee72a2..9e567de 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -147,32 +147,65 @@ def FieldLine(t, # return lax.cond(condition, zero_derivatives, compute_derivatives, operand=None) class Tracing(): - def __init__(self, trajectories_input=None, initial_conditions=None, times=None, field=None, - model=None, method=None, maxtime: float = 1e-7, timesteps: int = 500, stepsize: str = "adaptative", - tol_step_size = 1e-10, particles=None, condition=None): + def __init__(self, model: str, field, maxtime: float, method=None, times=None, + timesteps: int = None, stepsize: str = "adaptive", dt0: float=1e-5, + tol_step_size = 1e-10, particles=None, initial_conditions=None, condition=None): + """ + Tracing class to compute the trajectories of particles in a magnetic field. + + Parameters + ---------- + + """ - assert method == None or \ + assert model in ["GuidingCenter", "FullOrbit", "FieldLine"], "Model must be one of: 'GuidingCenter', 'FullOrbit', or 'FieldLine'" + assert method is None or \ method == 'Boris' or \ issubclass(method, AbstractSolver), "Method must be None, 'Boris', or a DIFFRAX solver" + assert stepsize in ["adaptive", "constant"], "stepsize must be 'adaptive' or 'constant'" if method == 'Boris': - assert model == 'FullOrbit' or model == 'FullOrbit_Boris', "Method 'Boris' is only available for FullOrbit models" - - if isinstance(field, Coils): - self.field = BiotSavart(field) - else: - self.field = field - assert stepsize in ["adaptative", "constant"], "stepsize must be 'adaptative' or 'constant'" - + assert model == 'FullOrbit', "Method 'Boris' is only available for full orbit model" + assert stepsize == "constant", "Method 'Boris' is only available for constant step size" self.model = model self.method = method - self.initial_conditions = initial_conditions - self.times = times - self.maxtime = maxtime - self.timesteps = timesteps self.stepsize = stepsize - self.tol_step_size = tol_step_size - self._trajectories = trajectories_input - self.particles = particles + + assert isinstance(field, (BiotSavart, Coils, Vmec)), "Field must be a BiotSavart, Coils, or Vmec object" + self.field = BiotSavart(field) if isinstance(field, Coils) else field + + assert isinstance(maxtime, (int, float)), "maxtime must be a float" + assert maxtime > 0, "maxtime must be greater than 0" + self.maxtime = maxtime + + assert times is not None or timesteps is not None, "Either times or timesteps must be provided" + + assert timesteps is None or \ + isinstance(timesteps, (int, float)) and \ + timesteps > 0, "timesteps must be None or a positive float" + assert times is None or \ + isinstance(times, jnp.ndarray), "times must be None or a numpy array" + self.times = jnp.linspace(0, maxtime, timesteps) if times is None else times + self.timesteps = len(self.times) + + if stepsize == "adaptive": + # assert dt0 is not None, "dt0 must be provided for adaptive step size" + assert tol_step_size is not None, "tol_step_size must be provided for adaptive step size" + assert isinstance(tol_step_size, float), "tol_step_size must be a float" + assert tol_step_size > 0, "tol_step_size must be greater than 0" + # self.dt0 = dt0 + self.tol_step_size = tol_step_size + elif stepsize == "constant": + assert maxtime == self.times[-1], "maxtime must be equal to the last time in the times array for constant step size" + # self.dt0 = None + + if model == 'FieldLine': + assert initial_conditions is not None, "initial_conditions must be provided for FieldLine model" + self.initial_conditions = initial_conditions + elif model == 'GuidingCenter' or model == 'FullOrbit': + assert isinstance(particles, Particles), "particles object must be provided for GuidingCenter and FullOrbit models" + self.particles = particles + + if condition is None: self.condition = lambda t, y, args, **kwargs: False if isinstance(field, Vmec): @@ -180,13 +213,14 @@ def condition_Vmec(t, y, args, **kwargs): s, _, _, _ = y return s-1 self.condition = condition_Vmec + if model == 'GuidingCenter': self.ODE_term = ODETerm(GuidingCenter) self.args = (self.field, self.particles) self.initial_conditions = jnp.concatenate([self.particles.initial_xyz, self.particles.initial_vparallel[:, None]], axis=1) if self.method is None: self.method = Dopri8 - elif model == 'FullOrbit' or model == 'FullOrbit_Boris': + elif model == 'FullOrbit': self.ODE_term = ODETerm(Lorentz) self.args = (self.field, self.particles) if self.particles.initial_xyz_fullorbit is None: @@ -201,14 +235,8 @@ def condition_Vmec(t, y, args, **kwargs): self.args = self.field if self.method is None: self.method = Dopri8 - else: - raise ValueError("Model must be one of: 'GuidingCenter', 'FullOrbit', 'FullOrbit_Boris', or 'FieldLine'") - if self.times is None: - self.times = jnp.linspace(0, self.maxtime, self.timesteps) - else: - self.maxtime = jnp.max(self.times) - self.timesteps = len(self.times) + self._trajectories = self.trace() @@ -224,7 +252,7 @@ def compute_energy_gc(trajectory): mu = (self.particles.energy - self.particles.mass * vpar[0]**2 / 2) / AbsB[0] return self.particles.mass * vpar**2 / 2 + mu * AbsB self.energy = vmap(compute_energy_gc)(self._trajectories) - elif model == 'FullOrbit' or model == 'FullOrbit_Boris': + elif model == 'FullOrbit': @jit def compute_energy_fo(trajectory): vxvyvz = trajectory[:, 3:] @@ -246,7 +274,7 @@ def trace(self): def compute_trajectory(initial_condition) -> jnp.ndarray: # initial_condition = initial_condition[0] if self.model == 'FullOrbit_Boris' or self.method == 'Boris': - dt=self.maxtime / self.timesteps + dt = self.times[1] - self.times[0] def update_state(state, _): # def update_fn(state): x = state[:3] @@ -269,17 +297,15 @@ def update_state(state, _): else: # import warnings # warnings.simplefilter("ignore", category=FutureWarning) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation - if self.stepsize == "adaptative": + if self.stepsize == "adaptive": controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size) - dt0 = self.maxtime / self.timesteps elif self.stepsize == "constant": controller = StepTo(self.times) - dt0 = None trajectory = diffeqsolve( self.ODE_term, t0=0.0, t1=self.maxtime, - dt0=dt0, + dt0=None, y0=initial_condition, solver=self.method(), args=self.args, @@ -332,7 +358,7 @@ def plot(self, ax=None, show=True, axis_equal=True, n_trajectories_plot=5, **kwa trajectories_xyz = jnp.array(self.trajectories_xyz) n_trajectories_plot = jnp.min(jnp.array([n_trajectories_plot, trajectories_xyz.shape[0]])) for i in random.choice(random.PRNGKey(0), trajectories_xyz.shape[0], (n_trajectories_plot,), replace=False): - ax.plot(trajectories_xyz[i, :, 0], trajectories_xyz[i, :, 1], trajectories_xyz[i, :, 2], linewidth=0.5, **kwargs) + ax.plot(trajectories_xyz[i, :, 0], trajectories_xyz[i, :, 1], trajectories_xyz[i, :, 2], **kwargs) ax.grid(False) if axis_equal: fix_matplotlib_3d(ax) From b1a227e1aa7e09fa1a3bab989b95fed0e8558fb5 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 14 May 2025 16:59:44 +0200 Subject: [PATCH 21/63] Add gc_vs_fo.py for particle tracing and visualization; adjust parameters for performance --- analysis/gc_vs_fo.py | 79 ++++++++++++++++++++++++++++++++++++++ analysis/gradients.py | 18 +++++++-- analysis/poincare_plots.py | 38 +++++++----------- 3 files changed, 106 insertions(+), 29 deletions(-) create mode 100644 analysis/gc_vs_fo.py diff --git a/analysis/gc_vs_fo.py b/analysis/gc_vs_fo.py new file mode 100644 index 0000000..d4c1678 --- /dev/null +++ b/analysis/gc_vs_fo.py @@ -0,0 +1,79 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from jax import vmap +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.fields import BiotSavart +from essos.coils import Coils_from_json +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE +from essos.dynamics import Tracing, Particles +from jax import block_until_ready + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils_from_json(json_file) +field = BiotSavart(coils) + +# Particle parameters +nparticles = number_of_processors_to_use +mass=PROTON_MASS +energy=5000*ONE_EV +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass +print("cyclotron period:", 1/cyclotron_frequency) + +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) + +particles_passing = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.1], phase_angle_full_orbit=0) +particles_traped = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.9], phase_angle_full_orbit=0) +particles = particles_passing.join(particles_traped, field=field) + +# Tracing parameters +tmax = 1e-4 +trace_tolerance = 1e-15 +dt_gc = 1e-7 +dt_fo = 1e-9 +num_steps_gc = int(tmax/dt_gc) +num_steps_fo = int(tmax/dt_fo) + +# Trace in ESSOS +time0 = time() +tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, + maxtime=tmax, timesteps=num_steps_gc, tol_step_size=trace_tolerance) +trajectories_guidingcenter = block_until_ready(tracing_gc.trajectories) +print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") + +time0 = time() +tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax, + timesteps=num_steps_fo, tol_step_size=trace_tolerance) +block_until_ready(tracing_fo.trajectories) +print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") + +# Plot trajectories, velocity parallel to the magnetic field, and energy error +fig = plt.figure(figsize=(9, 8)) +ax = fig.add_subplot(projection='3d') +coils.plot(ax=ax, show=False) +tracing_gc.plot(ax=ax, show=False, color='black', linewidth=2) +tracing_fo.plot(ax=ax, show=False) +plt.tight_layout() + +plt.figure(figsize=(9, 6)) +plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy[0]/particles.energy-1), label='Guiding Center', color='red') +plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy[0]/particles.energy-1), label='Full Orbit', color='blue') +plt.xlabel('Time (ms)') +plt.ylabel('Relative Energy Error') +plt.xlim(0, tmax*1000) +plt.ylim(bottom=0) +plt.legend() +plt.tight_layout() +plt.savefig(os.path.join(os.path.dirname(__file__), 'energies.png'), dpi=300) + + +plt.show() + +## Save results in vtk format to analyze in Paraview +tracing_gc.to_vtk(os.path.join(os.path.dirname(__file__), 'trajectories_gc')) +tracing_fo.to_vtk(os.path.join(os.path.dirname(__file__), 'trajectories_fo')) +coils.to_vtk(os.path.join(os.path.dirname(__file__), 'coils')) \ No newline at end of file diff --git a/analysis/gradients.py b/analysis/gradients.py index 8cf8798..0e991f6 100644 --- a/analysis/gradients.py +++ b/analysis/gradients.py @@ -3,7 +3,7 @@ number_of_processors_to_use = 8 # Parallelization, this should divide ntheta*nphi os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time -from jax import jit, grad +from jax import jit, grad, block_until_ready import jax.numpy as jnp import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) @@ -47,19 +47,31 @@ grad_loss_partial = jit(grad(loss_partial)) +time0 = time() +loss = loss_partial(coils.x) +block_until_ready(loss) +print(f"Loss took {time()-time0:.4f} seconds. Gradient would take {(time()-time0)*(coils.x.size +1):.4f} seconds") + +time0 = time() +loss_comp = loss_partial(coils.x) +block_until_ready(loss_comp) +print(f"Compiled loss took {time()-time0:.4f} seconds. Gradient would take {(time()-time0)*(coils.x.size +1):.4f} seconds") + time0 = time() grad_loss = grad_loss_partial(coils.x) +block_until_ready(grad_loss) print(f"Gradient took {time()-time0:.4f} seconds") time0 = time() grad_loss_comp = grad_loss_partial(coils.x) +block_until_ready(grad_loss_comp) print(f"Compiled gradient took {time()-time0:.4f} seconds") # Parameter to perturb param = 42 # Set the possible perturbations -h_list = jnp.arange(-10, -1.9, 1/3) +h_list = jnp.arange(-9, -0.9, 1/3) h_list = 10.**h_list # Number of orders for finite differences @@ -102,8 +114,6 @@ plt.grid(which='major', axis='y') for spine in plt.gca().spines.values(): spine.set_zorder(0) -# plt.yticks([1e-11, 1e-9, 1e-7, 1e-5, 1e-3]) -# plt.gca().yaxis.set_minor_locator(plt.NullLocator()) plt.tight_layout() plt.savefig(os.path.join(os.path.dirname(__file__), 'gradients.pdf')) plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" ,'gradients.pdf')) diff --git a/analysis/poincare_plots.py b/analysis/poincare_plots.py index 4c8179f..c43fb00 100644 --- a/analysis/poincare_plots.py +++ b/analysis/poincare_plots.py @@ -1,6 +1,6 @@ import os from functools import partial -number_of_processors_to_use = 4 # Parallelization, this should divide ntheta*nphi +number_of_processors_to_use = 1 # Parallelization, this should divide ntheta*nphi os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time from jax import jit, grad, block_until_ready @@ -15,14 +15,14 @@ # Input parameters tmax_fl = 50000 -tmax_gc = 5e-3 +tmax_gc = 1e-3 tmax_fo = 1e-3 -nparticles = number_of_processors_to_use*8 +nparticles = number_of_processors_to_use*1 nfieldlines = number_of_processors_to_use*8 s = 0.25 # s-coordinate: flux surface label -trace_tolerance = 1e-14 -dt_fo = 1e-10 +trace_tolerance = 1e-15 +dt_fo = 1e-9 dt_gc = 1e-7 timesteps_gc = int(tmax_gc/dt_gc) timesteps_fo = int(tmax_fo/dt_fo) @@ -48,18 +48,18 @@ particles = Particles(initial_xyz=initial_xyz_particles, mass=mass, energy=energy, field=field, min_vparallel_over_v=0.8) # Trace in ESSOS -time0 = time() -tracing_fl = Tracing(field=field, model='FieldLine', initial_conditions=initial_xyz_fieldlines, - maxtime=tmax_fl, timesteps=tmax_fl*10, tol_step_size=trace_tolerance) -block_until_ready(tracing_fl) -print(f"ESSOS tracing of {nfieldlines} field lines took {time()-time0:.2f} seconds") +# time0 = time() +# tracing_fl = Tracing(field=field, model='FieldLine', initial_conditions=initial_xyz_fieldlines, +# maxtime=tmax_fl, timesteps=tmax_fl*10, tol_step_size=trace_tolerance) +# block_until_ready(tracing_fl) +# print(f"ESSOS tracing of {nfieldlines} field lines took {time()-time0:.2f} seconds") time0 = time() tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo, timesteps=timesteps_fo, tol_step_size=trace_tolerance) -tracing_fo.trajectories = tracing_fo.trajectories[:, 0::1000, :] -tracing_fo.times = tracing_fo.times[0::1000] -tracing_fo.energy = tracing_fo.energy[:, 0::1000] +# tracing_fo.trajectories = tracing_fo.trajectories[:, 0::100, :] +# tracing_fo.times = tracing_fo.times[0::100] +# tracing_fo.energy = tracing_fo.energy[:, 0::100] block_until_ready(tracing_fo) print(f"ESSOS tracing of {nparticles} particles with FO for {tmax_fo:.1e}s took {time()-time0:.2f} seconds") @@ -69,18 +69,6 @@ block_until_ready(tracing_gc) print(f"ESSOS tracing of {nparticles} particles with GC for {tmax_gc:.1e}s took {time()-time0:.2f} seconds") -# plt.figure(figsize=(9, 6)) -# plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy[0]/particles.energy-1), label='Guiding Center', color='red') -# plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy[0]/particles.energy-1), label='Full Orbit', color='blue') -# plt.xlabel('Time (ms)') -# plt.ylabel('Relative Energy Error') -# plt.xlim(0, tmax*1000) -# plt.ylim(bottom=0) -# plt.legend() -# plt.tight_layout() -# plt.savefig(os.path.join(os.path.dirname(__file__), 'energies.png'), dpi=300) - - # fig = plt.figure(figsize=(9, 6)) # ax = fig.add_subplot(projection='3d') # coils.plot(ax=ax, show=False) From 8332e1e535c9478cf8bb71584d2a7f439686b16f Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 19 May 2025 16:14:45 +0200 Subject: [PATCH 22/63] Add output directory creation and update file saving paths in analysis scripts --- analysis/fo_integrators.py | 34 ++++++++++++++++++---------------- analysis/gc_integrators.py | 35 +++++++++++++++++++++-------------- analysis/gc_vs_fo.py | 18 +++++++++++------- analysis/gradients.py | 13 +++++++++---- analysis/poincare_plots.py | 9 ++++++--- 5 files changed, 65 insertions(+), 44 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index d654beb..25b971d 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -10,9 +10,12 @@ from essos.coils import Coils_from_json from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles -# import integrators import diffrax +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils_from_json(json_file) @@ -30,7 +33,7 @@ particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.8], field=field) # Tracing parameters -tmax = 1e-5 +tmax = 1e-4 dt = 1e-9 num_steps = int(tmax/dt) @@ -42,47 +45,46 @@ if method_name != 'Boris': energies = [] tracing_times = [] - for trace_tolerance in [1e-8, 1e-10, 1e-12, 1e-14]: + for trace_tolerance in [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]: time0 = time() tracing = Tracing('FullOrbit', field, tmax, method=method, timesteps=num_steps, stepsize='adaptive', tol_step_size=trace_tolerance, particles=particles) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] - print(f"Tracing with adaptative {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + print(f"Tracing with adaptive {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] - ax.plot(tracing_times, energies, label=f'adaptative {method_name}', marker='o', markersize=3, linestyle='-') + ax.plot(tracing_times, energies, label=f'{method_name} adapt', marker='o', markersize=3, linestyle='-') energies = [] tracing_times = [] - for n_points_in_gyration in [5, 10, 20, 30, 40]: + for n_points_in_gyration in [10, 20, 50, 75, 100, 150, 200]: dt = 1/(n_points_in_gyration*cyclotron_frequency) num_steps = int(tmax/dt) time0 = time() tracing = Tracing('FullOrbit', field, tmax, method=method, timesteps=num_steps, - stepsize="constant", particles=particles) + stepsize="constant", particles=particles) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] - print(f"Tracing with {method_name} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") + print(f"Tracing with {method_name} and step {dt:.2e} took {tracing_times[-1]:.2f} seconds") energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method_name}', marker='o', markersize=4, linestyle='-') -ax.legend() +ax.legend(fontsize=15, loc='upper left') ax.set_xlabel('Computation time (s)') ax.set_ylabel('Relative Energy Error') -# ax.set_xscale('log') +ax.set_xscale('log') ax.set_yscale('log') -ax.tick_params(axis='x', which='minor', length=0) -yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] -ax.set_yticks(yticks) -ax.set_ylim(top=1e-6) -plt.grid() +ax.set_xlim(1e-1, 1e2) +ax.set_ylim(1e-16, 1e-4) +plt.grid(axis='x', which='both', linestyle='--', linewidth=0.6) +plt.grid(axis='y', which='major', linestyle='--', linewidth=0.6) plt.tight_layout() -plt.savefig(os.path.join(os.path.dirname(__file__), 'fo_integration.pdf')) +plt.savefig(os.path.join(output_dir, 'fo_integration.pdf')) plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'fo_integration.pdf')) plt.show() diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index 50478ba..9520c0b 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -13,6 +13,10 @@ # import integrators import diffrax +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils_from_json(json_file) @@ -36,24 +40,26 @@ fig, ax = plt.subplots(figsize=(9, 6)) -for method in ['Tsit5', 'Dopri5', 'Dopri8']: +for method in ['Tsit5', 'Dopri5', 'Dopri8', 'Kvaerno5']: energies = [] tracing_times = [] - for trace_tolerance in [1e-9, 1e-10, 1e-11, 1e-12, 1e-13]: + for tolerance in [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]: time0 = time() tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), timesteps=num_steps, - stepsize='adaptive', tol_step_size=trace_tolerance, particles=particles,) + stepsize='adaptive', tol_step_size=tolerance, particles=particles) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] - print(f"Tracing with adaptive {method} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + print(f"Tracing with adaptive {method} and {tolerance=:.0e} took {tracing_times[-1]:.2f} seconds") energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] - ax.plot(tracing_times, energies, label=f'adaptive {method}', marker='o', markersize=3, linestyle='-') + ax.plot(tracing_times, energies, label=f'{method} adapt', marker='o', markersize=3, linestyle='-') + + if method == 'Kvaerno5': continue energies = [] tracing_times = [] - for dt in [2e-7, 1e-7, 5e-8, 2e-8]: + for dt in [4e-7, 2e-7, 1e-7, 8e-8, 6e-8, 4e-8, 2e-8, 1e-8]: num_steps = int(tmax/dt) time0 = time() tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), @@ -61,23 +67,24 @@ block_until_ready(tracing.trajectories) tracing_times += [time() - time0] - print(f"Tracing with {method} and step {tmax/num_steps:.2e} took {tracing_times[-1]:.2f} seconds") + print(f"Tracing with {method} and {dt=:.2e} took {tracing_times[-1]:.2f} seconds") energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') -ax.legend() +ax.legend(fontsize=15) ax.set_xlabel('Computation time (s)') ax.set_ylabel('Relative Energy Error') -# ax.set_xscale('log') ax.set_yscale('log') -ax.tick_params(axis='x', which='minor', length=0) -yticks = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14, 1e-16] -ax.set_yticks(yticks) -plt.grid() +ax.set_xscale('log') +ax.set_yscale('log') +ax.set_xlim(1e-1, 1e2) +ax.set_ylim(1e-16, 1e-4) +plt.grid(axis='x', which='both', linestyle='--', linewidth=0.6) +plt.grid(axis='y', which='major', linestyle='--', linewidth=0.6) plt.tight_layout() -plt.savefig(os.path.join(os.path.dirname(__file__), 'gc_integration.pdf')) +plt.savefig(os.path.join(output_dir, 'gc_integration.pdf')) plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'gc_integration.pdf')) plt.show() diff --git a/analysis/gc_vs_fo.py b/analysis/gc_vs_fo.py index d4c1678..b07d7ec 100644 --- a/analysis/gc_vs_fo.py +++ b/analysis/gc_vs_fo.py @@ -11,6 +11,10 @@ from essos.dynamics import Tracing, Particles from jax import block_until_ready +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils_from_json(json_file) @@ -31,8 +35,8 @@ particles = particles_passing.join(particles_traped, field=field) # Tracing parameters -tmax = 1e-4 -trace_tolerance = 1e-15 +tmax = 1e-3 +trace_tolerance = 1e-14 dt_gc = 1e-7 dt_fo = 1e-9 num_steps_gc = int(tmax/dt_gc) @@ -68,12 +72,12 @@ plt.ylim(bottom=0) plt.legend() plt.tight_layout() -plt.savefig(os.path.join(os.path.dirname(__file__), 'energies.png'), dpi=300) - +plt.savefig(os.path.join(output_dir, 'energies.png'), dpi=300) +plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" ,'energies.png'), dpi=300) plt.show() ## Save results in vtk format to analyze in Paraview -tracing_gc.to_vtk(os.path.join(os.path.dirname(__file__), 'trajectories_gc')) -tracing_fo.to_vtk(os.path.join(os.path.dirname(__file__), 'trajectories_fo')) -coils.to_vtk(os.path.join(os.path.dirname(__file__), 'coils')) \ No newline at end of file +# tracing_gc.to_vtk(os.path.join(output_dir, 'trajectories_gc')) +# tracing_fo.to_vtk(os.path.join(output_dir, 'trajectories_fo')) +# coils.to_vtk(os.path.join(output_dir, 'coils')) \ No newline at end of file diff --git a/analysis/gradients.py b/analysis/gradients.py index 0e991f6..8fa4939 100644 --- a/analysis/gradients.py +++ b/analysis/gradients.py @@ -11,6 +11,10 @@ from essos.fields import Vmec from essos.objective_functions import loss_BdotN +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Optimization parameters max_coil_length = 40 max_coil_curvature = 0.5 @@ -104,17 +108,18 @@ plt.plot(h_list, fd_diff[1], "^-", label=f'2nd order', clip_on=False, linewidth=2.5) plt.plot(h_list, fd_diff[2], "*-", label=f'4th order', clip_on=False, linewidth=2.5) plt.plot(h_list, fd_diff[3], "s-", label=f'6th order', clip_on=False, linewidth=2.5) -plt.legend() +plt.legend(fontsize=15) plt.xlabel('Finite differences stepsize h') plt.ylabel('Relative difference') plt.xscale('log') plt.yscale('log') +plt.ylim(1e-13, 1e-1) plt.xlim(jnp.min(h_list), jnp.max(h_list)) -plt.grid(which='both', axis='x') -plt.grid(which='major', axis='y') +plt.grid(which='both', axis='x', linestyle='--', linewidth=0.6) +plt.grid(which='major', axis='y', linestyle='--', linewidth=0.6) for spine in plt.gca().spines.values(): spine.set_zorder(0) plt.tight_layout() -plt.savefig(os.path.join(os.path.dirname(__file__), 'gradients.pdf')) +plt.savefig(os.path.join(output_dir, 'gradients.pdf')) plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" ,'gradients.pdf')) plt.show() \ No newline at end of file diff --git a/analysis/poincare_plots.py b/analysis/poincare_plots.py index c43fb00..c2c9d87 100644 --- a/analysis/poincare_plots.py +++ b/analysis/poincare_plots.py @@ -12,6 +12,9 @@ from essos.fields import BiotSavart from essos.dynamics import Tracing, Particles +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) # Input parameters tmax_fl = 50000 @@ -97,7 +100,7 @@ # ax.set_ylim(-0.3, 0.3) # plt.grid(visible=False) # plt.tight_layout() -# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fl.png'), dpi=300) +# plt.savefig(os.path.join(output_dir, 'poincare_plot_fl.png'), dpi=300) # plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_fl.png'), dpi=300) @@ -111,7 +114,7 @@ # plt.ylim(-0.3, 0.3) # plt.grid(visible=False) # plt.tight_layout() -# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fo.png'), dpi=300) +# plt.savefig(os.path.join(output_dir 'poincare_plot_fo.png'), dpi=300) # plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_fo.png'), dpi=300) @@ -125,7 +128,7 @@ # ax.set_ylim(-0.3, 0.3) # plt.grid(visible=False) # plt.tight_layout() -# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_gc.png'), dpi=300) +# plt.savefig(os.path.join(output_dir, 'poincare_plot_gc.png'), dpi=300) # plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_gc.png'), dpi=300) # plt.show() \ No newline at end of file From c64a9c1f2453f83402f1ae4927a632e947027bd7 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 19 May 2025 16:14:55 +0200 Subject: [PATCH 23/63] Add comparison_coils.py for BiotSavart field analysis and performance evaluation --- analysis/comparison_coils.py | 226 +++++++++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 analysis/comparison_coils.py diff --git a/analysis/comparison_coils.py b/analysis/comparison_coils.py new file mode 100644 index 0000000..a835246 --- /dev/null +++ b/analysis/comparison_coils.py @@ -0,0 +1,226 @@ +import os +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from jax import block_until_ready +from essos.fields import BiotSavart as BiotSavart_essos +from essos.coils import Coils_from_simsopt, Curves_from_simsopt +from simsopt import load +from simsopt.geo import CurveXYZFourier, curves_to_vtk +from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries +from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +n_segments = 100 + +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +nfp_array = [3, 2, 5, 4, 2] +curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] +currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] +name_array = ["NCSX", "QA(json)", "W7-X", "HSX", "Giuliani"] + +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') +for nfp, curves_stel, currents_stel, name in zip(nfp_array, curves_array, currents_array, name_array): + print(f' Running {name} and saving to output directory...') + if currents_stel is None: + json_file_stel = curves_stel + field_simsopt = load(json_file_stel) + coils_simsopt = field_simsopt.coils + curves_simsopt = [coil.curve for coil in coils_simsopt] + currents_simsopt = [coil.current for coil in coils_simsopt] + coils_essos = Coils_from_simsopt(json_file_stel, nfp) + curves_essos = Curves_from_simsopt(json_file_stel, nfp) + else: + coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + currents_simsopt = [c.current for c in coils_simsopt] + field_simsopt = BiotSavart_simsopt(coils_simsopt) + + coils_essos = Coils_from_simsopt(coils_simsopt, nfp) + curves_essos = Curves_from_simsopt(curves_simsopt, nfp) + + field_essos = BiotSavart_essos(coils_essos) + + coils_essos_to_simsopt = coils_essos.to_simsopt() + curves_essos_to_simsopt = curves_essos.to_simsopt() + field_essos_to_simsopt = BiotSavart_simsopt(coils_essos_to_simsopt) + + # curves_to_vtk(curves_simsopt, os.path.join(output_dir,f"curves_simsopt_{name}")) + # curves_essos.to_vtk(os.path.join(output_dir,f"curves_essos_{name}")) + # curves_to_vtk(curves_essos_to_simsopt, os.path.join(output_dir,f"curves_essos_to_simsopt_{name}")) + + base_coils_simsopt = coils_simsopt[:int(len(coils_simsopt)/2/nfp)] + R = jnp.mean(jnp.array([jnp.sqrt(coil.curve.x[coil.curve.local_dof_names.index('xc(0)')]**2 + +coil.curve.x[coil.curve.local_dof_names.index('yc(0)')]**2) + for coil in base_coils_simsopt])) + x = jnp.array([R+0.01,R,R]) + y = jnp.array([R,R+0.01,R-0.01]) + z = jnp.array([0.05,0.06,0.07]) + + positions = jnp.array((x,y,z)) + + def update_nsegments_simsopt(curve_simsopt, n_segments): + new_curve = CurveXYZFourier(n_segments, curve_simsopt.order) + new_curve.x = curve_simsopt.x + return new_curve + + coils_essos.n_segments = n_segments + + base_curves_simsopt = [update_nsegments_simsopt(coil_simsopt.curve, n_segments) for coil_simsopt in base_coils_simsopt] + coils_simsopt = coils_via_symmetries(base_curves_simsopt, currents_simsopt[0:len(base_coils_simsopt)], nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + + # Running the first time for compilation + [curve.gamma() for curve in curves_simsopt] + coils_essos.gamma + + # Running the second time for coils characteristics comparison + start_time = time() + gamma_curves_simsopt = block_until_ready(jnp.array([curve.gamma() for curve in curves_simsopt])) + t_gamma_avg_simsopt = time() - start_time + + start_time = time() + gamma_curves_essos = block_until_ready(jnp.array(coils_essos.gamma)) + t_gamma_avg_essos = time() - start_time + + start_time = time() + gammadash_curves_simsopt = block_until_ready(jnp.array([curve.gammadash() for curve in curves_simsopt])) + t_gammadash_avg_simsopt = time() - start_time + + start_time = time() + gammadash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dash)) + t_gammadash_avg_essos = time() - start_time + + start_time = time() + gammadashdash_curves_simsopt = block_until_ready(jnp.array([curve.gammadashdash() for curve in curves_simsopt])) + t_gammadashdash_avg_simsopt = time() - start_time + + start_time = time() + gammadashdash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dashdash)) + t_gammadashdash_avg_essos = time() - start_time + + start_time = time() + curvature_curves_simsopt = block_until_ready(jnp.array([curve.kappa() for curve in curves_simsopt])) + t_curvature_avg_simsopt = time() - start_time + + start_time = time() + curvature_curves_essos = block_until_ready(jnp.array(coils_essos.curvature)) + t_curvature_avg_essos = time() - start_time + + gamma_error_avg = jnp.linalg.norm(gamma_curves_essos - gamma_curves_simsopt) + gammadash_error_avg = jnp.linalg.norm(gammadash_curves_essos - gammadash_curves_simsopt) + gammadashdash_error_avg = jnp.linalg.norm(gammadashdash_curves_essos - gammadashdash_curves_simsopt) + curvature_error_avg = jnp.linalg.norm(curvature_curves_essos - curvature_curves_simsopt) + + # Magnetic field comparison + + field_essos = BiotSavart_essos(coils_essos) + field_simsopt = BiotSavart_simsopt(coils_simsopt) + + t_B_avg_essos = 0 + t_B_avg_simsopt = 0 + B_error_avg = 0 + t_dB_by_dX_avg_essos = 0 + t_dB_by_dX_avg_simsopt = 0 + dB_by_dX_error_avg = 0 + + for position in positions: + field_essos.B(position) + time1 = time() + result_B_essos = field_essos.B(position) + t_B_avg_essos = t_B_avg_essos + time() - time1 + normB_essos = jnp.linalg.norm(result_B_essos) + + field_simsopt.set_points(jnp.array([position])) + field_simsopt.B() + time3 = time() + field_simsopt.set_points(jnp.array([position])) + result_simsopt = field_simsopt.B() + t_B_avg_simsopt = t_B_avg_simsopt + time() - time3 + normB_simsopt = jnp.linalg.norm(jnp.array(result_simsopt)) + + B_error_avg = B_error_avg + jnp.abs(normB_essos - normB_simsopt) + + field_essos.dB_by_dX(position) + time1 = time() + field_simsopt.set_points(jnp.array([position])) + result_dB_by_dX_essos = field_essos.dB_by_dX(position) + t_dB_by_dX_avg_essos = t_dB_by_dX_avg_essos + time() - time1 + norm_dB_by_dX_essos = jnp.linalg.norm(result_dB_by_dX_essos) + + field_simsopt.dB_by_dX() + time3 = time() + field_simsopt.set_points(jnp.array([position])) + result_dB_by_dX_simsopt = field_simsopt.dB_by_dX() + t_dB_by_dX_avg_simsopt = t_dB_by_dX_avg_simsopt + time() - time3 + norm_dB_by_dX_simsopt = jnp.linalg.norm(jnp.array(result_dB_by_dX_simsopt)) + + dB_by_dX_error_avg = dB_by_dX_error_avg + jnp.abs(norm_dB_by_dX_essos - norm_dB_by_dX_simsopt) + + # Labels and corresponding absolute errors (ESSOS - SIMSOPT) + quantities_errors = [ + (r"$B$", jnp.abs(B_error_avg)), + (r"$B'$", jnp.abs(dB_by_dX_error_avg)), + (r"$\Gamma$", jnp.abs(gamma_error_avg)), + (r"$\Gamma'$", jnp.abs(gammadash_error_avg)), + (r"$\Gamma''$", jnp.abs(gammadashdash_error_avg)), + (r"$\kappa$", jnp.abs(curvature_error_avg)), + ] + + labels = [q[0] for q in quantities_errors] + error_vals = [q[1] for q in quantities_errors] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.6 + + fig, ax = plt.subplots(figsize=(9, 5)) + ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Absolute error") + ax.set_yscale("log") + ax.set_ylim(1e-17, 1e-12) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparison_error_BiotSavart_{name}.pdf"), transparent=True) + plt.close() + + + # Labels and corresponding timings + quantities = [ + (r"$B$", t_B_avg_essos, t_B_avg_simsopt), + (r"$B'$", t_dB_by_dX_avg_essos, t_dB_by_dX_avg_simsopt), + (r"$\Gamma$", t_gamma_avg_essos, t_gamma_avg_simsopt), + (r"$\Gamma'$", t_gammadash_avg_essos, t_gammadash_avg_simsopt), + (r"$\Gamma''$", t_gammadashdash_avg_essos, t_gammadashdash_avg_simsopt), + (r"$\kappa$", t_curvature_avg_essos, t_curvature_avg_simsopt), + ] + + labels = [q[0] for q in quantities] + essos_vals = [q[1] for q in quantities] + simsopt_vals = [q[2] for q in quantities] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.35 + + fig, ax = plt.subplots(figsize=(9, 5)) + ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") + ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Computation time (s)") + ax.set_yscale("log") + ax.set_ylim(1e-5, 1e-1) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + ax.legend(fontsize=12) + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparison_time_BiotSavart_{name}.pdf"), transparent=True) + plt.close() From f86785711058a5786f9ed7ef2ae4fb95cc40aa0b Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 21 May 2025 11:28:39 +0200 Subject: [PATCH 24/63] Fix errors derived from dynamics refactoring --- essos/dynamics.py | 10 +++-- .../fullorbit_SIMSOPT_vs_ESSOS.py | 41 ++++++++++--------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/essos/dynamics.py b/essos/dynamics.py index 9e567de..0ea32dc 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -11,6 +11,7 @@ from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, FUSION_ALPHA_PARTICLE_ENERGY from essos.plot import fix_matplotlib_3d from essos.util import roots +import warnings mesh = Mesh(jax.devices(), ("dev",)) sharding = NamedSharding(mesh, PartitionSpec("dev", None)) @@ -165,7 +166,8 @@ def __init__(self, model: str, field, maxtime: float, method=None, times=None, assert stepsize in ["adaptive", "constant"], "stepsize must be 'adaptive' or 'constant'" if method == 'Boris': assert model == 'FullOrbit', "Method 'Boris' is only available for full orbit model" - assert stepsize == "constant", "Method 'Boris' is only available for constant step size" + warnings.warn("The 'Boris' method is only supported with a constant step size. 'stepsize' has been set to constant.") + stepsize = "constant" self.model = model self.method = method self.stepsize = stepsize @@ -192,15 +194,15 @@ def __init__(self, model: str, field, maxtime: float, method=None, times=None, assert tol_step_size is not None, "tol_step_size must be provided for adaptive step size" assert isinstance(tol_step_size, float), "tol_step_size must be a float" assert tol_step_size > 0, "tol_step_size must be greater than 0" - # self.dt0 = dt0 self.tol_step_size = tol_step_size elif stepsize == "constant": assert maxtime == self.times[-1], "maxtime must be equal to the last time in the times array for constant step size" - # self.dt0 = None + self.tol_step_size = None if model == 'FieldLine': assert initial_conditions is not None, "initial_conditions must be provided for FieldLine model" self.initial_conditions = initial_conditions + self.particles = None elif model == 'GuidingCenter' or model == 'FullOrbit': assert isinstance(particles, Particles), "particles object must be provided for GuidingCenter and FullOrbit models" self.particles = particles @@ -273,7 +275,7 @@ def compute_energy_fo(trajectory): def trace(self): def compute_trajectory(initial_condition) -> jnp.ndarray: # initial_condition = initial_condition[0] - if self.model == 'FullOrbit_Boris' or self.method == 'Boris': + if self.method == 'Boris': dt = self.times[1] - self.times[0] def update_state(state, _): # def update_fn(state): diff --git a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py index c1b2aba..fc9ca34 100644 --- a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py +++ b/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py @@ -9,6 +9,7 @@ from essos.dynamics import Tracing, Particles from essos.fields import BiotSavart as BiotSavart_essos import matplotlib.pyplot as plt +from diffrax import Dopri8 tmax_full = 1e-5 nparticles = 3 @@ -18,7 +19,7 @@ trace_tolerance_ESSOS = 1e-5 mass=PROTON_MASS energy=5000*ONE_EV -model_ESSOS_array = ['FullOrbit', 'FullOrbit_Boris'] +method_ESSOS_array = ['Boris', Dopri8] output_dir = os.path.join(os.path.dirname(__file__), 'output') if not os.path.exists(output_dir): @@ -72,15 +73,15 @@ tracing_array = [] trajectories_ESSOS_array = [] time_ESSOS_array = [] -for model_ESSOS in model_ESSOS_array: - print(f'Tracing ESSOS full orbit '+('Boris' if model_ESSOS=='FullOrbit_Boris' else f'with tolerance={trace_tolerance_ESSOS}')+f' and plotting the result.') +for method_ESSOS in method_ESSOS_array: + print(f'Tracing ESSOS full orbit '+('Boris' if method_ESSOS=='Boris' else f'with tolerance={trace_tolerance_ESSOS}')+f' and plotting the result.') t1 = time.time() - tracing = block_until_ready(Tracing(field=field_essos, model=model_ESSOS, particles=particles, - maxtime=tmax_full, timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) + tracing = block_until_ready(Tracing('FullOrbit', field_essos, tmax_full, method=method_ESSOS, particles=particles, + timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) trajectories_ESSOS = tracing.trajectories time_ESSOS = time.time()-t1 - print(f" Time for ESSOS tracing={time.time()-t1:.3f}s "+('Boris' if model_ESSOS=='FullOrbit_Boris' else f'')+f". Num steps={len(trajectories_ESSOS[0])}") - tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_ESSOS')) + print(f" Time for ESSOS tracing={time.time()-t1:.3f}s "+('Boris' if method_ESSOS=='Boris' else f'')+f". Num steps={len(trajectories_ESSOS[0])}") + tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='Boris' else '')+'_ESSOS')) tracing_array.append(tracing) trajectories_ESSOS_array.append(trajectories_ESSOS) time_ESSOS_array.append(time_ESSOS) @@ -93,9 +94,9 @@ SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') -for model_ESSOS, tracing, trajectories_ESSOS in zip(model_ESSOS_array, tracing_array, trajectories_ESSOS_array): +for method_ESSOS, tracing, trajectories_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array): relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy - plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS'+(' Boris' if model_ESSOS=='FullOrbit_Boris' else f' Tol={trace_tolerance_ESSOS}')) + plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS'+(' Boris' if method_ESSOS=='Boris' else f' Tol={trace_tolerance_ESSOS}')) plt.legend() plt.yscale('log') plt.xlabel('Time (s)') @@ -107,9 +108,9 @@ labels = [f'SIMSOPT Tol={tol}' for tol in trace_tolerance_SIMSOPT_array] times = time_SIMSOPT_array plt.figure() -for model_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(model_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): +for method_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): # Plot time comparison in a bar chart - labels += ([f'ESSOS Boris Algorithm'] if model_ESSOS=='FullOrbit_Boris' else [f'ESSOS Tol={trace_tolerance_ESSOS}']) + labels += ([f'ESSOS Boris Algorithm'] if method_ESSOS=='FullOrbit_Boris' else [f'ESSOS Tol={trace_tolerance_ESSOS}']) times += [time_ESSOS] bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red', 'orange'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black']*2, hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']*2) plt.xlabel('Tracing Tolerance of SIMSOPT') @@ -120,7 +121,7 @@ red_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') orange_patch = plt.Line2D([0], [0], color='orange', lw=4, label=f'ESSOS\nBoris Algorithm') plt.legend(handles=[blue_patch, red_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, 'times_full_orbit'+('_boris' if method_ESSOS=='Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) plt.close() def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): @@ -136,13 +137,13 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_vx, interp_vy, interp_vz]) return coords_ESSOS_interp -for model_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(model_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): +for method_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): relative_error_array = [] for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] tracing.trajectories = trajectories_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_ESSOS_interp')) + if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_ESSOS_interp')) relative_error_trajectories_SIMSOPT_vs_ESSOS = [] plt.figure() @@ -166,7 +167,7 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): plt.ylabel('Relative Error') plt.yscale('log') plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+f'_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, f'relative_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+f'_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) plt.close() relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) @@ -187,7 +188,7 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): plt.xlabel('R') plt.ylabel('Z') plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+f'_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+f'_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) plt.close() plt.figure() @@ -203,7 +204,7 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): plt.ylabel(r'$v_x/v$') # plt.yscale('log') plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+f'_vx_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+f'_vx_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) plt.close() # Calculate RMS error for each tolerance @@ -221,7 +222,7 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) plt.legend() plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'rms_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'rms_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) plt.close() # Calculate maximum error for each tolerance @@ -238,7 +239,7 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) plt.legend() plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'max_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'max_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) plt.close() # Calculate mean error for each tolerance @@ -255,5 +256,5 @@ def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) plt.legend() plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'mean_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'mean_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) plt.close() From 7d7be44f38a1b4bce2a1408f9910879b80f5f2eb Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 21 May 2025 11:28:50 +0200 Subject: [PATCH 25/63] Enhance plotting in gc_integrators.py: adjust figure sizes, add tolerance plots, and improve layout for better visualization --- analysis/comparison_coils.py | 4 +-- analysis/gc_integrators.py | 49 +++++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/analysis/comparison_coils.py b/analysis/comparison_coils.py index a835246..03288c0 100644 --- a/analysis/comparison_coils.py +++ b/analysis/comparison_coils.py @@ -178,7 +178,7 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): X_axis = jnp.arange(len(labels)) bar_width = 0.6 - fig, ax = plt.subplots(figsize=(9, 5)) + fig, ax = plt.subplots(figsize=(9, 6)) ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") ax.set_xticks(X_axis) @@ -210,7 +210,7 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): X_axis = jnp.arange(len(labels)) bar_width = 0.35 - fig, ax = plt.subplots(figsize=(9, 5)) + fig, ax = plt.subplots(figsize=(9, 6)) ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index 9520c0b..dad15dc 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -1,4 +1,5 @@ import os +import gc number_of_processors_to_use = 1 # Parallelization, this should divide nparticles os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time @@ -35,15 +36,17 @@ # Tracing parameters tmax = 1e-4 -dt = 1e-7 -num_steps = int(tmax/dt) fig, ax = plt.subplots(figsize=(9, 6)) - -for method in ['Tsit5', 'Dopri5', 'Dopri8', 'Kvaerno5']: +fig_tol, ax_tol = plt.subplots(figsize=(9, 6)) +markers = ["o-", "^-", "*-", "s-"] +for method, marker in zip(['Tsit5', 'Dopri5', 'Dopri8', 'Kvaerno5'], markers): + dt = 1e-7 + num_steps = int(tmax/dt) energies = [] tracing_times = [] - for tolerance in [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]: + tolerances = [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16] + for tolerance in tolerances: time0 = time() tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), timesteps=num_steps, stepsize='adaptive', tol_step_size=tolerance, particles=particles) @@ -53,7 +56,8 @@ print(f"Tracing with adaptive {method} and {tolerance=:.0e} took {tracing_times[-1]:.2f} seconds") energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] - ax.plot(tracing_times, energies, label=f'{method} adapt', marker='o', markersize=3, linestyle='-') + ax.plot(tracing_times, energies, label=f'{method} adapt', marker='o', markersize=3) + ax_tol.plot(tolerances, energies, marker, label=f'{method} adapt', clip_on=False, linewidth=2.5) if method == 'Kvaerno5': continue @@ -71,21 +75,30 @@ energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') + gc.collect() - -ax.legend(fontsize=15) ax.set_xlabel('Computation time (s)') -ax.set_ylabel('Relative Energy Error') -ax.set_yscale('log') -ax.set_xscale('log') -ax.set_yscale('log') +ax_tol.set_xlabel('Tracing tolerance') ax.set_xlim(1e-1, 1e2) -ax.set_ylim(1e-16, 1e-4) -plt.grid(axis='x', which='both', linestyle='--', linewidth=0.6) -plt.grid(axis='y', which='major', linestyle='--', linewidth=0.6) -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'gc_integration.pdf')) -plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'gc_integration.pdf')) +ax_tol.set_xlim(tolerances[-1], tolerances[0]) + +for axis in [ax, ax_tol]: + axis.legend(fontsize=15) + axis.set_ylabel('Relative Energy Error') + axis.set_xscale('log') + axis.set_yscale('log') + axis.set_ylim(1e-16, 1e-4) + axis.grid(axis='x', which='both', linestyle='--', linewidth=0.6) + axis.grid(axis='y', which='major', linestyle='--', linewidth=0.6) +for figure in [fig, fig_tol]: + figure.tight_layout() + +for spine in ax_tol.spines.values(): + spine.set_zorder(0) + +fig.savefig(os.path.join(output_dir, 'gc_integration.pdf')) +fig.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'gc_integration.pdf')) +fig_tol.savefig(os.path.join(output_dir, 'energy_vs_tol.pdf')) plt.show() ## Save results in vtk format to analyze in Paraview From 53682b645d20881e1ef811a2db8b76ba0dab0a38 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sat, 24 May 2025 17:07:53 +0200 Subject: [PATCH 26/63] Change dynamics to accept Integrator name --- analysis/gc_integrators.py | 8 ++- essos/dynamics.py | 102 +++++++++++++++++++++---------------- 2 files changed, 61 insertions(+), 49 deletions(-) diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index dad15dc..d6efaa3 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -11,8 +11,6 @@ from essos.coils import Coils_from_json from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles -# import integrators -import diffrax output_dir = os.path.join(os.path.dirname(__file__), 'output') if not os.path.exists(output_dir): @@ -45,10 +43,10 @@ num_steps = int(tmax/dt) energies = [] tracing_times = [] - tolerances = [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16] + tolerances = [1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16] for tolerance in tolerances: time0 = time() - tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), timesteps=num_steps, + tracing = Tracing('GuidingCenter', field, tmax, method=method, timesteps=num_steps, stepsize='adaptive', tol_step_size=tolerance, particles=particles) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] @@ -66,7 +64,7 @@ for dt in [4e-7, 2e-7, 1e-7, 8e-8, 6e-8, 4e-8, 2e-8, 1e-8]: num_steps = int(tmax/dt) time0 = time() - tracing = Tracing('GuidingCenter', field, tmax, method=getattr(diffrax, method), + tracing = Tracing('GuidingCenter', field, tmax, method=method, timesteps=num_steps, stepsize="constant", particles=particles) block_until_ready(tracing.trajectories) tracing_times += [time() - time0] diff --git a/essos/dynamics.py b/essos/dynamics.py index 0ea32dc..6fc91dd 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -5,6 +5,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax import jit, vmap, tree_util, random, lax, device_put from functools import partial +import diffrax from diffrax import diffeqsolve, ODETerm, SaveAt, Dopri8, PIDController, Event, AbstractSolver, ConstantStepSize, StepTo from essos.coils import Coils from essos.fields import BiotSavart, Vmec @@ -160,6 +161,11 @@ def __init__(self, model: str, field, maxtime: float, method=None, times=None, """ assert model in ["GuidingCenter", "FullOrbit", "FieldLine"], "Model must be one of: 'GuidingCenter', 'FullOrbit', or 'FieldLine'" + if isinstance(method, str) and method != 'Boris': + try: + method = getattr(diffrax, method) + except AttributeError: + raise ValueError(f"String method '{method}' is not a valid diffrax solver") assert method is None or \ method == 'Boris' or \ issubclass(method, AbstractSolver), "Method must be None, 'Boris', or a DIFFRAX solver" @@ -168,6 +174,7 @@ def __init__(self, model: str, field, maxtime: float, method=None, times=None, assert model == 'FullOrbit', "Method 'Boris' is only available for full orbit model" warnings.warn("The 'Boris' method is only supported with a constant step size. 'stepsize' has been set to constant.") stepsize = "constant" + self.model = model self.method = method self.stepsize = stepsize @@ -238,43 +245,15 @@ def condition_Vmec(t, y, args, **kwargs): if self.method is None: self.method = Dopri8 - self._trajectories = self.trace() - - if self.particles is not None: - self.energy = jnp.zeros((self.particles.nparticles, self.timesteps)) - - if model == 'GuidingCenter': - @jit - def compute_energy_gc(trajectory): - xyz = trajectory[:, :3] - vpar = trajectory[:, 3] - AbsB = vmap(self.field.AbsB)(xyz) - mu = (self.particles.energy - self.particles.mass * vpar[0]**2 / 2) / AbsB[0] - return self.particles.mass * vpar**2 / 2 + mu * AbsB - self.energy = vmap(compute_energy_gc)(self._trajectories) - elif model == 'FullOrbit': - @jit - def compute_energy_fo(trajectory): - vxvyvz = trajectory[:, 3:] - return self.particles.mass / 2 * (vxvyvz[:, 0]**2 + vxvyvz[:, 1]**2 + vxvyvz[:, 2]**2) - self.energy = vmap(compute_energy_fo)(self._trajectories) - elif model == 'FieldLine': - self.energy = jnp.ones((len(initial_conditions), self.timesteps)) - - self.trajectories_xyz = vmap(lambda xyz: vmap(lambda point: self.field.to_xyz(point[:3]))(xyz))(self.trajectories) - if isinstance(field, Vmec): - self.loss_fractions, self.total_particles_lost, self.lost_times = self.loss_fraction() + self.trajectories_xyz = vmap(lambda xyz: vmap(lambda point: self.field.to_xyz(point[:3]))(xyz))(self.trajectories) else: - self.loss_fractions = None - self.total_particles_lost = None - self.loss_times = None + self.trajectories_xyz = self.trajectories def trace(self): def compute_trajectory(initial_condition) -> jnp.ndarray: - # initial_condition = initial_condition[0] if self.method == 'Boris': dt = self.times[1] - self.times[0] def update_state(state, _): @@ -297,17 +276,20 @@ def update_state(state, _): _, trajectory = lax.scan(update_state, initial_condition, jnp.arange(len(self.times)-1)) trajectory = jnp.vstack([initial_condition, trajectory]) else: - # import warnings - # warnings.simplefilter("ignore", category=FutureWarning) # see https://github.com/patrick-kidger/diffrax/issues/445 for explanation if self.stepsize == "adaptive": - controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size) + r0 = jnp.linalg.norm(initial_condition[:2]) + dtmax = r0*0.5*jnp.pi/self.particles.total_speed # can at most do quarter of a revolution per step + controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, dtmax=dtmax, rtol=self.tol_step_size, atol=self.tol_step_size) + dt0 = 1e-3 * dtmax # initial guess for first timestep, will be adjusted by adaptive timestepper elif self.stepsize == "constant": controller = StepTo(self.times) + dt0 = None + trajectory = diffeqsolve( self.ODE_term, t0=0.0, t1=self.maxtime, - dt0=None, + dt0=dt0, y0=initial_condition, solver=self.method(), args=self.args, @@ -322,7 +304,7 @@ def update_state(state, _): return jit(vmap(compute_trajectory), in_shardings=sharding, out_shardings=sharding)( device_put(self.initial_conditions, sharding)) - + @property def trajectories(self): return self._trajectories @@ -331,15 +313,36 @@ def trajectories(self): def trajectories(self, value): self._trajectories = value - def _tree_flatten(self): - children = (self.trajectories, self.initial_conditions, self.times) # arrays / dynamic values - aux_data = {'field': self.field, 'model': self.model, 'method': self.method, 'maxtime': self.maxtime, 'timesteps': self.timesteps,'stepsize': - self.stepsize, 'tol_step_size': self.tol_step_size, 'particles': self.particles, 'condition': self.condition} # static values - return (children, aux_data) + def _energy(self): + assert self.model in ['GuidingCenter', 'FullOrbit'], "Energy calculation is only available for GuidingCenter and FullOrbit models" + mass = self.particles.mass - @classmethod - def _tree_unflatten(cls, aux_data, children): - return cls(*children, **aux_data) + if self.model == 'GuidingCenter': + initial_xyz = self.initial_conditions[:, :3] + initial_vparallel = self.initial_conditions[:, 3] + initial_B = vmap(self.field.AbsB)(initial_xyz) + mu_array = (self.particles.energy - 0.5 * mass * jnp.square(initial_vparallel)) / initial_B + def compute_energy(trajectory, mu): + xyz = trajectory[:, :3] + vpar = trajectory[:, 3] + AbsB = vmap(self.field.AbsB)(xyz) + return 0.5 * mass * jnp.square(vpar) + mu * AbsB + + energy = vmap(compute_energy)(self.trajectories, mu_array) + + elif self.model == 'FullOrbit': + def compute_energy(trajectory): + vxvyvz = trajectory[:, 3:] + v_squared = jnp.dot(vxvyvz, vxvyvz, axis=1) + return 0.5 * mass * v_squared + + energy = vmap(compute_energy)(self.trajectories) + + return energy + + @property + def energy(self): + return self._energy() def to_vtk(self, filename): try: import numpy as np @@ -471,7 +474,18 @@ def process_trajectory(X_i, Y_i, T_i): plt.show() return plotting_data - + + def _tree_flatten(self): + children = (self.trajectories, self.initial_conditions, self.times) # arrays / dynamic values + aux_data = {'field': self.field, 'model': self.model, 'method': self.method, 'maxtime': self.maxtime, 'timesteps': self.timesteps,'stepsize': + self.stepsize, 'tol_step_size': self.tol_step_size, 'particles': self.particles, 'condition': self.condition} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) + + tree_util.register_pytree_node(Tracing, Tracing._tree_flatten, Tracing._tree_unflatten) \ No newline at end of file From 093f11d1ad77630b458eac50513f1aa68899b391 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Sat, 24 May 2025 17:08:20 +0200 Subject: [PATCH 27/63] Add comparison_gc.py for comparing gc trajectories between SIMSOPT and ESSOS --- analysis/comparison_gc.py | 254 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 analysis/comparison_gc.py diff --git a/analysis/comparison_gc.py b/analysis/comparison_gc.py new file mode 100644 index 0000000..ad9ea64 --- /dev/null +++ b/analysis/comparison_gc.py @@ -0,0 +1,254 @@ +import os +from time import time +import jax.numpy as jnp +from jax import block_until_ready, random +from simsopt import load +from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) +from essos.coils import Coils_from_simsopt +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +from essos.fields import BiotSavart as BiotSavart_essos +import matplotlib.pyplot as plt + +tmax_gc = 1e-4 +nparticles = 5 +axis_shft=0.02 +R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) +trace_tolerance_SIMSOPT_array = [1e-5, 1e-7, 1e-9, 1e-11] +trace_tolerance_ESSOS = 1e-9 +mass=PROTON_MASS +energy=5000*ONE_EV + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +nfp=2 +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +field_simsopt = load(LandremanPaulQA_json_file) +field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) + +Z0 = jnp.zeros(nparticles) +phi0 = jnp.zeros(nparticles) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) + +phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] + +particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) + +# Trace in SIMSOPT +time_SIMSOPT_array = [] +trajectories_SIMSOPT_array = [] +avg_steps_SIMSOPT = 0 +relative_energy_error_SIMSOPT_array = [] +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') +for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: + print(f'Tracing SIMSOPT guiding center with tolerance={trace_tolerance_SIMSOPT}') + t1 = time() + trajectories_SIMSOPT_this_tolerance, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( + field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, + parallel_speeds=particles.initial_vparallel, tmax=tmax_gc, mode='gc_vac', + charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) + time_SIMSOPT_array.append(time()-t1) + avg_steps_SIMSOPT += sum([len(l) for l in trajectories_SIMSOPT_this_tolerance]) // nparticles + print(f" Time for SIMSOPT tracing={time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT_this_tolerance) + + relative_energy_SIMSOPT = [] + for i, trajectory in enumerate(trajectories_SIMSOPT_this_tolerance): + xyz = jnp.asarray(trajectory[:, 1:4]) + vpar = trajectory[:, 4] + field_simsopt.set_points(xyz) + AbsB = field_simsopt.AbsB()[:,0] + mu = (particles.energy - particles.mass*vpar[0]**2/2)/AbsB[0] + relative_energy_SIMSOPT.append(jnp.abs(particles.mass*vpar**2/2+mu*AbsB-particles.energy)/particles.energy) + relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) + +# particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + +# Trace in ESSOS +num_steps_essos = 1000#int(jnp.mean(jnp.array([len(trajectories_SIMSOPT[0]) for trajectories_SIMSOPT in trajectories_SIMSOPT_array]))) +time_essos = jnp.linspace(0, tmax_gc, num_steps_essos) + +tracing = Tracing('GuidingCenter', field_essos, 1e-7, timesteps=100, method='Dopri8', + stepsize='adaptive', tol_step_size=1e-7, particles=particles) +block_until_ready(tracing.trajectories) + +print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') +start_time = time() +tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri8', + stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) +block_until_ready(tracing.trajectories) +time_ESSOS = time() - start_time + +trajectories_ESSOS = tracing.trajectories +print(f" Time for ESSOS tracing={time_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}") +tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) + +relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy + +print('Plotting the results to output directory...') +plt.figure() +SIMSOPT_energy_interp_this_particle = jnp.zeros((len(trace_tolerance_SIMSOPT_array), nparticles, len(trajectories_SIMSOPT_array[-1][-1][:,0]))) +for j in range(nparticles): + for i, relative_energy_error_SIMSOPT in enumerate(relative_energy_error_SIMSOPT_array): + SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) +plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS Tol={trace_tolerance_ESSOS}') +for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): + plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') +plt.legend() +plt.yscale('log') +plt.xlabel('Time (s)') +plt.ylabel('Average Relative Energy Error') +plt.tight_layout() +plt.savefig(os.path.join(output_dir, f'relative_energy_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.close() + +# Plot time comparison in a bar chart +labels = [f'SIMSOPT\nTol={tol}' for tol in trace_tolerance_SIMSOPT_array] + [f'ESSOS\nTol={trace_tolerance_ESSOS}'] +times = time_SIMSOPT_array + [time_ESSOS] +plt.figure() +bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black'], hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']) +plt.xlabel('Tracing Tolerance of SIMSOPT') +plt.ylabel('Time (s)') +plt.xticks(rotation=45) +plt.tight_layout() +blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') +orange_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') +plt.legend(handles=[blue_patch, orange_patch]) +plt.savefig(os.path.join(output_dir, 'times_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.close() + +def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): + time_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 0] # Time values from guiding center SIMSOPT + # coords_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from guiding center SIMSOPT + coords_ESSOS = jnp.array(trajectory_ESSOS) + + interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) + interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) + interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) + interp_v = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 3]) + + coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_v]) + + return coords_ESSOS_interp + +relative_error_array = [] +for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): + trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] + tracing.trajectories = trajectories_ESSOS_interp + if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS_interp')) + + relative_error_trajectories_SIMSOPT_vs_ESSOS = [] + plt.figure() + for j in range(nparticles): + this_trajectory_SIMSOPT = jnp.array(trajectories_SIMSOPT[j])[:,1:] + this_trajectory_ESSOS = trajectories_ESSOS_interp[j] + average_relative_error = [] + for trajectory_SIMSOPT_t, trajectory_ESSOS_t in zip(this_trajectory_SIMSOPT, this_trajectory_ESSOS): + relative_error_x = jnp.abs(trajectory_SIMSOPT_t[0] - trajectory_ESSOS_t[0])/(jnp.abs(trajectory_SIMSOPT_t[0])+1e-12) + relative_error_y = jnp.abs(trajectory_SIMSOPT_t[1] - trajectory_ESSOS_t[1])/(jnp.abs(trajectory_SIMSOPT_t[1])+1e-12) + relative_error_z = jnp.abs(trajectory_SIMSOPT_t[2] - trajectory_ESSOS_t[2])/(jnp.abs(trajectory_SIMSOPT_t[2])+1e-12) + relative_error_v = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[3])+1e-12) + average_relative_error.append((relative_error_x + relative_error_y + relative_error_z + relative_error_v)/4) + average_relative_error = jnp.array(average_relative_error) + relative_error_trajectories_SIMSOPT_vs_ESSOS.append(average_relative_error) + plt.plot(jnp.linspace(0, tmax_gc, len(average_relative_error))[1:], average_relative_error[1:], label=f'Particle {1+j}') + plt.legend() + plt.xlabel('Time') + plt.ylabel('Relative Error') + plt.yscale('log') + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f'relative_error_guiding_center_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) + plt.close() + + relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) + + plt.figure() + for j in range(nparticles): + R_SIMSOPT = jnp.sqrt(trajectories_SIMSOPT[j][:,1]**2+trajectories_SIMSOPT[j][:,2]**2) + phi_SIMSOPT = jnp.arctan2(trajectories_SIMSOPT[j][:,2], trajectories_SIMSOPT[j][:,1]) + Z_SIMSOPT = trajectories_SIMSOPT[j][:,3] + + R_ESSOS = jnp.sqrt(trajectories_ESSOS_interp[j][:,0]**2+trajectories_ESSOS_interp[j][:,1]**2) + phi_ESSOS = jnp.arctan2(trajectories_ESSOS_interp[j][:,1], trajectories_ESSOS_interp[j][:,0]) + Z_ESSOS = trajectories_ESSOS_interp[j][:,2] + + plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') + plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') + plt.legend() + plt.xlabel('R') + plt.ylabel('Z') + plt.tight_layout() + plt.savefig(os.path.join(output_dir,f'guiding_center_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) + plt.close() + + plt.figure() + for j in range(nparticles): + time_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,0]) + vpar_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,4]) + vpar_ESSOS = jnp.array(trajectories_ESSOS_interp[j][:,3]) + # plt.plot(time_SIMSOPT, jnp.abs((vpar_SIMSOPT-vpar_ESSOS)/vpar_SIMSOPT), '-', linewidth=2.5, label=f'Particle {1+j}') + plt.plot(time_SIMSOPT, vpar_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') + plt.plot(time_SIMSOPT, vpar_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') + plt.legend() + plt.xlabel('Time (s)') + plt.ylabel(r'$v_{\parallel}/v$') + # plt.yscale('log') + plt.tight_layout() + plt.savefig(os.path.join(output_dir,f'guiding_center_vpar_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) + plt.close() + +# Calculate RMS error for each tolerance +rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) + +# Plot RMS error in a bar chart +plt.figure() +bar_width = 0.15 +x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) +for i in range(rms_error_array.shape[1]): + plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Particle {1+i}') +plt.xlabel('Tracing Tolerance of SIMSOPT') +plt.ylabel('RMS Error') +plt.yscale('log') +plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) +plt.legend() +plt.tight_layout() +plt.savefig(os.path.join(output_dir, 'rms_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.close() + +# Calculate maximum error for each tolerance +max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) +# Plot maximum error in a bar chart +plt.figure() +bar_width = 0.15 +x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) +for i in range(max_error_array.shape[1]): + plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Particle {1+i}') +plt.xlabel('Tracing Tolerance of SIMSOPT') +plt.ylabel('Maximum Error') +plt.yscale('log') +plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) +plt.legend() +plt.tight_layout() +plt.savefig(os.path.join(output_dir, 'max_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.close() + +# Calculate mean error for each tolerance +mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) +# Plot mean error in a bar chart +plt.figure() +bar_width = 0.15 +x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) +for i in range(mean_error_array.shape[1]): + plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Particle {1+i}') +plt.xlabel('Tracing Tolerance of SIMSOPT') +plt.ylabel('Mean Error') +plt.yscale('log') +plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) +plt.legend() +plt.tight_layout() +plt.savefig(os.path.join(output_dir, 'mean_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.close() \ No newline at end of file From 1c6584c8e9a81c36802d82c85cb58d96c58182d3 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 26 May 2025 13:06:38 +0200 Subject: [PATCH 28/63] Fixed energy calculation for fo trajectories --- essos/dynamics.py | 10 +++------- .../comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/essos/dynamics.py b/essos/dynamics.py index 6fc91dd..08cc717 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -190,7 +190,7 @@ def __init__(self, model: str, field, maxtime: float, method=None, times=None, assert timesteps is None or \ isinstance(timesteps, (int, float)) and \ - timesteps > 0, "timesteps must be None or a positive float" + timesteps > 0, f"timesteps must be None or a positive float. Got: {type(timesteps)}" assert times is None or \ isinstance(times, jnp.ndarray), "times must be None or a numpy array" self.times = jnp.linspace(0, maxtime, timesteps) if times is None else times @@ -313,7 +313,7 @@ def trajectories(self): def trajectories(self, value): self._trajectories = value - def _energy(self): + def energy(self): assert self.model in ['GuidingCenter', 'FullOrbit'], "Energy calculation is only available for GuidingCenter and FullOrbit models" mass = self.particles.mass @@ -333,17 +333,13 @@ def compute_energy(trajectory, mu): elif self.model == 'FullOrbit': def compute_energy(trajectory): vxvyvz = trajectory[:, 3:] - v_squared = jnp.dot(vxvyvz, vxvyvz, axis=1) + v_squared = jnp.sum(jnp.square(vxvyvz), axis=1) return 0.5 * mass * v_squared energy = vmap(compute_energy)(self.trajectories) return energy - @property - def energy(self): - return self._energy() - def to_vtk(self, filename): try: import numpy as np except ImportError: raise ImportError("The 'numpy' library is required. Please install it using 'pip install numpy'.") diff --git a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py index fc9ca34..fa5fe45 100644 --- a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py +++ b/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py @@ -95,7 +95,7 @@ for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') for method_ESSOS, tracing, trajectories_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array): - relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy + relative_energy_error_ESSOS = jnp.abs(tracing.energy()-particles.energy)/particles.energy plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS'+(' Boris' if method_ESSOS=='Boris' else f' Tol={trace_tolerance_ESSOS}')) plt.legend() plt.yscale('log') From 32b84ad9bf47048745627f04e17dff2fcc12ae47 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 26 May 2025 13:07:14 +0200 Subject: [PATCH 29/63] Finalize gc analysis ESSOS vs SIMSOPT --- analysis/comparison_gc.py | 344 ++++++++++++++++++-------------------- 1 file changed, 159 insertions(+), 185 deletions(-) diff --git a/analysis/comparison_gc.py b/analysis/comparison_gc.py index ad9ea64..cfd57c5 100644 --- a/analysis/comparison_gc.py +++ b/analysis/comparison_gc.py @@ -1,4 +1,6 @@ import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time import jax.numpy as jnp from jax import block_until_ready, random @@ -9,12 +11,14 @@ from essos.dynamics import Tracing, Particles from essos.fields import BiotSavart as BiotSavart_essos import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +plt.rcParams.update({'font.size': 18}) -tmax_gc = 1e-4 +tmax_gc = 5e-4 nparticles = 5 axis_shft=0.02 R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) -trace_tolerance_SIMSOPT_array = [1e-5, 1e-7, 1e-9, 1e-11] +trace_tolerance_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] trace_tolerance_ESSOS = 1e-9 mass=PROTON_MASS energy=5000*ONE_EV @@ -38,26 +42,29 @@ particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) # Trace in SIMSOPT -time_SIMSOPT_array = [] +runtime_SIMSOPT_array = [] trajectories_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 +avg_steps_SIMSOPT_array = [] relative_energy_error_SIMSOPT_array = [] print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}\n') +for trace_tolerance_SIMSOPT in trace_tolerance_array: print(f'Tracing SIMSOPT guiding center with tolerance={trace_tolerance_SIMSOPT}') t1 = time() - trajectories_SIMSOPT_this_tolerance, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( + trajectories_SIMSOPT, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, parallel_speeds=particles.initial_vparallel, tmax=tmax_gc, mode='gc_vac', charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) - time_SIMSOPT_array.append(time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in trajectories_SIMSOPT_this_tolerance]) // nparticles - print(f" Time for SIMSOPT tracing={time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - trajectories_SIMSOPT_array.append(trajectories_SIMSOPT_this_tolerance) + runtime_SIMSOPT = time() - t1 + runtime_SIMSOPT_array.append(runtime_SIMSOPT) + avg_steps_SIMSOPT = sum([len(l) for l in trajectories_SIMSOPT]) // nparticles + avg_steps_SIMSOPT_array.append(avg_steps_SIMSOPT) + # print(trajectories_SIMSOPT_this_tolerance[0].shape) + print(f"Time for SIMSOPT tracing={runtime_SIMSOPT:.3f}s. Avg num steps={avg_steps_SIMSOPT}\n") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT) relative_energy_SIMSOPT = [] - for i, trajectory in enumerate(trajectories_SIMSOPT_this_tolerance): + for i, trajectory in enumerate(trajectories_SIMSOPT): xyz = jnp.asarray(trajectory[:, 1:4]) vpar = trajectory[:, 4] field_simsopt.set_points(xyz) @@ -66,189 +73,156 @@ relative_energy_SIMSOPT.append(jnp.abs(particles.mass*vpar**2/2+mu*AbsB-particles.energy)/particles.energy) relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) -# particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + # particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) # Trace in ESSOS -num_steps_essos = 1000#int(jnp.mean(jnp.array([len(trajectories_SIMSOPT[0]) for trajectories_SIMSOPT in trajectories_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_gc, num_steps_essos) +runtime_ESSOS_array = [] +times_essos_array = [] +trajectories_ESSOS_array = [] +relative_energy_error_ESSOS_array = [] + +# Creating a tracing object for compilation +compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri8', + stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) +block_until_ready(compile_tracing.trajectories) + +for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): + num_steps_essos = avg_steps_SIMSOPT_array[index] + print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') + start_time = time() + tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri8', + stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) + block_until_ready(tracing.trajectories) + runtime_ESSOS = time() - start_time + runtime_ESSOS_array.append(runtime_ESSOS) + times_essos_array.append(tracing.times) + trajectories_ESSOS_array.append(tracing.trajectories) + # print(tracing.trajectories.shape) + + trajectories_ESSOS = tracing.trajectories + print(f"Time for ESSOS tracing={runtime_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}\n") + + relative_energy_error_ESSOS = jnp.abs(tracing.energy()-particles.energy)/particles.energy + relative_energy_error_ESSOS_array.append(relative_energy_error_ESSOS) + # tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) -tracing = Tracing('GuidingCenter', field_essos, 1e-7, timesteps=100, method='Dopri8', - stepsize='adaptive', tol_step_size=1e-7, particles=particles) -block_until_ready(tracing.trajectories) +print('Plotting the results to output directory...') +plt.figure(figsize=(9, 6)) +colors = ['blue', 'orange', 'green', 'red', 'purple'] -print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') -start_time = time() -tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri8', - stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) -block_until_ready(tracing.trajectories) -time_ESSOS = time() - start_time +SIMSOPT_energy_interp = [] -trajectories_ESSOS = tracing.trajectories -print(f" Time for ESSOS tracing={time_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}") -tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) +for tolerance_idx in range(len(trace_tolerance_array)): + interpolation = jnp.stack([ + jnp.interp(times_essos_array[tolerance_idx], trajectories_SIMSOPT_array[tolerance_idx][particle_idx][:, 0], relative_energy_error_SIMSOPT_array[tolerance_idx][particle_idx]) + for particle_idx in range(nparticles) + ]) # This will have shape (nparticles, len(times_essos_array[tolerance_idx])) + SIMSOPT_energy_interp.append(interpolation) -relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(interpolation, axis=0), '--', color=colors[tolerance_idx]) + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(relative_energy_error_ESSOS_array[tolerance_idx], axis=0), '-', color=colors[tolerance_idx]) -print('Plotting the results to output directory...') -plt.figure() -SIMSOPT_energy_interp_this_particle = jnp.zeros((len(trace_tolerance_SIMSOPT_array), nparticles, len(trajectories_SIMSOPT_array[-1][-1][:,0]))) -for j in range(nparticles): - for i, relative_energy_error_SIMSOPT in enumerate(relative_energy_error_SIMSOPT_array): - SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) -plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS Tol={trace_tolerance_ESSOS}') -for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): - plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') -plt.legend() +legend_elements = [Line2D([0], [0], color=colors[tolerance_idx], linestyle='-', label=fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$") + for tolerance_idx in range(len(trace_tolerance_array))] + +plt.legend(handles=legend_elements, loc='lower right', title='ESSOS (─), SIMSOPT (--)', fontsize=14, title_fontsize=14) plt.yscale('log') -plt.xlabel('Time (s)') +plt.xlabel('Time (ms)') plt.ylabel('Average Relative Energy Error') plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() +plt.savefig(os.path.join(output_dir, f'relative_energy_error_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) # Plot time comparison in a bar chart -labels = [f'SIMSOPT\nTol={tol}' for tol in trace_tolerance_SIMSOPT_array] + [f'ESSOS\nTol={trace_tolerance_ESSOS}'] -times = time_SIMSOPT_array + [time_ESSOS] -plt.figure() -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black'], hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -orange_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -plt.legend(handles=[blue_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): - time_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 0] # Time values from guiding center SIMSOPT - # coords_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from guiding center SIMSOPT - coords_ESSOS = jnp.array(trajectory_ESSOS) - - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - interp_v = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 3]) - - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_v]) - - return coords_ESSOS_interp - -relative_error_array = [] -for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): - trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] - tracing.trajectories = trajectories_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS_interp')) - - relative_error_trajectories_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nparticles): - this_trajectory_SIMSOPT = jnp.array(trajectories_SIMSOPT[j])[:,1:] - this_trajectory_ESSOS = trajectories_ESSOS_interp[j] - average_relative_error = [] - for trajectory_SIMSOPT_t, trajectory_ESSOS_t in zip(this_trajectory_SIMSOPT, this_trajectory_ESSOS): - relative_error_x = jnp.abs(trajectory_SIMSOPT_t[0] - trajectory_ESSOS_t[0])/(jnp.abs(trajectory_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(trajectory_SIMSOPT_t[1] - trajectory_ESSOS_t[1])/(jnp.abs(trajectory_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(trajectory_SIMSOPT_t[2] - trajectory_ESSOS_t[2])/(jnp.abs(trajectory_SIMSOPT_t[2])+1e-12) - relative_error_v = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[3])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z + relative_error_v)/4) - average_relative_error = jnp.array(average_relative_error) - relative_error_trajectories_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_gc, len(average_relative_error))[1:], average_relative_error[1:], label=f'Particle {1+j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_guiding_center_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nparticles): - R_SIMSOPT = jnp.sqrt(trajectories_SIMSOPT[j][:,1]**2+trajectories_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(trajectories_SIMSOPT[j][:,2], trajectories_SIMSOPT[j][:,1]) - Z_SIMSOPT = trajectories_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(trajectories_ESSOS_interp[j][:,0]**2+trajectories_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(trajectories_ESSOS_interp[j][:,1], trajectories_ESSOS_interp[j][:,0]) - Z_ESSOS = trajectories_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'guiding_center_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", runtime_ESSOS_array[tolerance_idx], runtime_SIMSOPT_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale('log') +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'times_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +################################## + +def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): + time_simsopt = trajectory_SIMSOPT[:, 0] # Time values from SIMSOPT trajectory + + interp_x = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 1]) + interp_y = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 2]) + interp_z = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 3]) + interp_v = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 4]) + + coords_SIMSOPT_interp = jnp.column_stack([interp_x, interp_y, interp_z, interp_v]) - plt.figure() - for j in range(nparticles): - time_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,0]) - vpar_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,4]) - vpar_ESSOS = jnp.array(trajectories_ESSOS_interp[j][:,3]) - # plt.plot(time_SIMSOPT, jnp.abs((vpar_SIMSOPT-vpar_ESSOS)/vpar_SIMSOPT), '-', linewidth=2.5, label=f'Particle {1+j}') - plt.plot(time_SIMSOPT, vpar_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(time_SIMSOPT, vpar_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('Time (s)') - plt.ylabel(r'$v_{\parallel}/v$') - # plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'guiding_center_vpar_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - -# Calculate RMS error for each tolerance -rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - -# Plot RMS error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('RMS Error') -plt.yscale('log') -plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'rms_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate maximum error for each tolerance -max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot maximum error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Maximum Error') -plt.yscale('log') -plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'max_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate mean error for each tolerance -mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot mean error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Mean Error') -plt.yscale('log') -plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'mean_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() \ No newline at end of file + return coords_SIMSOPT_interp + +xyz_error_fig, xyz_error_ax = plt.subplots(figsize=(9, 6)) +vpar_error_fig, vpar_error_ax = plt.subplots(figsize=(9, 6)) + +avg_relative_xyz_error_array = [] +avg_relative_v_error_array = [] +for tolerance_idx in range(len(trace_tolerance_array)): + this_trajectory_SIMSOPT = jnp.stack([interpolate_SIMSOPT_to_ESSOS( + trajectories_SIMSOPT_array[tolerance_idx][particle_idx], times_essos_array[tolerance_idx] + ) for particle_idx in range(nparticles)]) + + this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] + + relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) + relative_v_errros = jnp.abs(this_trajectory_SIMSOPT[:, :, 3] - this_trajectory_ESSOS[:, :, 3]) / (jnp.abs(this_trajectory_SIMSOPT[:, :, 3]) + 1e-12) + + avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) + avg_relative_v_errors = jnp.mean(relative_v_errros, axis=0) + avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) + avg_relative_v_error_array.append(jnp.mean(avg_relative_v_errors)) + + xyz_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_xyz_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + vpar_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_v_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + +for ax, fig in zip([xyz_error_ax, vpar_error_ax], [xyz_error_fig, vpar_error_fig]): + ax.legend() + ax.set_xlabel('Time (ms)') + ax.set_yscale('log') + +xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') +vpar_error_ax.set_ylabel(r'Relative $v_\parallel$ Error') +xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +vpar_error_fig.savefig(os.path.join(output_dir, f'relative_vpar_error_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +xyz_vals = [q[1] for q in quantities] +vpar_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, xyz_vals, bar_width, label=r"x,y,z", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, vpar_vals, bar_width, label=r"$v_\parallel$", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Time Averaged Relative Error") +ax.set_yscale('log') +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'relative_errors_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +plt.show() \ No newline at end of file From f902e22eaa133544427bcfb0cb32a1ac52f56b26 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 28 May 2025 16:22:09 +0200 Subject: [PATCH 30/63] Add comparison script for fo tracing & improve gc script plots --- analysis/comparison_fo.py | 230 ++++++++++++++++++++++++++++++++++++++ analysis/comparison_gc.py | 6 +- 2 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 analysis/comparison_fo.py diff --git a/analysis/comparison_fo.py b/analysis/comparison_fo.py new file mode 100644 index 0000000..c01efef --- /dev/null +++ b/analysis/comparison_fo.py @@ -0,0 +1,230 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +from jax import block_until_ready, random +from simsopt import load +from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) +from essos.coils import Coils_from_simsopt +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +from essos.fields import BiotSavart as BiotSavart_essos +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +plt.rcParams.update({'font.size': 18}) + +tmax = 5e-5 +nparticles = 5 +axis_shft=0.02 +R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) +trace_tolerance_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] +trace_tolerance_ESSOS = 1e-9 +mass=PROTON_MASS +energy=5000*ONE_EV + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +nfp=2 +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +field_simsopt = load(LandremanPaulQA_json_file) +field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) + +Z0 = jnp.zeros(nparticles) +phi0 = jnp.zeros(nparticles) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) + +particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy, field=field_essos) + +# Trace in SIMSOPT +runtime_SIMSOPT_array = [] +trajectories_SIMSOPT_array = [] +avg_steps_SIMSOPT_array = [] +relative_energy_error_SIMSOPT_array = [] +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}\n') +for trace_tolerance_SIMSOPT in trace_tolerance_array: + print(f'Tracing SIMSOPT full orbit with tolerance={trace_tolerance_SIMSOPT}') + t1 = time() + trajectories_SIMSOPT, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( + field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, + parallel_speeds=particles.initial_vparallel, tmax=tmax, mode='full', + charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) + runtime_SIMSOPT = time() - t1 + runtime_SIMSOPT_array.append(runtime_SIMSOPT) + avg_steps_SIMSOPT = sum([len(l) for l in trajectories_SIMSOPT]) // nparticles + avg_steps_SIMSOPT_array.append(avg_steps_SIMSOPT) + + print(f"Time for SIMSOPT tracing={runtime_SIMSOPT:.3f}s. Avg num steps={avg_steps_SIMSOPT}\n") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT) + + relative_energy_SIMSOPT = [jnp.abs(0.5 * mass * jnp.sum(jnp.square(trajectory[:, 4:]), axis=1) - particles.energy) / particles.energy + for trajectory in trajectories_SIMSOPT] + + relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) + + # particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + +# Trace in ESSOS +runtime_ESSOS_array = [] +times_essos_array = [] +trajectories_ESSOS_array = [] +relative_energy_error_ESSOS_array = [] + +# Creating a tracing object for compilation +compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) +# compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Boris', +# stepsize='constant', particles=particles) + +block_until_ready(compile_tracing.trajectories) + +for tolerance_idx, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): + num_steps_essos = 10000 # avg_steps_SIMSOPT_array[tolerance_idx] + print(f'Tracing ESSOS full orbit with tolerance={trace_tolerance_ESSOS}') + start_time = time() + tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) + # tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Boris', + # stepsize='constant', particles=particles) + block_until_ready(tracing.trajectories) + runtime_ESSOS = time() - start_time + runtime_ESSOS_array.append(runtime_ESSOS) + times_essos_array.append(tracing.times) + trajectories_ESSOS_array.append(tracing.trajectories) + # print(tracing.trajectories.shape) + + trajectories_ESSOS = tracing.trajectories + print(f"Time for ESSOS tracing={runtime_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}\n") + + relative_energy_error_ESSOS = jnp.abs(tracing.energy()-particles.energy)/particles.energy + relative_energy_error_ESSOS_array.append(relative_energy_error_ESSOS) + # tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) + +print('Plotting the results to output directory...') +plt.figure(figsize=(9, 6)) +colors = ['blue', 'orange', 'green', 'red', 'purple'] + +SIMSOPT_energy_interp = [] + +for tolerance_idx in range(len(trace_tolerance_array)): + interpolation = jnp.stack([ + jnp.interp(times_essos_array[tolerance_idx], trajectories_SIMSOPT_array[tolerance_idx][particle_idx][:, 0], relative_energy_error_SIMSOPT_array[tolerance_idx][particle_idx]) + for particle_idx in range(nparticles) + ]) # This will have shape (nparticles, len(times_essos_array[tolerance_idx])) + SIMSOPT_energy_interp.append(interpolation) + + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(interpolation, axis=0), '--', color=colors[tolerance_idx]) + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(relative_energy_error_ESSOS_array[tolerance_idx], axis=0), '-', color=colors[tolerance_idx]) + +legend_elements = [Line2D([0], [0], color=colors[tolerance_idx], linestyle='-', label=fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$") + for tolerance_idx in range(len(trace_tolerance_array))] + +plt.legend(handles=legend_elements, loc='lower right', title='ESSOS (─), SIMSOPT (--)', fontsize=14, title_fontsize=14) +plt.yscale('log') +plt.xlabel('Time (ms)') +plt.ylabel('Average Relative Energy Error') +plt.tight_layout() +plt.savefig(os.path.join(output_dir, f'relative_energy_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +# Plot time comparison in a bar chart + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", runtime_ESSOS_array[tolerance_idx], runtime_SIMSOPT_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale('log') +ax.set_ylim(1e0, 1e3) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'times_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +################################## + +def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): + time_simsopt = trajectory_SIMSOPT[:, 0] # Time values from SIMSOPT trajectory + + interp_x = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 1]) + interp_y = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 2]) + interp_z = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 3]) + interp_vx = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 4]) + interp_vy = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 5]) + interp_vz = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 6]) + + coords_SIMSOPT_interp = jnp.column_stack([interp_x, interp_y, interp_z, interp_vx, interp_vy, interp_vz]) + + return coords_SIMSOPT_interp + +xyz_error_fig, xyz_error_ax = plt.subplots(figsize=(9, 6)) +v_error_fig, v_error_ax = plt.subplots(figsize=(9, 6)) + +avg_relative_xyz_error_array = [] +avg_relative_v_error_array = [] +for tolerance_idx in range(len(trace_tolerance_array)): + this_trajectory_SIMSOPT = jnp.stack([interpolate_SIMSOPT_to_ESSOS( + trajectories_SIMSOPT_array[tolerance_idx][particle_idx], times_essos_array[tolerance_idx] + ) for particle_idx in range(nparticles)]) + + this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] + + relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) + relative_v_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, 3:] - this_trajectory_SIMSOPT[:, :, 3:], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, 3:], axis=2) + 1e-12) + + avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) + avg_relative_v_errors = jnp.mean(relative_v_errors, axis=0) + avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) + avg_relative_v_error_array.append(jnp.mean(avg_relative_v_errors)) + + xyz_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_xyz_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + v_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_v_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + +for ax, fig in zip([xyz_error_ax, v_error_ax], [xyz_error_fig, v_error_fig]): + ax.legend() + ax.set_xlabel('Time (ms)') + ax.set_yscale('log') + +xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') +v_error_ax.set_ylabel(r'Relative $v_x,v_y,v_z$ Error') +xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +v_error_fig.savefig(os.path.join(output_dir, f'relative_v_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +xyz_vals = [q[1] for q in quantities] +v_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, xyz_vals, bar_width, label=r"x,y,z", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, v_vals, bar_width, label=r"$v_x,v_y,v_z$", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Time Averaged Relative Error") +ax.set_yscale('log') +ax.set_ylim(1e-8, 1e-1) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'relative_errors_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +plt.show() \ No newline at end of file diff --git a/analysis/comparison_gc.py b/analysis/comparison_gc.py index cfd57c5..6ac6b57 100644 --- a/analysis/comparison_gc.py +++ b/analysis/comparison_gc.py @@ -82,7 +82,7 @@ relative_energy_error_ESSOS_array = [] # Creating a tracing object for compilation -compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri8', +compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) block_until_ready(compile_tracing.trajectories) @@ -90,7 +90,7 @@ num_steps_essos = avg_steps_SIMSOPT_array[index] print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') start_time = time() - tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri8', + tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) block_until_ready(tracing.trajectories) runtime_ESSOS = time() - start_time @@ -152,6 +152,7 @@ ax.set_xticklabels(labels) ax.set_ylabel("Computation time (s)") ax.set_yscale('log') +ax.set_ylim(1e0, 1e2) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) plt.savefig(os.path.join(output_dir, 'times_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) @@ -221,6 +222,7 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): ax.set_xticklabels(labels) ax.set_ylabel("Time Averaged Relative Error") ax.set_yscale('log') +ax.set_ylim(1e-6, 1e-1) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) plt.savefig(os.path.join(output_dir, 'relative_errors_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) From 823c6dbd704196b046e50d4aebd85cffac2bbaee Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 28 May 2025 19:12:54 +0200 Subject: [PATCH 31/63] Fixed error in dynamics when tracing fieldlines --- essos/dynamics.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/essos/dynamics.py b/essos/dynamics.py index 08cc717..31cc236 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -278,9 +278,13 @@ def update_state(state, _): else: if self.stepsize == "adaptive": r0 = jnp.linalg.norm(initial_condition[:2]) - dtmax = r0*0.5*jnp.pi/self.particles.total_speed # can at most do quarter of a revolution per step + if self.model != 'FieldLine': + dtmax = r0*0.5*jnp.pi/self.particles.total_speed # can at most do quarter of a revolution per step + dt0 = 1e-3 * dtmax # initial guess for first timestep, will be adjusted by adaptive timestepper + else: + dtmax = None + dt0 = None controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, dtmax=dtmax, rtol=self.tol_step_size, atol=self.tol_step_size) - dt0 = 1e-3 * dtmax # initial guess for first timestep, will be adjusted by adaptive timestepper elif self.stepsize == "constant": controller = StepTo(self.times) dt0 = None From 32c36b3ab0b3f6890380ffc0bbeb528709c8093d Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Fri, 6 Jun 2025 11:00:19 +0200 Subject: [PATCH 32/63] Based on work from PR #19 https://github.com/uwplasma/ESSOS/pull/19 --- essos/dynamics.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/essos/dynamics.py b/essos/dynamics.py index 31cc236..338a803 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -6,7 +6,7 @@ from jax import jit, vmap, tree_util, random, lax, device_put from functools import partial import diffrax -from diffrax import diffeqsolve, ODETerm, SaveAt, Dopri8, PIDController, Event, AbstractSolver, ConstantStepSize, StepTo +from diffrax import diffeqsolve, ODETerm, SaveAt, Dopri8, PIDController, Event, AbstractSolver, ConstantStepSize, StepTo, NoProgressMeter, TqdmProgressMeter from essos.coils import Coils from essos.fields import BiotSavart, Vmec from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, FUSION_ALPHA_PARTICLE_ENERGY @@ -151,7 +151,8 @@ def FieldLine(t, class Tracing(): def __init__(self, model: str, field, maxtime: float, method=None, times=None, timesteps: int = None, stepsize: str = "adaptive", dt0: float=1e-5, - tol_step_size = 1e-10, particles=None, initial_conditions=None, condition=None): + tol_step_size = 1e-10, particles=None, initial_conditions=None, condition=None, + progress_meter=False): """ Tracing class to compute the trajectories of particles in a magnetic field. @@ -178,6 +179,7 @@ def __init__(self, model: str, field, maxtime: float, method=None, times=None, self.model = model self.method = method self.stepsize = stepsize + self.progress_meter = progress_meter assert isinstance(field, (BiotSavart, Coils, Vmec)), "Field must be a BiotSavart, Coils, or Vmec object" self.field = BiotSavart(field) if isinstance(field, Coils) else field @@ -288,6 +290,11 @@ def update_state(state, _): elif self.stepsize == "constant": controller = StepTo(self.times) dt0 = None + + if self.progress_meter: + progress_meter = TqdmProgressMeter() + else: + progress_meter = NoProgressMeter() trajectory = diffeqsolve( self.ODE_term, @@ -300,6 +307,7 @@ def update_state(state, _): saveat=SaveAt(ts=self.times), throw=True, # adjoint=DirectAdjoint(), + progress_meter = progress_meter, stepsize_controller = controller, max_steps = int(1e10), event = Event(self.condition) From 798731d234105999ac9591c09629d57eee8d96e4 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Fri, 6 Jun 2025 11:00:57 +0200 Subject: [PATCH 33/63] Fix energy calls in integrators analysis --- analysis/fo_integrators.py | 4 ++-- analysis/gc_integrators.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index 25b971d..79c408a 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -54,7 +54,7 @@ print(f"Tracing with adaptive {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") - energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + energies += [jnp.mean(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method_name} adapt', marker='o', markersize=3, linestyle='-') energies = [] @@ -70,7 +70,7 @@ print(f"Tracing with {method_name} and step {dt:.2e} took {tracing_times[-1]:.2f} seconds") - energies += [jnp.mean(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + energies += [jnp.mean(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method_name}', marker='o', markersize=4, linestyle='-') diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index d6efaa3..d274563 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -53,7 +53,7 @@ print(f"Tracing with adaptive {method} and {tolerance=:.0e} took {tracing_times[-1]:.2f} seconds") - energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + energies += [jnp.max(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method} adapt', marker='o', markersize=3) ax_tol.plot(tolerances, energies, marker, label=f'{method} adapt', clip_on=False, linewidth=2.5) @@ -71,7 +71,7 @@ print(f"Tracing with {method} and {dt=:.2e} took {tracing_times[-1]:.2f} seconds") - energies += [jnp.max(jnp.abs(tracing.energy-particles.energy)/particles.energy)] + energies += [jnp.max(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') gc.collect() From c3134b906ea3c589f058e36b5bca89cce18cad57 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 25 Jun 2025 03:34:03 +0200 Subject: [PATCH 34/63] Fixed loss functions and enhanced coil separation logic --- essos/objective_functions.py | 43 ++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/essos/objective_functions.py b/essos/objective_functions.py index 8ffa78e..6acfa86 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -35,8 +35,8 @@ def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves_shape, currents_scale B_difference_loss = jnp.sum(jnp.abs(jnp.array(B_coils)-jnp.array(B_nearaxis))) gradB_difference_loss = jnp.sum(jnp.abs(jnp.array(gradB_coils)-jnp.array(gradB_nearaxis))) - coil_length_loss = 1e3*jnp.max(loss_coil_length(field, max_coil_length)) - coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(field, max_coil_curvature)) + coil_length_loss = 1e3*jnp.max(loss_coil_length(coils, max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(coils, max_coil_curvature)) return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss @@ -77,8 +77,8 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves_shape, currents_scale B_difference_loss = 3*jnp.sum(jnp.abs(B_difference)) gradB_difference_loss = jnp.sum(jnp.abs(gradB_difference)) - coil_length_loss = 1e3*jnp.max(loss_coil_length(field, max_coil_length)) - coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(field, max_coil_curvature)) + coil_length_loss = 1e3*jnp.max(loss_coil_length(coils, max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(coils, max_coil_curvature)) elongation_loss = jnp.sum(jnp.abs(elongation)) iota_loss = 30/jnp.abs(iota) @@ -107,32 +107,41 @@ def loss_particle_drift(field, particles, maxtime=1e-5, num_steps=300, trace_tol # return jnp.concatenate((jnp.ravel(jnp.abs(vertical_factor)),)) @partial(jit, static_argnames=['max_coil_length']) -def loss_coil_length(coils, max_coil_length): - return jnp.square((coils.length-max_coil_length)/max_coil_length) +def loss_coil_length(coils, max_coil_length=0): + return jnp.square(coils.length/max_coil_length - 1) @partial(jit, static_argnames=['max_coil_curvature']) -def loss_coil_curvature(coils, max_coil_curvature): +def loss_coil_curvature(coils, max_coil_curvature=0): pointwise_curvature_loss = jnp.square(jnp.maximum(coils.curvature-max_coil_curvature, 0)) - return jnp.mean(pointwise_curvature_loss, axis=1) + return jnp.mean(pointwise_curvature_loss*jnp.linalg.norm(coils.gamma_dash, axis=-1), axis=1) -@partial(jit, static_argnames=['min_separation']) -def loss_coil_separation(coils, min_separation): - # Sort coils by angle - # sorting = jnp.argsort(jnp.arctan2(coils.curves[:,1,0], coils.curves[:,0,0])%(2*jnp.pi)) - # This can be useful to only cosider the separation between adjacent coils - # i_vals, j_vals = jnp.arange(len(coils)), jnp.arange(1, len(coils)+1)%len(coils) - # but in this case gamma_i and gamma_j have to be sorted with the sorting mask +def compute_candidates(coils, min_separation): + centers = coils.curves[:, :, 0] + a_n = coils.curves[:, :, 2 : 2*coils.order+1 : 2] + b_n = coils.curves[:, :, 1 : 2*coils.order : 2] + radii = jnp.sum(jnp.linalg.norm(a_n, axis=1)+jnp.linalg.norm(b_n, axis=1), axis=1) i_vals, j_vals = jnp.triu_indices(len(coils), k=1) + centers_dists = jnp.linalg.norm(centers[i_vals] - centers[j_vals], axis=1) + mask = centers_dists <= min_separation + radii[i_vals] + radii[j_vals] + + return i_vals[mask], j_vals[mask] + +@partial(jit, static_argnames=['min_separation']) +def loss_coil_separation(coils, min_separation, candidates=None): + if candidates is None: + candidates = jnp.triu_indices(len(coils), k=1) def pair_loss(i, j): gamma_i = coils.gamma[i] + gamma_dash_i = jnp.linalg.norm(coils.gamma_dash[i], axis=-1) gamma_j = coils.gamma[j] + gamma_dash_j = jnp.linalg.norm(coils.gamma_dash[j], axis=-1) dists = jnp.linalg.norm(gamma_i[:, None, :] - gamma_j[None, :, :], axis=2) penalty = jnp.maximum(0, min_separation - dists) - return jnp.mean(jnp.square(penalty)) + return jnp.mean(jnp.square(penalty)*gamma_dash_i*gamma_dash_j) - losses = jax.vmap(pair_loss)(i_vals, j_vals) + losses = jax.vmap(pair_loss)(*candidates) return jnp.sum(losses) # @partial(jit, static_argnames=['target_B_on_axis', 'npoints']) From 5a728c0c981bd3ad8e43e507379305313dcf4d03 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 25 Jun 2025 03:35:05 +0200 Subject: [PATCH 35/63] Fixed coils&surface opt example --- examples/optimize_coils_and_surface.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/optimize_coils_and_surface.py b/examples/optimize_coils_and_surface.py index 99950ed..42089f0 100644 --- a/examples/optimize_coils_and_surface.py +++ b/examples/optimize_coils_and_surface.py @@ -117,12 +117,12 @@ def loss_normal_cross_GradB_dot_grad_B_dot_GradB_surface(surface, field): normal_cross_GradB_dot_grad_B_dot_GradB_surface = jnp.sum(normal_cross_GradB_surface * grad_B_dot_GradB_surface, axis=-1) return normal_cross_GradB_dot_grad_B_dot_GradB_surface -@partial(jit, static_argnums=(1, 5, 6, 7, 8, 9, 10)) -def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents_scale, nfp, max_coil_length=42, +@partial(jit, static_argnums=(1, 3, 5, 6, 7, 8, 9, 10)) +def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, n_segments=60, stellsym=True, max_coil_curvature=0.5, target_B_on_surface=5.7): - len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) + len_dofs_curves_ravelled = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] dofs_currents = x[len_dofs_curves_ravelled:-len(surface_all.x)-len(field_nearaxis.x)] - new_dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], (dofs_curves.shape)) + new_dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], dofs_curves_shape) curves = Curves(new_dofs_curves, n_segments, nfp, stellsym) coils = Coils(curves=curves, currents=dofs_currents*currents_scale) @@ -133,8 +133,8 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len(field_nearaxis.x):], field_nearaxis) - coil_length = loss_coil_length(field) - coil_curvature = loss_coil_curvature(field) + coil_length = loss_coil_length(coils) + coil_curvature = loss_coil_curvature(coils) coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) From 43dd5a012b911c86af7a6a7924d5620632e12efb Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 25 Jun 2025 03:35:36 +0200 Subject: [PATCH 36/63] Added extra simsopt compilation runs --- analysis/comparison_coils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/analysis/comparison_coils.py b/analysis/comparison_coils.py index 03288c0..c7c7541 100644 --- a/analysis/comparison_coils.py +++ b/analysis/comparison_coils.py @@ -77,6 +77,8 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): # Running the first time for compilation [curve.gamma() for curve in curves_simsopt] + [curve.gammadash() for curve in curves_simsopt] + [curve.gammadashdash() for curve in curves_simsopt] coils_essos.gamma # Running the second time for coils characteristics comparison From e6cb9fa46e5a44fd0175ecc57573724dad1b4d79 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 25 Jun 2025 03:36:26 +0200 Subject: [PATCH 37/63] Added field line comparisons & minor improvements to fo and gc comparisons --- analysis/comparison_fl.py | 179 ++++++++++++++++++++++++++++++++++++++ analysis/comparison_fo.py | 56 ++++++++---- analysis/comparison_gc.py | 4 +- 3 files changed, 221 insertions(+), 18 deletions(-) create mode 100644 analysis/comparison_fl.py diff --git a/analysis/comparison_fl.py b/analysis/comparison_fl.py new file mode 100644 index 0000000..de78e5d --- /dev/null +++ b/analysis/comparison_fl.py @@ -0,0 +1,179 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +from jax import block_until_ready, random +from simsopt import load +from simsopt.field import (particles_to_vtk, compute_fieldlines, plot_poincare_data) +from essos.coils import Coils_from_simsopt +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +from essos.fields import BiotSavart as BiotSavart_essos +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +plt.rcParams.update({'font.size': 18}) + +tmax_fl = 2000 +nfieldlines = 5 +axis_shift=0.02 +R0 = jnp.linspace(1.2125346+axis_shift, 1.295-axis_shift, nfieldlines) +trace_tolerance_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] +trace_tolerance_ESSOS = 1e-9 +mass=PROTON_MASS +energy=5000*ONE_EV + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +nfp=2 +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +field_simsopt = load(LandremanPaulQA_json_file) +field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) + +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nfieldlines,), minval=-1, maxval=1) + +phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] + +particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) + +# Trace in SIMSOPT +runtime_SIMSOPT_array = [] +trajectories_SIMSOPT_array = [] +avg_steps_SIMSOPT_array = [] + +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}\n') +for trace_tolerance_SIMSOPT in trace_tolerance_array: + print(f'Tracing SIMSOPT field lines with tolerance={trace_tolerance_SIMSOPT}') + t1 = time() + trajectories_SIMSOPT, trajectories_SIMSOPT_phi_hits = block_until_ready(compute_fieldlines( + field_simsopt, R0, Z0, tmax=tmax_fl, tol=trace_tolerance_SIMSOPT, phis=phis_poincare)) + runtime_SIMSOPT = time() - t1 + runtime_SIMSOPT_array.append(runtime_SIMSOPT) + avg_steps_SIMSOPT = sum([len(l) for l in trajectories_SIMSOPT]) // nfieldlines + avg_steps_SIMSOPT_array.append(avg_steps_SIMSOPT) + # print(trajectories_SIMSOPT_this_tolerance[0].shape) + print(f"Time for SIMSOPT tracing={runtime_SIMSOPT:.3f}s. Avg num steps={avg_steps_SIMSOPT}\n") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT) + + # particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + +# Trace in ESSOS +runtime_ESSOS_array = [] +times_essos_array = [] +trajectories_ESSOS_array = [] +relative_energy_error_ESSOS_array = [] + +# Creating a tracing object for compilation +compile_tracing = Tracing('FieldLine', field_essos, tmax_fl, initial_conditions=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T, + timesteps=100, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_array[0]) +block_until_ready(compile_tracing.trajectories) + +for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): + num_steps_essos = avg_steps_SIMSOPT_array[index] + print(f'Tracing ESSOS field lines with tolerance={trace_tolerance_ESSOS}') + start_time = time() + tracing = Tracing('FieldLine', field_essos, tmax_fl, initial_conditions=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T, + timesteps=num_steps_essos, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS) + block_until_ready(tracing.trajectories) + runtime_ESSOS = time() - start_time + runtime_ESSOS_array.append(runtime_ESSOS) + times_essos_array.append(tracing.times) + trajectories_ESSOS_array.append(tracing.trajectories) + # print(tracing.trajectories.shape) + + trajectories_ESSOS = tracing.trajectories + print(f"Time for ESSOS tracing={runtime_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}\n") + + # tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) + +print('Plotting the results to output directory...') + +# Plot time comparison in a bar chart +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", runtime_ESSOS_array[tolerance_idx], runtime_SIMSOPT_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale('log') +ax.set_ylim(1e0, 1e2) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'times_fl_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +################################## + +def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): + time_simsopt = trajectory_SIMSOPT[:, 0] # Time values from SIMSOPT trajectory + + interp_x = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 1]) + interp_y = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 2]) + interp_z = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 3]) + + coords_SIMSOPT_interp = jnp.column_stack([interp_x, interp_y, interp_z]) + + return coords_SIMSOPT_interp + +plt.figure(figsize=(9, 6)) + +avg_relative_xyz_error_array = [] +for tolerance_idx in range(len(trace_tolerance_array)): + this_trajectory_SIMSOPT = jnp.stack([interpolate_SIMSOPT_to_ESSOS( + trajectories_SIMSOPT_array[tolerance_idx][particle_idx], times_essos_array[tolerance_idx] + ) for particle_idx in range(nfieldlines)]) + + this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] + + relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) + + avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) + avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) + + plt.plot(times_essos_array[tolerance_idx], avg_relative_xyz_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + +plt.legend() +plt.xlabel('Time (a.u.)') +plt.yscale('log') + +plt.ylabel(r'Relative $x,y,z$ Error') +plt.savefig(os.path.join(output_dir, f'relative_xyz_error_fl_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +xyz_vals = [q[1] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.4 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis, xyz_vals, bar_width, label=r"x,y,z", color="darkorange", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Time Averaged Relative Error") +ax.set_yscale('log') +ax.set_ylim(1e-6, 1e-1) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'relative_errors_fl_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + +plt.show() \ No newline at end of file diff --git a/analysis/comparison_fo.py b/analysis/comparison_fo.py index c01efef..ed513d5 100644 --- a/analysis/comparison_fo.py +++ b/analysis/comparison_fo.py @@ -14,6 +14,11 @@ from matplotlib.lines import Line2D plt.rcParams.update({'font.size': 18}) +######################################################################################## +method = 'Boris' # 'Boris' or 'Dopri5' +######################################################################################## + + tmax = 5e-5 nparticles = 5 axis_shft=0.02 @@ -75,21 +80,27 @@ relative_energy_error_ESSOS_array = [] # Creating a tracing object for compilation -compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Dopri5', - stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) -# compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Boris', -# stepsize='constant', particles=particles) +if method == 'Dopri5': + compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) +else: + compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Boris', + stepsize='constant', particles=particles) block_until_ready(compile_tracing.trajectories) for tolerance_idx, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): - num_steps_essos = 10000 # avg_steps_SIMSOPT_array[tolerance_idx] print(f'Tracing ESSOS full orbit with tolerance={trace_tolerance_ESSOS}') start_time = time() - tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Dopri5', - stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) - # tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Boris', - # stepsize='constant', particles=particles) + if method == 'Dopri5': + num_steps_essos = 10000 + tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) + else: + num_steps_essos = avg_steps_SIMSOPT_array[tolerance_idx]*10 + tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Boris', + stepsize='constant', particles=particles) + block_until_ready(tracing.trajectories) runtime_ESSOS = time() - start_time runtime_ESSOS_array.append(runtime_ESSOS) @@ -128,7 +139,10 @@ plt.xlabel('Time (ms)') plt.ylabel('Average Relative Energy Error') plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +if method == 'Dopri5': + plt.savefig(os.path.join(output_dir, f'relative_energy_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +else: + plt.savefig(os.path.join(output_dir, f'relative_energy_error_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) # Plot time comparison in a bar chart @@ -150,10 +164,13 @@ ax.set_xticklabels(labels) ax.set_ylabel("Computation time (s)") ax.set_yscale('log') -ax.set_ylim(1e0, 1e3) +ax.set_ylim(1e-1, 1e3) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) -plt.savefig(os.path.join(output_dir, 'times_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +if method == 'Dopri5': + plt.savefig(os.path.join(output_dir, 'times_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +else: + plt.savefig(os.path.join(output_dir, 'times_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) ################################## @@ -201,8 +218,12 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') v_error_ax.set_ylabel(r'Relative $v_x,v_y,v_z$ Error') -xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -v_error_fig.savefig(os.path.join(output_dir, f'relative_v_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +if method == 'Dopri5': + xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + v_error_fig.savefig(os.path.join(output_dir, f'relative_v_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +else: + xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) + v_error_fig.savefig(os.path.join(output_dir, f'relative_v_error_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) for tolerance_idx in range(len(trace_tolerance_array))] @@ -222,9 +243,12 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): ax.set_xticklabels(labels) ax.set_ylabel("Time Averaged Relative Error") ax.set_yscale('log') -ax.set_ylim(1e-8, 1e-1) +ax.set_ylim(1e-6, 1e1) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) -plt.savefig(os.path.join(output_dir, 'relative_errors_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +if method == 'Dopri5': + plt.savefig(os.path.join(output_dir, 'relative_errors_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +else: + plt.savefig(os.path.join(output_dir, 'relative_errors_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) plt.show() \ No newline at end of file diff --git a/analysis/comparison_gc.py b/analysis/comparison_gc.py index 6ac6b57..3acc059 100644 --- a/analysis/comparison_gc.py +++ b/analysis/comparison_gc.py @@ -184,10 +184,10 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) - relative_v_errros = jnp.abs(this_trajectory_SIMSOPT[:, :, 3] - this_trajectory_ESSOS[:, :, 3]) / (jnp.abs(this_trajectory_SIMSOPT[:, :, 3]) + 1e-12) + relative_v_errors = jnp.abs(this_trajectory_SIMSOPT[:, :, 3] - this_trajectory_ESSOS[:, :, 3]) / (jnp.abs(this_trajectory_SIMSOPT[:, :, 3]) + 1e-12) avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) - avg_relative_v_errors = jnp.mean(relative_v_errros, axis=0) + avg_relative_v_errors = jnp.mean(relative_v_errors, axis=0) avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) avg_relative_v_error_array.append(jnp.mean(avg_relative_v_errors)) From 97f2985f52f6697f75a39b127b0fc6e88fb2f790 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Thu, 26 Jun 2025 22:04:30 +0200 Subject: [PATCH 38/63] Improve length calculation & refactor gammas to lazy initialization & change Coils_from_Simsopt /Json to class methods --- essos/coils.py | 58 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index 817fc6f..32b083f 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -44,7 +44,11 @@ def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stell self._order = dofs.shape[2]//2 self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) self.quadpoints = jnp.linspace(0, 1, self.n_segments, endpoint=False) - self._set_gamma() + self._gamma = None + self._gamma_dash = None + self._gamma_dashdash = None + self._curvature = None + self._length = None def __str__(self): return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ @@ -73,7 +77,7 @@ def fori_createdata(order_index: int, data: jnp.ndarray) -> jnp.ndarray: gamma_dash = jnp.zeros((jnp.size(self._curves, 0), self.n_segments, 3)) gamma_dashdash = jnp.zeros((jnp.size(self._curves, 0), self.n_segments, 3)) gamma, gamma_dash, gamma_dashdash = fori_loop(1, self._order+1, fori_createdata, (gamma, gamma_dash, gamma_dashdash)) - length = jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in gamma_dash]) + length = jnp.mean(jnp.linalg.norm(gamma_dash, axis=2), axis=1) curvature = vmap(compute_curvature)(gamma_dash, gamma_dashdash) self._gamma = gamma self._gamma_dash = gamma_dash @@ -150,22 +154,32 @@ def stellsym(self, new_stellsym): @property def gamma(self): + if self._gamma is None: + self._set_gamma() return self._gamma @property def gamma_dash(self): + if self._gamma_dash is None: + self._set_gamma() return self._gamma_dash @property def gamma_dashdash(self): + if self._gamma_dashdash is None: + self._set_gamma() return self._gamma_dashdash @property def length(self): + if self._length is None: + self._set_gamma() return self._length @property def curvature(self): + if self._curvature is None: + self._set_gamma() return self._curvature def __len__(self): @@ -288,9 +302,12 @@ def wrap(data): pointData = {**pointData, **extra_data} polyLinesToVTK(str(filename), np.array(x), np.array(y), np.array(z), pointsPerLine=np.array(ppl), pointData=pointData) -class Curves_from_simsopt(Curves): - # This assumes curves have all nfp and stellsym symmetries - def __init__(self, simsopt_curves, nfp=1, stellsym=True): + @classmethod + def from_simsopt(cls, simsopt_curves, nfp=1, stellsym=True): + """ + Create a Curves object from a list of simsopt curves. + This assumes curves have all nfp and stellsym symmetries. + """ if isinstance(simsopt_curves, str): from simsopt import load bs = load(simsopt_curves) @@ -301,7 +318,7 @@ def __init__(self, simsopt_curves, nfp=1, stellsym=True): [curve.x for curve in simsopt_curves] ), (len(simsopt_curves), 3, 2*simsopt_curves[0].order+1)) n_segments = len(simsopt_curves[0].quadpoints) - super().__init__(dofs, n_segments, nfp, stellsym) + return cls(dofs, n_segments, nfp, stellsym) tree_util.register_pytree_node(Curves, Curves._tree_flatten, @@ -447,24 +464,29 @@ def to_json(self, filename: str): import json with open(filename, "w") as file: json.dump(data, file) - -class Coils_from_json(Coils): - def __init__(self, filename: str): - import json - with open(filename , "r") as file: - data = json.load(file) - super().__init__(Curves(jnp.array(data["dofs_curves"]), data["n_segments"], data["nfp"], data["stellsym"]), data["dofs_currents"]) - -class Coils_from_simsopt(Coils): - # This assumes coils have all nfp and stellsym symmetries - def __init__(self, simsopt_coils, nfp=1, stellsym=True): + + @classmethod + def from_simsopt(cls, simsopt_coils, nfp=1, stellsym=True): + """ This assumes coils have all nfp and stellsym symmetries""" if isinstance(simsopt_coils, str): from simsopt import load bs = load(simsopt_coils) simsopt_coils = bs.coils curves = [c.curve for c in simsopt_coils] currents = jnp.array([c.current.get_value() for c in simsopt_coils[0:int(len(simsopt_coils)/nfp/(1+stellsym))]]) - super().__init__(Curves_from_simsopt(curves, nfp, stellsym), currents) + return cls(Curves.from_simsopt(curves, nfp, stellsym), currents) + + @classmethod + def from_json(cls, filename: str): + """ + Create a Coils object from a json file + """ + import json + with open(filename, "r") as file: + data = json.load(file) + curves = Curves(jnp.array(data["dofs_curves"]), data["n_segments"], data["nfp"], data["stellsym"]) + currents = jnp.array(data["dofs_currents"]) + return cls(curves, currents) tree_util.register_pytree_node(Coils, Coils._tree_flatten, From 312dab3b62621b0c14b46456ca06eab4e27015f9 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Thu, 26 Jun 2025 22:07:57 +0200 Subject: [PATCH 39/63] Fix coils.from_simsopt imports --- analysis/comparison_coils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/analysis/comparison_coils.py b/analysis/comparison_coils.py index c7c7541..58a9c79 100644 --- a/analysis/comparison_coils.py +++ b/analysis/comparison_coils.py @@ -5,7 +5,7 @@ plt.rcParams.update({'font.size': 18}) from jax import block_until_ready from essos.fields import BiotSavart as BiotSavart_essos -from essos.coils import Coils_from_simsopt, Curves_from_simsopt +from essos.coils import Coils, Curves from simsopt import load from simsopt.geo import CurveXYZFourier, curves_to_vtk from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries @@ -33,17 +33,17 @@ coils_simsopt = field_simsopt.coils curves_simsopt = [coil.curve for coil in coils_simsopt] currents_simsopt = [coil.current for coil in coils_simsopt] - coils_essos = Coils_from_simsopt(json_file_stel, nfp) - curves_essos = Curves_from_simsopt(json_file_stel, nfp) + coils_essos = Coils.from_simsopt(json_file_stel, nfp) + curves_essos = Curves.from_simsopt(json_file_stel, nfp) else: coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) curves_simsopt = [c.curve for c in coils_simsopt] currents_simsopt = [c.current for c in coils_simsopt] field_simsopt = BiotSavart_simsopt(coils_simsopt) - - coils_essos = Coils_from_simsopt(coils_simsopt, nfp) - curves_essos = Curves_from_simsopt(curves_simsopt, nfp) - + + coils_essos = Coils.from_simsopt(coils_simsopt, nfp) + curves_essos = Curves.from_simsopt(curves_simsopt, nfp) + field_essos = BiotSavart_essos(coils_essos) coils_essos_to_simsopt = coils_essos.to_simsopt() From 32fa0675281050e9bfa9c0f96497a4fb49bc6027 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Thu, 26 Jun 2025 22:08:26 +0200 Subject: [PATCH 40/63] Add loss function comparison with simsopt --- analysis/comparison_losses.py | 193 ++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 analysis/comparison_losses.py diff --git a/analysis/comparison_losses.py b/analysis/comparison_losses.py new file mode 100644 index 0000000..4f58431 --- /dev/null +++ b/analysis/comparison_losses.py @@ -0,0 +1,193 @@ +import os +from time import perf_counter as time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from jax import block_until_ready +from essos.fields import BiotSavart as BiotSavart_essos +from essos.coils import Coils, Curves +from essos.objective_functions import loss_coil_curvature, loss_coil_separation, compute_candidates, loss_coil_length +from simsopt import load +from simsopt.geo import CurveXYZFourier, curves_to_vtk, CurveCurveDistance, LpCurveCurvature, CurveLength +from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries +from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +n_segments = 100 + +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +nfp_array = [3, 2, 5, 4, 2] +curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] +currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] +name_array = ["NCSX", "QA(json)", "W7-X", "HSX", "Giuliani"] + +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') +for nfp, curves_stel, currents_stel, name in zip(nfp_array, curves_array, currents_array, name_array): + print(f' Running {name} and saving to output directory...') + if currents_stel is None: + json_file_stel = curves_stel + field_simsopt = load(json_file_stel) + coils_simsopt = field_simsopt.coils + curves_simsopt = [coil.curve for coil in coils_simsopt] + currents_simsopt = [coil.current for coil in coils_simsopt] + coils_essos = Coils.from_simsopt(json_file_stel, nfp) + curves_essos = Curves.from_simsopt(json_file_stel, nfp) + else: + coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + currents_simsopt = [c.current for c in coils_simsopt] + field_simsopt = BiotSavart_simsopt(coils_simsopt) + + coils_essos = Coils.from_simsopt(coils_simsopt, nfp) + curves_essos = Curves.from_simsopt(curves_simsopt, nfp) + + field_essos = BiotSavart_essos(coils_essos) + + coils_essos_to_simsopt = coils_essos.to_simsopt() + curves_essos_to_simsopt = curves_essos.to_simsopt() + field_essos_to_simsopt = BiotSavart_simsopt(coils_essos_to_simsopt) + + # curves_to_vtk(curves_simsopt, os.path.join(output_dir,f"curves_simsopt_{name}")) + # curves_essos.to_vtk(os.path.join(output_dir,f"curves_essos_{name}")) + # curves_to_vtk(curves_essos_to_simsopt, os.path.join(output_dir,f"curves_essos_to_simsopt_{name}")) + + base_coils_simsopt = coils_simsopt[:int(len(coils_simsopt)/2/nfp)] + R = jnp.mean(jnp.array([jnp.sqrt(coil.curve.x[coil.curve.local_dof_names.index('xc(0)')]**2 + +coil.curve.x[coil.curve.local_dof_names.index('yc(0)')]**2) + for coil in base_coils_simsopt])) + x = jnp.array([R+0.01,R,R]) + y = jnp.array([R,R+0.01,R-0.01]) + z = jnp.array([0.05,0.06,0.07]) + + positions = jnp.array((x,y,z)) + + def update_nsegments_simsopt(curve_simsopt, n_segments): + new_curve = CurveXYZFourier(n_segments, curve_simsopt.order) + new_curve.x = curve_simsopt.x + return new_curve + + coils_essos.n_segments = n_segments + + base_curves_simsopt = [update_nsegments_simsopt(coil_simsopt.curve, n_segments) for coil_simsopt in base_coils_simsopt] + coils_simsopt = coils_via_symmetries(base_curves_simsopt, currents_simsopt[0:len(base_coils_simsopt)], nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + + # Running the first time for compilation + [LpCurveCurvature(curve, p=2, threshold=0).J() for curve in curves_simsopt] + loss_coil_curvature(coils_essos, 0) + [CurveLength(curve).J() for curve in curves_simsopt] + loss_coil_length(coils_essos, 10) + CurveCurveDistance(curves_simsopt, 0.5).J() + loss_coil_separation(coils_essos, 0.5) + + # Running the second time for coils characteristics comparison + + start_time = time() + curvature_loss_simsopt = block_until_ready(2*sum([LpCurveCurvature(curve, p=2, threshold=0).J() for curve in curves_simsopt])) + t_curvature_avg_simsopt = time() - start_time + + start_time = time() + curvature_loss_essos = block_until_ready(jnp.sum(loss_coil_curvature(coils_essos, 0))) + t_curvature_avg_essos = time() - start_time + + start_time = time() + length_loss_simsopt = block_until_ready(sum([(CurveLength(curve).J()/10 - 1)**2 for curve in curves_simsopt])) + t_length_avg_simsopt = time() - start_time + print(f"Length loss SIMSOPT: {length_loss_simsopt}") + + start_time = time() + length_loss_essos = block_until_ready(jnp.sum(loss_coil_length(coils_essos, 10))) + t_length_avg_essos = time() - start_time + print(f"Length loss ESSOS: {length_loss_essos}") + + start_time = time() + separation_loss_simsopt = block_until_ready(CurveCurveDistance(curves_simsopt, 0.5).J()) + t_separation_avg_simsopt = time() - start_time + print(f"Separation loss SIMSOPT: {separation_loss_simsopt}") + + start_time = time() + separation_loss_essos = block_until_ready(loss_coil_separation(coils_essos, 0.5)) + t_separation_avg_essos = time() - start_time + print(f"Separation loss ESSOS: {separation_loss_essos}") + + start_time = time() + ind_separation_loss_simsopt = block_until_ready(CurveCurveDistance(curves_simsopt, 0.5).J()) + t_ind_separation_avg_simsopt = time() - start_time + print(f"Independence separation loss SIMSOPT: {ind_separation_loss_simsopt}") + + start_time = time() + ind_separation_loss_essos = block_until_ready(loss_coil_separation(coils_essos, 0.5, candidates=compute_candidates(coils_essos, 0.5))) + t_ind_separation_avg_essos = time() - start_time + print(f"Independence separation loss ESSOS: {ind_separation_loss_essos}") + + length_error_avg = jnp.linalg.norm(length_loss_essos - length_loss_simsopt) + curvature_error_avg = jnp.linalg.norm(curvature_loss_essos - curvature_loss_simsopt) + separation_error_avg = jnp.linalg.norm(separation_loss_essos - separation_loss_simsopt) + ind_separation_error_avg = jnp.linalg.norm(ind_separation_loss_essos - ind_separation_loss_simsopt) + print(f"length_error_avg: {length_error_avg:.2e}") + print(f"curvature_error_avg: {curvature_error_avg:.2e}") + print(f"separation_error_avg: {separation_error_avg:.2e}") + # print(f"ind_separation_error_avg: {ind_separation_error_avg:.2e}") + + # Labels and corresponding absolute errors (ESSOS - SIMSOPT) + quantities_errors = [ + (r"$L_\ell$", jnp.abs(length_error_avg)), + (r"$L_\kappa$", jnp.abs(curvature_error_avg)), + (r"$L_\text{sep}$", jnp.abs(separation_error_avg)), + # (r"$L_\text{sep,ind}$", jnp.abs(ind_separation_error_avg)), + ] + + labels = [q[0] for q in quantities_errors] + error_vals = [q[1] for q in quantities_errors] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.6 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Absolute error") + ax.set_yscale("log") + ax.set_ylim(1e-17, 1e-2) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparison_error_losses_{name}.pdf"), transparent=True) + plt.close() + + + # Labels and corresponding timings + quantities = [ + (r"$L_\ell$", t_length_avg_essos, t_length_avg_simsopt), + (r"$L_\kappa$", t_curvature_avg_essos, t_curvature_avg_simsopt), + (r"$L_\text{sep}$", t_separation_avg_essos, t_separation_avg_simsopt), + # (r"$L_\text{sep,ind}$", t_ind_separation_avg_essos, t_ind_separation_avg_simsopt), + ] + + labels = [q[0] for q in quantities] + essos_vals = [q[1] for q in quantities] + simsopt_vals = [q[2] for q in quantities] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.35 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") + ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Computation time (s)") + ax.set_yscale("log") + ax.set_ylim(1e-4, 1e0) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + ax.legend(fontsize=12) + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparison_time_losses_{name}.pdf"), transparent=True) + plt.close() From d4bb38784f400ab7cbdde55316672e4f90d45e95 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Thu, 26 Jun 2025 22:08:33 +0200 Subject: [PATCH 41/63] Add surface comparison with simsopt --- analysis/comparison_surfaces.py | 116 ++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 analysis/comparison_surfaces.py diff --git a/analysis/comparison_surfaces.py b/analysis/comparison_surfaces.py new file mode 100644 index 0000000..3e69aed --- /dev/null +++ b/analysis/comparison_surfaces.py @@ -0,0 +1,116 @@ +import os +from time import time +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +import jax.numpy as jnp +from essos.coils import Coils, CreateEquallySpacedCurves +from essos.fields import Vmec, BiotSavart +from essos.surfaces import B_on_surface, BdotN_over_B, SurfaceRZFourier as SurfaceRZFourier_ESSOS, SquaredFlux as SquaredFlux_ESSOS +from simsopt.field import BiotSavart as BiotSavart_simsopt +from simsopt.geo import SurfaceRZFourier as SurfaceRZFourier_SIMSOPT +from simsopt.objectives import SquaredFlux as SquaredFlux_SIMSOPT + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Optimization parameters +max_coil_length = 42 +order_Fourier_series_coils = 4 +number_coil_points = 50 +function_evaluations_array = [30]*1 +diff_step_array = [1e-2]*1 +number_coils_per_half_field_period = 3 + +ntheta = 36 +nphi = 32 + +# Initialize VMEC field +vmec_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', + 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') +vmec = Vmec(vmec_file, ntheta=ntheta, nphi=nphi, close=False) + +# Initialize coils +current_on_each_coil = 1 +number_of_field_periods = vmec.nfp +major_radius_coils = vmec.r_axis +minor_radius_coils = vmec.r_axis/1.5 +curves_essos = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) +field_essos = BiotSavart(coils_essos) +surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) +# surface_essos.to_vtk("essos_surface") + +coils_simsopt = coils_essos.to_simsopt() +curves_simsopt = curves_essos.to_simsopt() +field_simsopt = BiotSavart_simsopt(coils_simsopt) +surface_simsopt = SurfaceRZFourier_SIMSOPT.from_wout(vmec_file, range="full torus", nphi=nphi, ntheta=ntheta) +field_simsopt.set_points(surface_simsopt.gamma().reshape((-1, 3))) +# surface_simsopt.to_vtk("simsopt_surface") + +print("Gamma") +gamma_error = jnp.sum(jnp.abs(surface_simsopt.gamma() - surface_essos.gamma)) +print(gamma_error) + +print('Gamma dash theta') +gamma_dash_theta_error = jnp.sum(jnp.abs(surface_simsopt.gammadash2()-surface_essos.gammadash_theta)) +print(gamma_dash_theta_error) + +print('Gamma dash phi') +gamma_dash_phi_error = jnp.sum(jnp.abs(surface_simsopt.gammadash1()-surface_essos.gammadash_phi)) +print(gamma_dash_phi_error) + +print('Normal') +normal_error = jnp.sum(jnp.abs(surface_simsopt.normal()-surface_essos.normal)) +print(normal_error) + +print('Unit normal') +unit_normal_error = jnp.sum(jnp.abs(surface_simsopt.unitnormal()-surface_essos.unitnormal)) +print(unit_normal_error) + +print('B on surface') +B_on_surface_error = jnp.sum(jnp.abs(field_simsopt.B().reshape((nphi, ntheta, 3)) - B_on_surface(surface_essos, field_essos))) +print(B_on_surface_error) + +definition = "local" +print("Squared flux", definition) +sf_SIMSOPT = SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt, definition=definition).J() +sf_ESSOS = SquaredFlux_ESSOS(surface_essos, field_essos, definition=definition) +squared_flux_error = jnp.abs(sf_SIMSOPT - sf_ESSOS) + +print("ESSOS: ", sf_ESSOS) +print("SIMSOPT: ", sf_SIMSOPT) + +# Labels and corresponding absolute errors (ESSOS - SIMSOPT) +quantities_errors = [ + (r"$\Gamma$", gamma_error), + (r"$\Gamma'_\theta$", gamma_dash_theta_error), + (r"$\Gamma'_\phi$", gamma_dash_phi_error), + (r"$\mathbf{n}$", unit_normal_error), + # (r"$\mathbf{B}$", B_on_surface_error), + (r"$L_\text{flux}$", squared_flux_error), +] + +labels = [q[0] for q in quantities_errors] +error_vals = [q[1] for q in quantities_errors] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.6 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Absolute error") +ax.set_yscale("log") +ax.set_ylim(1e-14, 1e-10) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + +plt.tight_layout() +# plt.savefig(os.path.join(output_dir, f"comparison_error_surfaces.pdf"), transparent=True) +plt.show() From 00f2160788390df404836ad2b6cef73096daaf23 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Thu, 26 Jun 2025 22:23:30 +0200 Subject: [PATCH 42/63] Refactor surfaces gamma to lazy initialization & add SquaredFlux function --- essos/surfaces.py | 89 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 19 deletions(-) diff --git a/essos/surfaces.py b/essos/surfaces.py index a8c7ea9..04649cc 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -26,9 +26,31 @@ def BdotN(surface, field): @partial(jit, static_argnames=['surface','field']) def BdotN_over_B(surface, field): - B_surface = B_on_surface(surface, field) - B_dot_n = jnp.sum(B_surface * surface.unitnormal, axis=2) - return B_dot_n / jnp.linalg.norm(B_surface, axis=2) + return BdotN(surface, field) / jnp.linalg.norm(B_on_surface(surface, field), axis=2) + +@partial(jit, static_argnames=['surface','field']) +def _squared_flux_local(surface, field): + return 0.5 * jnp.mean(BdotN(surface, field)**2 / jnp.sum(B_on_surface(surface, field)**2, axis=2) + * surface.area_element) + +@partial(jit, static_argnames=['surface','field']) +def _squared_flux_global(surface, field): + return 0.5 * jnp.mean(BdotN(surface, field)**2 * surface.area_element) + +@partial(jit, static_argnames=['surface','field']) +def _squared_flux_normalized(surface, field): + return 0.5 * jnp.mean(BdotN(surface, field)**2 * surface.area_element) / \ + jnp.mean(jnp.sum(B_on_surface(surface, field)**2, axis=2) * surface.area_element) + +def SquaredFlux(surface, field, definition='local'): + if definition == 'local': + return _squared_flux_local(surface, field) + elif definition == 'quadratic flux': + return _squared_flux_global(surface, field) + elif definition == 'normalized': + return _squared_flux_normalized(surface, field) + else: + raise ValueError(f"Unknown definition: {definition}") def nested_lists_to_array(ll): """ @@ -147,11 +169,13 @@ def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus=' self.angles = jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) - - if hasattr(self, 'bmnc'): - self._AbsB = self._set_AbsB() + self._gamma = None + self._gammadash_theta = None + self._gammadash_phi = None + self._normal = None + self._unitnormal = None + self._area_element = None + self._AbsB = None @property def dofs(self): @@ -165,12 +189,14 @@ def dofs(self, new_dofs): indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) - # if hasattr(self, 'bmnc'): - # self._AbsB = self._set_AbsB() + self._gamma = None + self._gammadash_theta = None + self._gammadash_phi = None + self._normal = None + self._unitnormal = None + self._area_element = None + self._AbsB = None - @partial(jit, static_argnames=['self']) def _set_gamma(self, rmnc_interp, zmns_interp): phi_2d = self.phi_2d angles = self.angles @@ -193,37 +219,62 @@ def _set_gamma(self, rmnc_interp, zmns_interp): normal = jnp.cross(gammadash_phi, gammadash_theta, axis=2) unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) - - return (gamma, gammadash_theta, gammadash_phi, normal, unitnormal) - - @partial(jit, static_argnames=['self']) + area_element = jnp.linalg.norm(jnp.cross(gammadash_theta, gammadash_phi, axis=2), axis=2) + + self._gamma = gamma + self._gammadash_theta = gammadash_theta + self._gammadash_phi = gammadash_phi + self._normal = normal + self._unitnormal = unitnormal + self._area_element = area_element + def _set_AbsB(self): angles_nyq = jnp.einsum('i,jk->ijk', self.xm_nyq, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn_nyq, self.phi_2d) AbsB = jnp.einsum('i,ijk->jk', self.bmnc_interp, jnp.cos(angles_nyq)) - return AbsB + self._AbsB = AbsB @property def gamma(self): + if self._gamma is None: + self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._gamma @property def gammadash_theta(self): + if self._gammadash_theta is None: + self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._gammadash_theta @property def gammadash_phi(self): + if self._gammadash_phi is None: + self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._gammadash_phi @property def normal(self): + if self._normal is None: + self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._normal @property def unitnormal(self): + if self._unitnormal is None: + self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._unitnormal - + + @property + def area_element(self): + if self._area_element is None: + self._set_gamma(self.rmnc_interp, self.zmns_interp) + return self._area_element + @property def AbsB(self): + if self._AbsB is None: + if not hasattr(self, 'bmnc'): + raise AttributeError("AbsB is not available. Ensure that the bmnc attribute is set.") + self._set_AbsB() return self._AbsB @property From 9c2407b5a057b5e6b9fe85685508961185859f5c Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Fri, 27 Jun 2025 11:27:23 +0200 Subject: [PATCH 43/63] Fix Coils.from_ imports --- analysis/comparison_fl.py | 4 ++-- analysis/comparison_fo.py | 4 ++-- analysis/comparison_gc.py | 4 ++-- analysis/fo_integrators.py | 4 ++-- analysis/gc_integrators.py | 4 ++-- analysis/gc_vs_fo.py | 8 ++++---- analysis/poincare_plots.py | 4 ++-- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/analysis/comparison_fl.py b/analysis/comparison_fl.py index de78e5d..71210b8 100644 --- a/analysis/comparison_fl.py +++ b/analysis/comparison_fl.py @@ -6,7 +6,7 @@ from jax import block_until_ready, random from simsopt import load from simsopt.field import (particles_to_vtk, compute_fieldlines, plot_poincare_data) -from essos.coils import Coils_from_simsopt +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV from essos.dynamics import Tracing, Particles from essos.fields import BiotSavart as BiotSavart_essos @@ -30,7 +30,7 @@ nfp=2 LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) +field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) Z0 = jnp.zeros(nfieldlines) phi0 = jnp.zeros(nfieldlines) diff --git a/analysis/comparison_fo.py b/analysis/comparison_fo.py index ed513d5..ee4ca95 100644 --- a/analysis/comparison_fo.py +++ b/analysis/comparison_fo.py @@ -6,7 +6,7 @@ from jax import block_until_ready, random from simsopt import load from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) -from essos.coils import Coils_from_simsopt +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV from essos.dynamics import Tracing, Particles from essos.fields import BiotSavart as BiotSavart_essos @@ -35,7 +35,7 @@ nfp=2 LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) +field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) Z0 = jnp.zeros(nparticles) phi0 = jnp.zeros(nparticles) diff --git a/analysis/comparison_gc.py b/analysis/comparison_gc.py index 3acc059..14ee37b 100644 --- a/analysis/comparison_gc.py +++ b/analysis/comparison_gc.py @@ -6,7 +6,7 @@ from jax import block_until_ready, random from simsopt import load from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) -from essos.coils import Coils_from_simsopt +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV from essos.dynamics import Tracing, Particles from essos.fields import BiotSavart as BiotSavart_essos @@ -30,7 +30,7 @@ nfp=2 LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) +field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) Z0 = jnp.zeros(nparticles) phi0 = jnp.zeros(nparticles) diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index 79c408a..a194da0 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles import diffrax @@ -18,7 +18,7 @@ # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Particle parameters diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index d274563..f300145 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles @@ -18,7 +18,7 @@ # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Particle parameters diff --git a/analysis/gc_vs_fo.py b/analysis/gc_vs_fo.py index b07d7ec..258b70f 100644 --- a/analysis/gc_vs_fo.py +++ b/analysis/gc_vs_fo.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.dynamics import Tracing, Particles from jax import block_until_ready @@ -17,7 +17,7 @@ # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Particle parameters @@ -64,8 +64,8 @@ plt.tight_layout() plt.figure(figsize=(9, 6)) -plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy[0]/particles.energy-1), label='Guiding Center', color='red') -plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy[0]/particles.energy-1), label='Full Orbit', color='blue') +plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy()[0]/particles.energy-1), label='Guiding Center', color='red') +plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy()[0]/particles.energy-1), label='Full Orbit', color='blue') plt.xlabel('Time (ms)') plt.ylabel('Relative Energy Error') plt.xlim(0, tmax*1000) diff --git a/analysis/poincare_plots.py b/analysis/poincare_plots.py index c2c9d87..fc878c5 100644 --- a/analysis/poincare_plots.py +++ b/analysis/poincare_plots.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE from essos.fields import BiotSavart from essos.dynamics import Tracing, Particles @@ -35,7 +35,7 @@ # Load coils and field json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +coils = Coils.from_json(json_file) field = BiotSavart(coils) R0_fieldlines = jnp.linspace(1.21, 1.41, nfieldlines) From 897593499d2bd2d114096f58be2deca345d8b2d7 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Fri, 27 Jun 2025 12:15:03 +0200 Subject: [PATCH 44/63] Finished surface comparison --- analysis/comparison_surfaces.py | 119 ++++++++++++++++++++++++++++---- 1 file changed, 104 insertions(+), 15 deletions(-) diff --git a/analysis/comparison_surfaces.py b/analysis/comparison_surfaces.py index 3e69aed..3b5c238 100644 --- a/analysis/comparison_surfaces.py +++ b/analysis/comparison_surfaces.py @@ -1,8 +1,9 @@ import os -from time import time +from time import perf_counter as time import matplotlib.pyplot as plt plt.rcParams.update({'font.size': 18}) import jax.numpy as jnp +from jax import block_until_ready from essos.coils import Coils, CreateEquallySpacedCurves from essos.fields import Vmec, BiotSavart from essos.surfaces import B_on_surface, BdotN_over_B, SurfaceRZFourier as SurfaceRZFourier_ESSOS, SquaredFlux as SquaredFlux_ESSOS @@ -52,38 +53,94 @@ field_simsopt.set_points(surface_simsopt.gamma().reshape((-1, 3))) # surface_simsopt.to_vtk("simsopt_surface") +# Running the first time for compilation +surface_simsopt.gamma() +surface_simsopt.gammadash1() +surface_simsopt.gammadash2() +surface_simsopt.unitnormal() +field_simsopt.B() +SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt).J() +surface_essos.gamma + +# Running the second time for surface characteristics comparison + print("Gamma") -gamma_error = jnp.sum(jnp.abs(surface_simsopt.gamma() - surface_essos.gamma)) +start_time = time() +gamma_essos = block_until_ready(surface_essos.gamma) +t_gamma_essos = time() - start_time + +gamma_simsopt = block_until_ready(surface_simsopt.gamma()) +start_time = time() +t_gamma_simsopt = time() - start_time + +gamma_error = jnp.sum(jnp.abs(gamma_simsopt - gamma_essos)) print(gamma_error) + print('Gamma dash theta') -gamma_dash_theta_error = jnp.sum(jnp.abs(surface_simsopt.gammadash2()-surface_essos.gammadash_theta)) +start_time = time() +gamma_dash_theta_essos = block_until_ready(surface_essos.gammadash_theta) +t_gamma_dash_theta_essos = time() - start_time + +start_time = time() +gamma_dash_theta_simsopt = block_until_ready(surface_simsopt.gammadash2()) +t_gamma_dash_theta_simsopt = time() - start_time + +gamma_dash_theta_error = jnp.sum(jnp.abs(gamma_dash_theta_simsopt - gamma_dash_theta_essos)) print(gamma_dash_theta_error) + print('Gamma dash phi') -gamma_dash_phi_error = jnp.sum(jnp.abs(surface_simsopt.gammadash1()-surface_essos.gammadash_phi)) +start_time = time() +gamma_dash_phi_essos = block_until_ready(surface_essos.gammadash_phi) +t_gamma_dash_phi_essos = time() - start_time + +start_time = time() +gamma_dash_phi_simsopt = block_until_ready(surface_simsopt.gammadash1()) +t_gamma_dash_phi_simsopt = time() - start_time + +gamma_dash_phi_error = jnp.sum(jnp.abs(gamma_dash_phi_simsopt - gamma_dash_phi_essos)) print(gamma_dash_phi_error) -print('Normal') -normal_error = jnp.sum(jnp.abs(surface_simsopt.normal()-surface_essos.normal)) -print(normal_error) print('Unit normal') -unit_normal_error = jnp.sum(jnp.abs(surface_simsopt.unitnormal()-surface_essos.unitnormal)) +start_time = time() +unit_normal_essos = block_until_ready(surface_essos.unitnormal) +t_unit_normal_essos = time() - start_time + +start_time = time() +unit_normal_simsopt = block_until_ready(surface_simsopt.unitnormal()) +t_unit_normal_simsopt = time() - start_time + +unit_normal_error = jnp.sum(jnp.abs(unit_normal_simsopt - unit_normal_essos)) print(unit_normal_error) + print('B on surface') -B_on_surface_error = jnp.sum(jnp.abs(field_simsopt.B().reshape((nphi, ntheta, 3)) - B_on_surface(surface_essos, field_essos))) +start_time = time() +B_on_surface_essos = block_until_ready(B_on_surface(surface_essos, field_essos)) +t_B_on_surface_essos = time() - start_time + +start_time = time() +B_on_surface_simsopt = block_until_ready(field_simsopt.B()) +t_B_on_surface_simsopt = time() - start_time + +B_on_surface_error = jnp.sum(jnp.abs(B_on_surface_simsopt.reshape((nphi, ntheta, 3)) - B_on_surface_essos)) print(B_on_surface_error) + definition = "local" print("Squared flux", definition) -sf_SIMSOPT = SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt, definition=definition).J() -sf_ESSOS = SquaredFlux_ESSOS(surface_essos, field_essos, definition=definition) -squared_flux_error = jnp.abs(sf_SIMSOPT - sf_ESSOS) +start_time = time() +sf_essos = block_until_ready(SquaredFlux_ESSOS(surface_essos, field_essos, definition=definition)) +t_squared_flux_essos = time() - start_time + +start_time = time() +sf_simsopt = block_until_ready(SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt, definition=definition).J()) +t_squared_flux_simsopt = time() - start_time -print("ESSOS: ", sf_ESSOS) -print("SIMSOPT: ", sf_SIMSOPT) +squared_flux_error = jnp.abs(sf_simsopt - sf_essos) +print(squared_flux_error) # Labels and corresponding absolute errors (ESSOS - SIMSOPT) quantities_errors = [ @@ -112,5 +169,37 @@ ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) plt.tight_layout() -# plt.savefig(os.path.join(output_dir, f"comparison_error_surfaces.pdf"), transparent=True) +plt.savefig(os.path.join(output_dir, f"comparison_error_surfaces.pdf"), transparent=True) + +# Labels and corresponding timings +quantities = [ + (r"$\Gamma$", t_gamma_essos, t_gamma_simsopt), + (r"$\Gamma'_\theta$", t_gamma_dash_theta_essos, t_gamma_dash_theta_simsopt), + (r"$\Gamma'_\phi$", t_gamma_dash_phi_essos, t_gamma_dash_phi_simsopt), + (r"$\mathbf{n}$", t_unit_normal_essos, t_unit_normal_simsopt), + # (r"$\mathbf{B}$", t_B_on_surface_essos, t_B_on_surface_simsopt), + (r"$L_\text{flux}$", t_squared_flux_essos, t_squared_flux_simsopt), +] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale("log") +ax.set_ylim(1e-7, 1e-1) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=12) +plt.tight_layout() +plt.savefig(os.path.join(output_dir, f"comparison_time_surfaces.pdf"), transparent=True) + plt.show() From 70dcaa6d233a469ee4af18637335b6c4669fc58b Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 9 Jul 2025 16:11:46 +0200 Subject: [PATCH 45/63] Finish simsopt comparison & improve analysis plots --- .../coils.py} | 8 +- .../field_lines.py} | 12 +- .../full_orbit.py} | 40 +++-- .../guiding_center.py} | 18 +-- .../losses.py} | 26 ++-- .../surfaces.py} | 10 +- analysis/comparisons_simsopt/vmec_import.py | 139 ++++++++++++++++++ analysis/fo_integrators.py | 2 +- analysis/gc_integrators.py | 2 +- analysis/gc_vs_fo.py | 3 +- analysis/gradients.py | 2 +- 11 files changed, 207 insertions(+), 55 deletions(-) rename analysis/{comparison_coils.py => comparisons_simsopt/coils.py} (96%) rename analysis/{comparison_fl.py => comparisons_simsopt/field_lines.py} (94%) rename analysis/{comparison_fo.py => comparisons_simsopt/full_orbit.py} (89%) rename analysis/{comparison_gc.py => comparisons_simsopt/guiding_center.py} (93%) rename analysis/{comparison_losses.py => comparisons_simsopt/losses.py} (90%) rename analysis/{comparison_surfaces.py => comparisons_simsopt/surfaces.py} (95%) create mode 100644 analysis/comparisons_simsopt/vmec_import.py diff --git a/analysis/comparison_coils.py b/analysis/comparisons_simsopt/coils.py similarity index 96% rename from analysis/comparison_coils.py rename to analysis/comparisons_simsopt/coils.py index 58a9c79..eb4d5f8 100644 --- a/analysis/comparison_coils.py +++ b/analysis/comparisons_simsopt/coils.py @@ -11,13 +11,13 @@ from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data -output_dir = os.path.join(os.path.dirname(__file__), 'output') +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) n_segments = 100 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') nfp_array = [3, 2, 5, 4, 2] curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] @@ -191,7 +191,7 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) plt.tight_layout() - plt.savefig(os.path.join(output_dir, f"comparison_error_BiotSavart_{name}.pdf"), transparent=True) + plt.savefig(os.path.join(output_dir, f"comparisons_coils_error_{name}.pdf"), transparent=True) plt.close() @@ -224,5 +224,5 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=12) plt.tight_layout() - plt.savefig(os.path.join(output_dir, f"comparison_time_BiotSavart_{name}.pdf"), transparent=True) + plt.savefig(os.path.join(output_dir, f"comparisons_coils_time_{name}.pdf"), transparent=True) plt.close() diff --git a/analysis/comparison_fl.py b/analysis/comparisons_simsopt/field_lines.py similarity index 94% rename from analysis/comparison_fl.py rename to analysis/comparisons_simsopt/field_lines.py index 71210b8..445d36e 100644 --- a/analysis/comparison_fl.py +++ b/analysis/comparisons_simsopt/field_lines.py @@ -23,12 +23,12 @@ mass=PROTON_MASS energy=5000*ONE_EV -output_dir = os.path.join(os.path.dirname(__file__), 'output') +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') field_simsopt = load(LandremanPaulQA_json_file) field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) @@ -116,7 +116,7 @@ ax.set_ylim(1e0, 1e2) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) -plt.savefig(os.path.join(output_dir, 'times_fl_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, 'comparisons_fl_times.pdf'), dpi=150) ################################## @@ -153,7 +153,7 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): plt.yscale('log') plt.ylabel(r'Relative $x,y,z$ Error') -plt.savefig(os.path.join(output_dir, f'relative_xyz_error_fl_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, f'comparisons_fl_error_xyz.pdf'), dpi=150) quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx]) for tolerance_idx in range(len(trace_tolerance_array))] @@ -169,11 +169,11 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): ax.set_xticks(X_axis) ax.set_xticklabels(labels) -ax.set_ylabel("Time Averaged Relative Error") +ax.set_ylabel("Time-averaged relative error") ax.set_yscale('log') ax.set_ylim(1e-6, 1e-1) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) -plt.savefig(os.path.join(output_dir, 'relative_errors_fl_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, 'comparisons_fl_error.pdf'), dpi=150) plt.show() \ No newline at end of file diff --git a/analysis/comparison_fo.py b/analysis/comparisons_simsopt/full_orbit.py similarity index 89% rename from analysis/comparison_fo.py rename to analysis/comparisons_simsopt/full_orbit.py index ee4ca95..ad7c8b6 100644 --- a/analysis/comparison_fo.py +++ b/analysis/comparisons_simsopt/full_orbit.py @@ -28,12 +28,12 @@ mass=PROTON_MASS energy=5000*ONE_EV -output_dir = os.path.join(os.path.dirname(__file__), 'output') +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') field_simsopt = load(LandremanPaulQA_json_file) field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) @@ -97,7 +97,7 @@ tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) else: - num_steps_essos = avg_steps_SIMSOPT_array[tolerance_idx]*10 + num_steps_essos = avg_steps_SIMSOPT_array[tolerance_idx]*3 tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Boris', stepsize='constant', particles=particles) @@ -137,12 +137,12 @@ plt.legend(handles=legend_elements, loc='lower right', title='ESSOS (─), SIMSOPT (--)', fontsize=14, title_fontsize=14) plt.yscale('log') plt.xlabel('Time (ms)') -plt.ylabel('Average Relative Energy Error') +plt.ylabel('Average relative energy error') plt.tight_layout() if method == 'Dopri5': - plt.savefig(os.path.join(output_dir, f'relative_energy_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, f'comparisons_fo_error_energy.pdf'), dpi=150) else: - plt.savefig(os.path.join(output_dir, f'relative_energy_error_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, f'comparisons_fo_boris_error_energy.pdf'), dpi=150) # Plot time comparison in a bar chart @@ -164,13 +164,17 @@ ax.set_xticklabels(labels) ax.set_ylabel("Computation time (s)") ax.set_yscale('log') -ax.set_ylim(1e-1, 1e3) +if method == 'Dopri5': + ax.set_ylim(1e0, 1e3) +else: + ax.set_ylim(1e-1, 1e3) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) if method == 'Dopri5': - plt.savefig(os.path.join(output_dir, 'times_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'comparisons_fo_times.pdf'), dpi=150) else: - plt.savefig(os.path.join(output_dir, 'times_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'comparisons_fo_boris_times.pdf'), dpi=150) ################################## @@ -219,11 +223,11 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') v_error_ax.set_ylabel(r'Relative $v_x,v_y,v_z$ Error') if method == 'Dopri5': - xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - v_error_fig.savefig(os.path.join(output_dir, f'relative_v_error_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + xyz_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_error_xyz.pdf'), dpi=150) + v_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_error_v.pdf'), dpi=150) else: - xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) - v_error_fig.savefig(os.path.join(output_dir, f'relative_v_error_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) + xyz_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_boris_error_xyz.pdf'), dpi=150) + v_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_boris_error_v.pdf'), dpi=150) quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) for tolerance_idx in range(len(trace_tolerance_array))] @@ -241,14 +245,18 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): ax.set_xticks(X_axis) ax.set_xticklabels(labels) -ax.set_ylabel("Time Averaged Relative Error") +ax.set_ylabel("Time-averaged relative error") ax.set_yscale('log') ax.set_ylim(1e-6, 1e1) +if method == 'Dopri5': + ax.set_ylim(1e-8, 1e-1) +else: + ax.set_ylim(1e-4, 1e0) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) if method == 'Dopri5': - plt.savefig(os.path.join(output_dir, 'relative_errors_fo_SIMSOPT_vs_ESSOS.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'comparisons_fo_errors.pdf'), dpi=150) else: - plt.savefig(os.path.join(output_dir, 'relative_errors_fo_SIMSOPT_vs_ESSOS_Boris.pdf'), dpi=150) + plt.savefig(os.path.join(output_dir, 'comparisons_fo_boris_errors.pdf'), dpi=150) plt.show() \ No newline at end of file diff --git a/analysis/comparison_gc.py b/analysis/comparisons_simsopt/guiding_center.py similarity index 93% rename from analysis/comparison_gc.py rename to analysis/comparisons_simsopt/guiding_center.py index 14ee37b..8798d85 100644 --- a/analysis/comparison_gc.py +++ b/analysis/comparisons_simsopt/guiding_center.py @@ -23,12 +23,12 @@ mass=PROTON_MASS energy=5000*ONE_EV -output_dir = os.path.join(os.path.dirname(__file__), 'output') +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') field_simsopt = load(LandremanPaulQA_json_file) field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) @@ -128,9 +128,9 @@ plt.legend(handles=legend_elements, loc='lower right', title='ESSOS (─), SIMSOPT (--)', fontsize=14, title_fontsize=14) plt.yscale('log') plt.xlabel('Time (ms)') -plt.ylabel('Average Relative Energy Error') +plt.ylabel('Average relative energy error') plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, f'comparisons_gc_error_energy.pdf'), dpi=150) # Plot time comparison in a bar chart @@ -155,7 +155,7 @@ ax.set_ylim(1e0, 1e2) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) -plt.savefig(os.path.join(output_dir, 'times_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, 'comparisons_gc_times.pdf'), dpi=150) ################################## @@ -201,8 +201,8 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') vpar_error_ax.set_ylabel(r'Relative $v_\parallel$ Error') -xyz_error_fig.savefig(os.path.join(output_dir, f'relative_xyz_error_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -vpar_error_fig.savefig(os.path.join(output_dir, f'relative_vpar_error_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +xyz_error_fig.savefig(os.path.join(output_dir, f'comparisons_gc_error_xyz.pdf'), dpi=150) +vpar_error_fig.savefig(os.path.join(output_dir, f'comparisons_gc_error_vpar.pdf'), dpi=150) quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) for tolerance_idx in range(len(trace_tolerance_array))] @@ -220,11 +220,11 @@ def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): ax.set_xticks(X_axis) ax.set_xticklabels(labels) -ax.set_ylabel("Time Averaged Relative Error") +ax.set_ylabel("Time-averaged relative error") ax.set_yscale('log') ax.set_ylim(1e-6, 1e-1) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=14) -plt.savefig(os.path.join(output_dir, 'relative_errors_gc_SIMSOPT_vs_ESSOS.pdf'), dpi=150) +plt.savefig(os.path.join(output_dir, 'comparisons_gc_error.pdf'), dpi=150) plt.show() \ No newline at end of file diff --git a/analysis/comparison_losses.py b/analysis/comparisons_simsopt/losses.py similarity index 90% rename from analysis/comparison_losses.py rename to analysis/comparisons_simsopt/losses.py index 4f58431..837e3ca 100644 --- a/analysis/comparison_losses.py +++ b/analysis/comparisons_simsopt/losses.py @@ -12,13 +12,13 @@ from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data -output_dir = os.path.join(os.path.dirname(__file__), 'output') +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) n_segments = 100 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') nfp_array = [3, 2, 5, 4, 2] curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] @@ -84,7 +84,7 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): CurveCurveDistance(curves_simsopt, 0.5).J() loss_coil_separation(coils_essos, 0.5) - # Running the second time for coils characteristics comparison + # Running the second time for losses comparison start_time = time() curvature_loss_simsopt = block_until_ready(2*sum([LpCurveCurvature(curve, p=2, threshold=0).J() for curve in curves_simsopt])) @@ -124,10 +124,14 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): t_ind_separation_avg_essos = time() - start_time print(f"Independence separation loss ESSOS: {ind_separation_loss_essos}") - length_error_avg = jnp.linalg.norm(length_loss_essos - length_loss_simsopt) - curvature_error_avg = jnp.linalg.norm(curvature_loss_essos - curvature_loss_simsopt) - separation_error_avg = jnp.linalg.norm(separation_loss_essos - separation_loss_simsopt) - ind_separation_error_avg = jnp.linalg.norm(ind_separation_loss_essos - ind_separation_loss_simsopt) + length_error_avg = jnp.linalg.norm(length_loss_essos - length_loss_simsopt) / jnp.linalg.norm(length_loss_simsopt) + if length_error_avg == 0: + length_error_avg = jnp.finfo(jnp.float64).eps + curvature_error_avg = jnp.linalg.norm(curvature_loss_essos - curvature_loss_simsopt) / jnp.linalg.norm(curvature_loss_simsopt) + if curvature_error_avg == 0: + curvature_error_avg = jnp.finfo(jnp.float64).eps + separation_error_avg = jnp.linalg.norm(separation_loss_essos - separation_loss_simsopt) / jnp.linalg.norm(separation_loss_simsopt) + ind_separation_error_avg = jnp.linalg.norm(ind_separation_loss_essos - ind_separation_loss_simsopt) / jnp.linalg.norm(ind_separation_loss_simsopt) print(f"length_error_avg: {length_error_avg:.2e}") print(f"curvature_error_avg: {curvature_error_avg:.2e}") print(f"separation_error_avg: {separation_error_avg:.2e}") @@ -152,13 +156,13 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): ax.set_xticks(X_axis) ax.set_xticklabels(labels) - ax.set_ylabel("Absolute error") + ax.set_ylabel("Relative error") ax.set_yscale("log") - ax.set_ylim(1e-17, 1e-2) + ax.set_ylim(1e-16, 1e-1) ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) plt.tight_layout() - plt.savefig(os.path.join(output_dir, f"comparison_error_losses_{name}.pdf"), transparent=True) + plt.savefig(os.path.join(output_dir, f"comparisons_losses_error_{name}.pdf"), transparent=True) plt.close() @@ -189,5 +193,5 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=12) plt.tight_layout() - plt.savefig(os.path.join(output_dir, f"comparison_time_losses_{name}.pdf"), transparent=True) + plt.savefig(os.path.join(output_dir, f"comparisons_losses_time_{name}.pdf"), transparent=True) plt.close() diff --git a/analysis/comparison_surfaces.py b/analysis/comparisons_simsopt/surfaces.py similarity index 95% rename from analysis/comparison_surfaces.py rename to analysis/comparisons_simsopt/surfaces.py index 3b5c238..5a6d70d 100644 --- a/analysis/comparison_surfaces.py +++ b/analysis/comparisons_simsopt/surfaces.py @@ -11,7 +11,7 @@ from simsopt.geo import SurfaceRZFourier as SurfaceRZFourier_SIMSOPT from simsopt.objectives import SquaredFlux as SquaredFlux_SIMSOPT -output_dir = os.path.join(os.path.dirname(__file__), 'output') +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -27,7 +27,7 @@ nphi = 32 # Initialize VMEC field -vmec_file = os.path.join(os.path.dirname(__file__), '../examples', 'input_files', +vmec_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') vmec = Vmec(vmec_file, ntheta=ntheta, nphi=nphi, close=False) @@ -60,7 +60,7 @@ surface_simsopt.unitnormal() field_simsopt.B() SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt).J() -surface_essos.gamma +block_until_ready(surface_essos.gamma) # Running the second time for surface characteristics comparison @@ -169,7 +169,7 @@ ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) plt.tight_layout() -plt.savefig(os.path.join(output_dir, f"comparison_error_surfaces.pdf"), transparent=True) +plt.savefig(os.path.join(output_dir, f"comparisons_surfaces_error.pdf"), transparent=True) # Labels and corresponding timings quantities = [ @@ -200,6 +200,6 @@ ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) ax.legend(fontsize=12) plt.tight_layout() -plt.savefig(os.path.join(output_dir, f"comparison_time_surfaces.pdf"), transparent=True) +plt.savefig(os.path.join(output_dir, f"comparisons_surfaces_time.pdf"), transparent=True) plt.show() diff --git a/analysis/comparisons_simsopt/vmec_import.py b/analysis/comparisons_simsopt/vmec_import.py new file mode 100644 index 0000000..adffffb --- /dev/null +++ b/analysis/comparisons_simsopt/vmec_import.py @@ -0,0 +1,139 @@ +import os +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from jax import block_until_ready, random +from essos.fields import Vmec as Vmec_essos +from simsopt.mhd import Vmec as Vmec_simsopt, vmec_compute_geometry + + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +wout_array = [os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', "wout_LandremanPaul2021_QA_reactorScale_lowres.nc"), + os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', "wout_n3are_R7.75B5.7.nc")] +name_array = ["LandremanPaulQA", 'NCSX'] + + +print(f'Output being saved to {output_dir}') +for name, wout in zip(name_array, wout_array): + print(f' Running comparison with VMEC file located at: {wout}') + + vmec_essos = Vmec_essos(wout) + vmec_simsopt = Vmec_simsopt(wout) + + s_array=jnp.linspace(0.2, 0.9, 10) + key = random.key(42) + + def absB_simsopt_func(s, theta, phi): + return vmec_compute_geometry(vmec_simsopt, s, theta, phi).modB[0][0][0] + def absB_essos_func(s, theta, phi): + return vmec_essos.AbsB([s, theta, phi]) + def B_simsopt_func(s, theta, phi): + g = vmec_compute_geometry(vmec_simsopt, s, theta, phi) + return jnp.array([g.B_sub_s * g.grad_s_X + g.B_sub_theta_vmec * g.grad_theta_vmec_X + g.B_sub_phi * g.grad_phi_X, + g.B_sub_s * g.grad_s_Y + g.B_sub_theta_vmec * g.grad_theta_vmec_Y + g.B_sub_phi * g.grad_phi_Y, + g.B_sub_s * g.grad_s_Z + g.B_sub_theta_vmec * g.grad_theta_vmec_Z + g.B_sub_phi * g.grad_phi_Z])[:,0,0,0] + def B_essos_func(s, theta, phi): + return vmec_essos.B([s, theta, phi]) + + def timed_B(s, function): + theta = random.uniform(key=key, minval=0, maxval=2 * jnp.pi) + phi = random.uniform(key=key, minval=0, maxval=2 * jnp.pi) + function(s, theta, phi) + time1 = time() + B = block_until_ready(function(s, theta, phi)) + time_taken = time()-time1 + return time_taken, B + + average_time_modB_simsopt = 0 + average_time_modB_essos = 0 + average_time_B_essos = 0 + average_time_B_simsopt = 0 + error_modB = 0 + error_B = 0 + for s in s_array: + time_modB_simsopt, modB_simsopt = timed_B(s, absB_simsopt_func) + average_time_modB_simsopt += time_modB_simsopt + + time_modB_essos, modB_essos = timed_B(s, absB_essos_func) + average_time_modB_essos += time_modB_essos + + time_B_essos, B_essos = timed_B(s, B_essos_func) + average_time_B_essos += time_B_essos + + time_B_simsopt, B_simsopt = timed_B(s, B_simsopt_func) + average_time_B_simsopt += time_B_simsopt + + error_modB += jnp.abs((modB_simsopt-modB_essos)/modB_simsopt) + error_B += jnp.abs((B_simsopt-B_essos)/B_simsopt) + + average_time_modB_simsopt /= len(s_array) + average_time_modB_essos /= len(s_array) + average_time_B_essos /= len(s_array) + average_time_B_simsopt /= len(s_array) + error_modB /= len(s_array) + error_B /= len(s_array) + + # Labels and corresponding absolute errors (ESSOS - SIMSOPT) + quantities_errors = [ + (r"$B$", jnp.mean(error_modB)), + (r"$\mathbf{B}$", jnp.mean(error_B)), + ] + + labels = [q[0] for q in quantities_errors] + error_vals = [q[1] for q in quantities_errors] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.4 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Relative error") + ax.set_yscale("log") + ax.set_ylim(1e-6, 1e-2) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparisons_VMEC_error_{name}.pdf"), transparent=True) + + # Labels and corresponding timings + print(f"Average time to compute |B| in SIMSOPT: {average_time_modB_simsopt:.6f} s") + print(f"Average time to compute B in SIMSOPT: {average_time_B_simsopt:.6f} s") + print(f"Average time to compute |B| in ESSOS: {average_time_modB_essos:.6f} s") + print(f"Average time to compute B in ESSOS: {average_time_B_essos:.6f} s") + print(f"Relative error in |B|: {jnp.mean(error_modB):.6f}") + print(f"Relative error in B: {jnp.mean(error_B):.6f}") + + quantities = [ + (r"$B$", average_time_modB_essos, average_time_modB_simsopt), + (r"$\mathbf{B}$", average_time_B_essos, average_time_B_simsopt), + ] + + labels = [q[0] for q in quantities] + essos_vals = [q[1] for q in quantities] + simsopt_vals = [q[2] for q in quantities] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.4 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") + ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Computation time (s)") + ax.set_yscale("log") + ax.set_ylim(1e-5, 1e-1) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + ax.legend(fontsize=12) + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparisons_VMEC_time_{name}.pdf"), transparent=True) + + plt.show() \ No newline at end of file diff --git a/analysis/fo_integrators.py b/analysis/fo_integrators.py index a194da0..d7b3783 100644 --- a/analysis/fo_integrators.py +++ b/analysis/fo_integrators.py @@ -76,7 +76,7 @@ ax.legend(fontsize=15, loc='upper left') ax.set_xlabel('Computation time (s)') -ax.set_ylabel('Relative Energy Error') +ax.set_ylabel('Relative energy error') ax.set_xscale('log') ax.set_yscale('log') ax.set_xlim(1e-1, 1e2) diff --git a/analysis/gc_integrators.py b/analysis/gc_integrators.py index f300145..4c5c354 100644 --- a/analysis/gc_integrators.py +++ b/analysis/gc_integrators.py @@ -82,7 +82,7 @@ for axis in [ax, ax_tol]: axis.legend(fontsize=15) - axis.set_ylabel('Relative Energy Error') + axis.set_ylabel('Relative energy error') axis.set_xscale('log') axis.set_yscale('log') axis.set_ylim(1e-16, 1e-4) diff --git a/analysis/gc_vs_fo.py b/analysis/gc_vs_fo.py index 258b70f..4517766 100644 --- a/analysis/gc_vs_fo.py +++ b/analysis/gc_vs_fo.py @@ -5,6 +5,7 @@ from time import time import jax.numpy as jnp import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) from essos.fields import BiotSavart from essos.coils import Coils from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE @@ -67,7 +68,7 @@ plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy()[0]/particles.energy-1), label='Guiding Center', color='red') plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy()[0]/particles.energy-1), label='Full Orbit', color='blue') plt.xlabel('Time (ms)') -plt.ylabel('Relative Energy Error') +plt.ylabel('Relative energy error') plt.xlim(0, tmax*1000) plt.ylim(bottom=0) plt.legend() diff --git a/analysis/gradients.py b/analysis/gradients.py index 8fa4939..8ff673e 100644 --- a/analysis/gradients.py +++ b/analysis/gradients.py @@ -110,7 +110,7 @@ plt.plot(h_list, fd_diff[3], "s-", label=f'6th order', clip_on=False, linewidth=2.5) plt.legend(fontsize=15) plt.xlabel('Finite differences stepsize h') -plt.ylabel('Relative difference') +plt.ylabel('Relative error') plt.xscale('log') plt.yscale('log') plt.ylim(1e-13, 1e-1) From bbf6cc4e08466ae1db31a1300f060a6d1979b5bf Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 9 Jul 2025 20:31:58 +0200 Subject: [PATCH 46/63] Update Ubuntu for workflows --- .github/workflows/build_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 7fb79ca..5f432c9 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -12,7 +12,7 @@ permissions: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: matrix: python-version: [ '3.9', '3.10', '3.11', '3.12'] From 7d807ff3c9197367266aa2ac18f915a69005d97e Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 22 Sep 2025 15:09:49 +0100 Subject: [PATCH 47/63] Fixed near-axis & surfaces examples --- essos/objective_functions.py | 41 ++++++++--------- essos/optimization.py | 7 ++- essos/surfaces.py | 61 +++++++------------------- examples/optimize_coils_and_surface.py | 20 ++++----- 4 files changed, 46 insertions(+), 83 deletions(-) diff --git a/essos/objective_functions.py b/essos/objective_functions.py index 92e0113..951e4ef 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -1,4 +1,5 @@ import jax +# from build.lib.essos import coils jax.config.update("jax_enable_x64", True) import jax.numpy as jnp from jax import jit, vmap @@ -72,16 +73,15 @@ def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, gradB_nearaxis = field_nearaxis.grad_B_axis.T gradB_coils = vmap(field.dB_by_dX)(points.T) - coil_length = loss_coil_length(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - + coil_length = field.coils.length + coil_curvature = field.coils.curvature B_difference_loss = jnp.sum(jnp.abs(jnp.array(B_coils)-jnp.array(B_nearaxis))) gradB_difference_loss = jnp.sum(jnp.abs(jnp.array(gradB_coils)-jnp.array(gradB_nearaxis))) - coil_length_loss = 1e3*jnp.max(loss_coil_length(coils, max_coil_length)) - coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(coils, max_coil_curvature)) - - + coil_length_loss = 1e3*jnp.max(jnp.maximum(0, coil_length - max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(jnp.maximum(0, coil_curvature - max_coil_curvature)) + + return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss # @partial(jit, static_argnums=(0, 1)) @@ -100,17 +100,17 @@ def difference_B_gradB_onaxis(nearaxis_field, coils_field): return jnp.array(B_coils)-jnp.array(B_nearaxis), jnp.array(gradB_coils)-jnp.array(gradB_nearaxis) -@partial(jit, static_argnums=(1, 2, 4, 5, 6, 7, 8)) -def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, +@partial(jit, static_argnums=(1, 4, 5, 6, 7, 8)) +def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, max_coil_length=42, n_segments=60, stellsym=True, max_coil_curvature=0.1): #len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) len_dofs_nearaxis = len(field_nearaxis.x) field=field_from_dofs(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len_dofs_nearaxis:], field_nearaxis) - coil_length = loss_coil_length(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - + coil_length = field.coils.length + coil_curvature = field.coils.curvature + elongation = new_field_nearaxis.elongation iota = new_field_nearaxis.iota @@ -118,14 +118,13 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves_shape, currents_scale B_difference_loss = 3*jnp.sum(jnp.abs(B_difference)) gradB_difference_loss = jnp.sum(jnp.abs(gradB_difference)) - coil_length_loss = 1e3*jnp.max(loss_coil_length(coils, max_coil_length)) - coil_curvature_loss = 1e3*jnp.max(loss_coil_curvature(coils, max_coil_curvature)) + coil_length_loss = 1e3*jnp.max(jnp.maximum(0, coil_length - max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(jnp.maximum(0, coil_curvature - max_coil_curvature)) elongation_loss = jnp.sum(jnp.abs(elongation)) iota_loss = 30/jnp.abs(iota) return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss+elongation_loss+iota_loss - def loss_particle_radial_drift(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True, maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) particles.to_full_orbit(field) @@ -362,14 +361,11 @@ def loss_BdotN(x, vmec, dofs_curves, currents_scale, nfp, max_coil_length=42, field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) bdotn_over_b = BdotN_over_B(vmec.surface, field) - coil_length = loss_coil_length(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - - bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) - coil_length_loss = jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) - + + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) + return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss @partial(jit, static_argnums=(1, 4, 5, 6)) @@ -377,7 +373,6 @@ def loss_BdotN_only(x, vmec, dofs_curves, currents_scale, nfp,n_segments=60, ste field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) bdotn_over_b = BdotN_over_B(vmec.surface, field) - bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) return bdotn_over_b_loss diff --git a/essos/optimization.py b/essos/optimization.py index ebe61a3..fb1a24b 100644 --- a/essos/optimization.py +++ b/essos/optimization.py @@ -31,7 +31,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- dofs_curves_shape = coils.dofs_curves.shape currents_scale = coils.currents_scale - loss_partial = partial(func, dofs_curves_shape=coils.dofs_curves.shape, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym, **kwargs) + loss_partial = partial(func, dofs_curves=coils.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym, **kwargs) ## Without JAX gradients, using finite differences result = least_squares(loss_partial, x0=initial_dofs, verbose=2, diff_step=1e-4, @@ -43,9 +43,8 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- # result = least_squares(loss_partial, x0=initial_dofs, verbose=2, jac=jac_loss_partial, # ftol=tolerance_optimization, gtol=tolerance_optimization, # xtol=1e-14, max_nfev=maximum_function_evaluations) - print("Starting optimization") - result = minimize(loss_partial, x0=initial_dofs, jac=jac_loss_partial, method=method, - tol=tolerance_optimization, options={'maxiter': maximum_function_evaluations, 'disp': True, 'gtol': 1e-14, 'ftol': 1e-14}) + ##result = minimize(loss_partial, x0=initial_dofs, jac=jac_loss_partial, method=method, + ## tol=tolerance_optimization, options={'maxiter': maximum_function_evaluations, 'disp': True, 'gtol': 1e-14, 'ftol': 1e-14}) dofs_curves = jnp.reshape(result.x[:len_dofs_curves], (dofs_curves_shape)) try: diff --git a/essos/surfaces.py b/essos/surfaces.py index 73371a3..008baa0 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -171,13 +171,11 @@ def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus=' self.angles = jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) - self._gamma = None - self._gammadash_theta = None - self._gammadash_phi = None - self._normal = None - self._unitnormal = None - self._area_element = None - self._AbsB = None + (self._gamma, self._gammadash_theta, self._gammadash_phi, + self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + + if hasattr(self, 'bmnc'): + self._AbsB = self._set_AbsB() @property def dofs(self): @@ -191,14 +189,12 @@ def dofs(self, new_dofs): indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] - self._gamma = None - self._gammadash_theta = None - self._gammadash_phi = None - self._normal = None - self._unitnormal = None - self._area_element = None - self._AbsB = None + (self._gamma, self._gammadash_theta, self._gammadash_phi, + self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + # if hasattr(self, 'bmnc'): + # self._AbsB = self._set_AbsB() + @partial(jit, static_argnames=['self']) def _set_gamma(self, rmnc_interp, zmns_interp): phi_2d = self.phi_2d angles = self.angles @@ -221,62 +217,37 @@ def _set_gamma(self, rmnc_interp, zmns_interp): normal = jnp.cross(gammadash_phi, gammadash_theta, axis=2) unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) - area_element = jnp.linalg.norm(jnp.cross(gammadash_theta, gammadash_phi, axis=2), axis=2) - - self._gamma = gamma - self._gammadash_theta = gammadash_theta - self._gammadash_phi = gammadash_phi - self._normal = normal - self._unitnormal = unitnormal - self._area_element = area_element - + + return (gamma, gammadash_theta, gammadash_phi, normal, unitnormal) + + @partial(jit, static_argnames=['self']) def _set_AbsB(self): angles_nyq = jnp.einsum('i,jk->ijk', self.xm_nyq, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn_nyq, self.phi_2d) AbsB = jnp.einsum('i,ijk->jk', self.bmnc_interp, jnp.cos(angles_nyq)) - self._AbsB = AbsB + return AbsB @property def gamma(self): - if self._gamma is None: - self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._gamma @property def gammadash_theta(self): - if self._gammadash_theta is None: - self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._gammadash_theta @property def gammadash_phi(self): - if self._gammadash_phi is None: - self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._gammadash_phi @property def normal(self): - if self._normal is None: - self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._normal @property def unitnormal(self): - if self._unitnormal is None: - self._set_gamma(self.rmnc_interp, self.zmns_interp) return self._unitnormal - - @property - def area_element(self): - if self._area_element is None: - self._set_gamma(self.rmnc_interp, self.zmns_interp) - return self._area_element - + @property def AbsB(self): - if self._AbsB is None: - if not hasattr(self, 'bmnc'): - raise AttributeError("AbsB is not available. Ensure that the bmnc attribute is set.") - self._set_AbsB() return self._AbsB @property diff --git a/examples/optimize_coils_and_surface.py b/examples/optimize_coils_and_surface.py index 7ff5e58..587daa3 100644 --- a/examples/optimize_coils_and_surface.py +++ b/examples/optimize_coils_and_surface.py @@ -117,25 +117,23 @@ def loss_normal_cross_GradB_dot_grad_B_dot_GradB_surface(surface, field): normal_cross_GradB_dot_grad_B_dot_GradB_surface = jnp.sum(normal_cross_GradB_surface * grad_B_dot_GradB_surface, axis=-1) return normal_cross_GradB_dot_grad_B_dot_GradB_surface -@partial(jit, static_argnums=(1, 3, 5, 6, 7, 8, 9, 10)) -def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves_shape, currents_scale, nfp, max_coil_length=42, +@partial(jit, static_argnums=(1, 5, 6, 7, 8, 9, 10)) +def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents_scale, nfp, max_coil_length=42, n_segments=60, stellsym=True, max_coil_curvature=0.5, target_B_on_surface=5.7): - len_dofs_curves_ravelled = dofs_curves_shape[0]*dofs_curves_shape[1]*dofs_curves_shape[2] - dofs_currents = x[len_dofs_curves_ravelled:-len(surface_all.x)-len(field_nearaxis.x)] - new_dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], dofs_curves_shape) field=field_from_dofs(x[:-len(surface_all.x)-len(field_nearaxis.x)] ,dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) surface.dofs = x[-len(surface_all.x)-len(field_nearaxis.x):-len(field_nearaxis.x)] field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len(field_nearaxis.x):], field_nearaxis) + + coil_length = field.coils.length + coil_curvature = field.coils.curvature + - coil_length = loss_coil_length(coils) - coil_curvature = loss_coil_curvature(coils) - - coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) - + coil_length_loss = 1e3*jnp.max(jnp.maximum(0, coil_length - max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(jnp.maximum(0, coil_curvature - max_coil_curvature)) + normal_cross_GradB_dot_grad_B_dot_GradB_surface = jnp.sum(jnp.abs(loss_normal_cross_GradB_dot_grad_B_dot_GradB_surface(surface, field))) bdotn_over_b = BdotN_over_B(surface, field) From 98534496cbe7914311f10065f91977a72c88eb87 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 22 Sep 2025 20:22:41 +0100 Subject: [PATCH 48/63] Fixed merge changes --- analysis/gc_vs_fo.py | 7 +- analysis/gradients.py | 4 +- essos/dynamics.py | 67 +++++-------------- essos/objective_functions.py | 24 +++---- essos/surfaces.py | 11 ++- examples/optimize_coils_and_surface.py | 2 +- examples/trace_fieldlines_coils.py | 4 +- .../trace_particles_coils_guidingcenter.py | 8 +-- 8 files changed, 47 insertions(+), 80 deletions(-) diff --git a/analysis/gc_vs_fo.py b/analysis/gc_vs_fo.py index 4517766..8090ae7 100644 --- a/analysis/gc_vs_fo.py +++ b/analysis/gc_vs_fo.py @@ -46,13 +46,16 @@ # Trace in ESSOS time0 = time() tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, - maxtime=tmax, timesteps=num_steps_gc, tol_step_size=trace_tolerance) + maxtime=tmax, timestep=num_steps_gc, atol=trace_tolerance, rtol=trace_tolerance, + times_to_trace=200) trajectories_guidingcenter = block_until_ready(tracing_gc.trajectories) print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") time0 = time() tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax, - timesteps=num_steps_fo, tol_step_size=trace_tolerance) + timestep=num_steps_fo, atol=trace_tolerance, rtol=trace_tolerance, + times_to_trace=200) + block_until_ready(tracing_fo.trajectories) print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") diff --git a/analysis/gradients.py b/analysis/gradients.py index 8ff673e..4fb04fe 100644 --- a/analysis/gradients.py +++ b/analysis/gradients.py @@ -45,10 +45,10 @@ coils = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) -loss_partial = partial(loss_BdotN, dofs_curves_shape=coils.dofs_curves.shape, currents_scale=coils.currents_scale, +loss_partial = partial(loss_BdotN, dofs_curves=coils.dofs_curves, currents_scale=coils.currents_scale, nfp=coils.nfp, n_segments=coils.n_segments, stellsym=coils.stellsym, vmec=vmec, max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature) - +print(loss_partial(coils.x)) grad_loss_partial = jit(grad(loss_partial)) time0 = time() diff --git a/essos/dynamics.py b/essos/dynamics.py index 7510da8..47da623 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -1,3 +1,4 @@ +from pyexpat import model import jax jax.config.update("jax_enable_x64", True) import jax.numpy as jnp @@ -598,50 +599,6 @@ def condition_BioSavart(t, y, args, **kwargs): self._trajectories = self.trace() - if self.particles is not None: - self.energy = jnp.zeros((self.particles.nparticles, self.times_to_trace)) - - if model == 'GuidingCenter' or model == 'GuidingCenterAdaptative' : - @jit - def compute_energy_gc(trajectory): - xyz = trajectory[:, :3] - vpar = trajectory[:, 3] - AbsB = vmap(self.field.AbsB)(xyz) - mu = (self.particles.energy - self.particles.mass * vpar[0]**2 / 2) / AbsB[0] - return self.particles.mass * vpar**2 / 2 + mu * AbsB - self.energy = vmap(compute_energy_gc)(self._trajectories) - elif model == 'GuidingCenterCollisions': - @jit - def compute_energy_gc(trajectory): - return 0.5*self.particles.mass* trajectory[:, 3]**2 - self.energy = vmap(compute_energy_gc)(self._trajectories) - elif model == 'GuidingCenterCollisionsMuIto' or model == 'GuidingCenterCollisionsMuFixed' or model == 'GuidingCenterCollisionsMuAdaptative' : - @jit - def compute_energy_gc(trajectory): - xyz = trajectory[:, :3] - vpar = trajectory[:, 3]*SPEED_OF_LIGHT - mu = trajectory[:, 4]*self.particles.mass*SPEED_OF_LIGHT**2 - AbsB = vmap(self.field.AbsB)(xyz) - return self.particles.mass * vpar**2 / 2 + mu*AbsB - self.energy = vmap(compute_energy_gc)(self._trajectories) - @jit - def compute_vperp_gc(trajectory): - xyz = trajectory[:, :3] - mu = trajectory[:, 4]*self.particles.mass*SPEED_OF_LIGHT**2 - AbsB = vmap(self.field.AbsB)(xyz) - return jnp.sqrt(2.*mu*AbsB/self.particles.mass) - self.vperp_final = vmap(compute_vperp_gc)(self._trajectories) - elif model == 'FullOrbit' or model == 'FullOrbit_Boris' or model == 'FullOrbitCollisions': - @jit - def compute_energy_fo(trajectory): - vxvyvz = trajectory[:, 3:] - return self.particles.mass / 2 * (vxvyvz[:, 0]**2 + vxvyvz[:, 1]**2 + vxvyvz[:, 2]**2) - self.energy = vmap(compute_energy_fo)(self._trajectories) - elif model == 'FieldLine' or model== 'FieldLineAdaptative': - self.energy = jnp.ones((len(initial_conditions), self.times_to_trace)) - - - self.trajectories_xyz = vmap(lambda xyz: vmap(lambda point: self.field.to_xyz(point[:3]))(xyz))(self.trajectories) if isinstance(field, Vmec): @@ -883,10 +840,11 @@ def trajectories(self, value): self._trajectories = value def energy(self): - assert self.model in ['GuidingCenter', 'FullOrbit'], "Energy calculation is only available for GuidingCenter and FullOrbit models" + assert 'GuidingCenter' in self.model or 'FullOrbit' in self.model, "Energy calculation is only available for GuidingCenter and FullOrbit models" mass = self.particles.mass - if self.model == 'GuidingCenter': + if self.model == 'GuidingCenter' or self.model == 'GuidingCenterAdaptative' or \ + self.model == 'GuidingCenterCollisionsMuIto' or self.model == 'GuidingCenterCollisionsMuFixed' or self.model == 'GuidingCenterCollisionsMuAdaptative': initial_xyz = self.initial_conditions[:, :3] initial_vparallel = self.initial_conditions[:, 3] initial_B = vmap(self.field.AbsB)(initial_xyz) @@ -898,17 +856,24 @@ def compute_energy(trajectory, mu): return 0.5 * mass * jnp.square(vpar) + mu * AbsB energy = vmap(compute_energy)(self.trajectories, mu_array) - + + elif self.model == 'GuidingCenterCollisions': + def compute_energy(trajectory): + return 0.5 * mass * trajectory[:, 3]**2 + energy = vmap(compute_energy)(self.trajectories) + elif self.model == 'FullOrbit': def compute_energy(trajectory): vxvyvz = trajectory[:, 3:] v_squared = jnp.sum(jnp.square(vxvyvz), axis=1) return 0.5 * mass * v_squared - energy = vmap(compute_energy)(self.trajectories) + elif self.model == 'FieldLine' or self.model == 'FieldLineAdaptative': + energy = jnp.ones((len(self.initial_conditions), self.times_to_trace)) + return energy - + def to_vtk(self, filename): try: import numpy as np except ImportError: raise ImportError("The 'numpy' library is required. Please install it using 'pip install numpy'.") @@ -1085,8 +1050,8 @@ def process_trajectory(X_i, Y_i, T_i): def _tree_flatten(self): children = (self.trajectories, self.initial_conditions, self.times) # arrays / dynamic values - aux_data = {'field': self.field, 'model': self.model, 'method': self.method, 'maxtime': self.maxtime, 'timesteps': self.timesteps,'stepsize': - self.stepsize, 'tol_step_size': self.tol_step_size, 'particles': self.particles, 'condition': self.condition} # static values + aux_data = {'field': self.field, 'electric_field': self.electric_field, 'model': self.model, 'maxtime': self.maxtime, 'timestep': self.timestep, + 'rtol': self.rtol, 'atol': self.atol, 'particles': self.particles, 'condition': self.condition, 'tag_gc': self.tag_gc} # static values return (children, aux_data) @classmethod diff --git a/essos/objective_functions.py b/essos/objective_functions.py index 951e4ef..a07aac4 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -72,15 +72,12 @@ def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, gradB_nearaxis = field_nearaxis.grad_B_axis.T gradB_coils = vmap(field.dB_by_dX)(points.T) - - coil_length = field.coils.length - coil_curvature = field.coils.curvature + B_difference_loss = jnp.sum(jnp.abs(jnp.array(B_coils)-jnp.array(B_nearaxis))) gradB_difference_loss = jnp.sum(jnp.abs(jnp.array(gradB_coils)-jnp.array(gradB_nearaxis))) - coil_length_loss = 1e3*jnp.max(jnp.maximum(0, coil_length - max_coil_length)) - coil_curvature_loss = 1e3*jnp.max(jnp.maximum(0, coil_curvature - max_coil_curvature)) - + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss @@ -107,9 +104,6 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, len_dofs_nearaxis = len(field_nearaxis.x) field=field_from_dofs(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len_dofs_nearaxis:], field_nearaxis) - - coil_length = field.coils.length - coil_curvature = field.coils.curvature elongation = new_field_nearaxis.elongation iota = new_field_nearaxis.iota @@ -118,8 +112,8 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, B_difference_loss = 3*jnp.sum(jnp.abs(B_difference)) gradB_difference_loss = jnp.sum(jnp.abs(gradB_difference)) - coil_length_loss = 1e3*jnp.max(jnp.maximum(0, coil_length - max_coil_length)) - coil_curvature_loss = 1e3*jnp.max(jnp.maximum(0, coil_curvature - max_coil_curvature)) + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) elongation_loss = jnp.sum(jnp.abs(elongation)) iota_loss = 30/jnp.abs(iota) @@ -199,7 +193,7 @@ def loss_particle_r_cross_final(x,particles,dofs_curves, currents_scale, nfp,n_s r_cross=jnp.sqrt(jnp.square(jnp.sqrt(jnp.square(xyz[:,:,0])+jnp.square(xyz[:,:,1]))-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)) return jnp.linalg.norm((jnp.average(r_cross,axis=1))) -def loss_particle_r_cross_max_constraint(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_r=0.4,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): +def loss_particle_r_cross_max(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_r=0.4,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) #particles.to_full_orbit(field) tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, @@ -337,8 +331,8 @@ def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, curr particles_drift_loss = loss_particle_radial_drift(x,dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym, particles=particles, maxtime=maxtime, num_steps=num_steps, trace_tolerance=trace_tolerance, model=model,boundary=boundary) normB_axis_loss = loss_normB_axis(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) - coil_length_loss = loss_coil_length(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature_loss = loss_coil_curvature(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) loss = jnp.concatenate((normB_axis_loss, coil_length_loss, coil_curvature_loss,particles_drift_loss)) return jnp.sum(loss) @@ -364,7 +358,7 @@ def loss_BdotN(x, vmec, dofs_curves, currents_scale, nfp, max_coil_length=42, bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) - coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) + coil_curvature_loss = jnp.maximum(0, jnp.max(jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature)) return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss diff --git a/essos/surfaces.py b/essos/surfaces.py index 008baa0..75361cb 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -172,7 +172,7 @@ def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus=' self.angles = jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + self._normal, self._unitnormal, self._area_element) = self._set_gamma(self.rmnc_interp, self.zmns_interp) if hasattr(self, 'bmnc'): self._AbsB = self._set_AbsB() @@ -190,7 +190,7 @@ def dofs(self, new_dofs): self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + self._normal, self._unitnormal, self._area_element) = self._set_gamma(self.rmnc_interp, self.zmns_interp) # if hasattr(self, 'bmnc'): # self._AbsB = self._set_AbsB() @@ -217,8 +217,9 @@ def _set_gamma(self, rmnc_interp, zmns_interp): normal = jnp.cross(gammadash_phi, gammadash_theta, axis=2) unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) + area_element = jnp.linalg.norm(jnp.cross(gammadash_theta, gammadash_phi, axis=2), axis=2) - return (gamma, gammadash_theta, gammadash_phi, normal, unitnormal) + return (gamma, gammadash_theta, gammadash_phi, normal, unitnormal, area_element) @partial(jit, static_argnames=['self']) def _set_AbsB(self): @@ -246,6 +247,10 @@ def normal(self): def unitnormal(self): return self._unitnormal + @property + def area_element(self): + return self._area_element + @property def AbsB(self): return self._AbsB diff --git a/examples/optimize_coils_and_surface.py b/examples/optimize_coils_and_surface.py index 587daa3..f005c50 100644 --- a/examples/optimize_coils_and_surface.py +++ b/examples/optimize_coils_and_surface.py @@ -20,7 +20,7 @@ ntheta=30 nphi=30 -input = os.path.join('input_files','input.rotating_ellipse') +input = os.path.join(os.path.dirname(__file__), 'input_files','input.rotating_ellipse') surface_initial = SurfaceRZFourier(input, ntheta=ntheta, nphi=nphi, range_torus='half period') # Optimization parameters diff --git a/examples/trace_fieldlines_coils.py b/examples/trace_fieldlines_coils.py index 2ea2305..c92148a 100644 --- a/examples/trace_fieldlines_coils.py +++ b/examples/trace_fieldlines_coils.py @@ -6,7 +6,7 @@ from jax import block_until_ready import matplotlib.pyplot as plt from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.dynamics import Tracing # Input parameters @@ -19,7 +19,7 @@ # Load coils and field json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Initialize particles diff --git a/examples/trace_particles_coils_guidingcenter.py b/examples/trace_particles_coils_guidingcenter.py index 018317c..6634674 100644 --- a/examples/trace_particles_coils_guidingcenter.py +++ b/examples/trace_particles_coils_guidingcenter.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, ONE_EV from essos.dynamics import Tracing, Particles @@ -21,8 +21,8 @@ energy=4000*ONE_EV # Load coils and field -json_file = os.path.join(os.path.dirname(__name__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Initialize particles @@ -49,7 +49,7 @@ tracing.plot(ax=ax1, show=False) for i, trajectory in enumerate(trajectories): - ax2.plot(tracing.times, jnp.abs(tracing.energy[i]-particles.energy)/particles.energy, label=f'Particle {i+1}') + ax2.plot(tracing.times, jnp.abs(tracing.energy()[i]-particles.energy)/particles.energy, label=f'Particle {i+1}') ax3.plot(tracing.times, trajectory[:, 3]/particles.total_speed, label=f'Particle {i+1}') ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') ax2.set_xlabel('Time (s)') From 1395650a8d34aa7d7646daa278782fa3a5cb1be2 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 22 Sep 2025 20:25:08 +0100 Subject: [PATCH 49/63] Deleting comparison_simsopt folder in examples --- .../coils_biotsavart_SIMSOPT_vs_ESSOS.py | 260 ------------------ .../fieldlines_SIMSOPT_vs_ESSOS.py | 195 ------------- .../fullorbit_SIMSOPT_vs_ESSOS.py | 260 ------------------ .../guiding_center_SIMSOPT_vs_ESSOS.py | 248 ----------------- .../surfaces_SIMSOPT_vs_ESSOS.py | 72 ----- .../vmec_SIMSOPT_vs_ESSOS.py | 111 -------- 6 files changed, 1146 deletions(-) delete mode 100644 examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py delete mode 100644 examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py delete mode 100644 examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py delete mode 100644 examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py delete mode 100644 examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py delete mode 100644 examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py diff --git a/examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index c89fc99..0000000 --- a/examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,260 +0,0 @@ -import os -from time import time -import jax.numpy as jnp -import matplotlib.pyplot as plt -from jax import block_until_ready -from essos.fields import BiotSavart as BiotSavart_essos -from essos.coils import Coils_from_simsopt, Curves_from_simsopt -from simsopt import load -from simsopt.geo import CurveXYZFourier, curves_to_vtk -from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries -from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -list_segments = [30, 100, 300, 1000, 3000] - -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -nfp_array = [3, 2, 5, 4, 2] -curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] -currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] -name_array = ["NCSX", "QA(json)", "W7-X", "HSX", "Giuliani"] - -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for nfp, curves_stel, currents_stel, name in zip(nfp_array, curves_array, currents_array, name_array): - print(f' Running {name} and saving to output directory...') - if currents_stel is None: - json_file_stel = curves_stel - field_simsopt = load(json_file_stel) - coils_simsopt = field_simsopt.coils - curves_simsopt = [coil.curve for coil in coils_simsopt] - currents_simsopt = [coil.current for coil in coils_simsopt] - coils_essos = Coils_from_simsopt(json_file_stel, nfp) - curves_essos = Curves_from_simsopt(json_file_stel, nfp) - else: - coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) - curves_simsopt = [c.curve for c in coils_simsopt] - currents_simsopt = [c.current for c in coils_simsopt] - field_simsopt = BiotSavart_simsopt(coils_simsopt) - - coils_essos = Coils_from_simsopt(coils_simsopt, nfp) - curves_essos = Curves_from_simsopt(curves_simsopt, nfp) - - field_essos = BiotSavart_essos(coils_essos) - - coils_essos_to_simsopt = coils_essos.to_simsopt() - curves_essos_to_simsopt = curves_essos.to_simsopt() - field_essos_to_simsopt = BiotSavart_simsopt(coils_essos_to_simsopt) - - curves_to_vtk(curves_simsopt, os.path.join(output_dir,f"curves_simsopt_{name}")) - curves_essos.to_vtk(os.path.join(output_dir,f"curves_essos_{name}")) - curves_to_vtk(curves_essos_to_simsopt, os.path.join(output_dir,f"curves_essos_to_simsopt_{name}")) - - base_coils_simsopt = coils_simsopt[:int(len(coils_simsopt)/2/nfp)] - R = jnp.mean(jnp.array([jnp.sqrt(coil.curve.x[coil.curve.local_dof_names.index('xc(0)')]**2 - +coil.curve.x[coil.curve.local_dof_names.index('yc(0)')]**2) - for coil in base_coils_simsopt])) - x = jnp.array([R+0.01,R,R]) - y = jnp.array([R,R+0.01,R-0.01]) - z = jnp.array([0.05,0.06,0.07]) - - positions = jnp.array((x,y,z)) - - len_list_segments = len(list_segments) - t_gamma_avg_essos = jnp.zeros(len_list_segments) - t_gamma_avg_simsopt = jnp.zeros(len_list_segments) - gamma_error_avg = jnp.zeros(len_list_segments) - t_gammadash_avg_essos = jnp.zeros(len_list_segments) - t_gammadash_avg_simsopt = jnp.zeros(len_list_segments) - gammadash_error_avg = jnp.zeros(len_list_segments) - t_gammadashdash_avg_essos = jnp.zeros(len_list_segments) - t_gammadashdash_avg_simsopt = jnp.zeros(len_list_segments) - gammadashdash_error_avg = jnp.zeros(len_list_segments) - t_curvature_avg_essos = jnp.zeros(len_list_segments) - t_curvature_avg_simsopt = jnp.zeros(len_list_segments) - curvature_error_avg = jnp.zeros(len_list_segments) - t_B_avg_essos = jnp.zeros(len_list_segments) - t_B_avg_simsopt = jnp.zeros(len_list_segments) - B_error_avg = jnp.zeros(len_list_segments) - t_dB_by_dX_avg_essos = jnp.zeros(len_list_segments) - t_dB_by_dX_avg_simsopt = jnp.zeros(len_list_segments) - dB_by_dX_error_avg = jnp.zeros(len_list_segments) - - gamma_error_simsopt_to_essos = 0 - gamma_error_essos_to_simsopt = 0 - - for i, (coil_simsopt, coil_essos_gamma, coil_essos_to_simsopt) in enumerate(zip(coils_simsopt, coils_essos.gamma, coils_essos_to_simsopt)): - gamma_error_simsopt_to_essos += jnp.linalg.norm(coil_simsopt.curve.gamma()-coil_essos_gamma) - gamma_error_essos_to_simsopt += jnp.linalg.norm(coil_simsopt.curve.gamma()-coil_essos_to_simsopt.curve.gamma()) - - B_error_avg_simsopt_to_essos = 0 - B_error_avg_essos_to_simsopt = 0 - for j, position in enumerate(positions): - field_simsopt.set_points([position]) - field_essos_to_simsopt.set_points([position]) - B_simsopt = field_simsopt.B() - B_essos_to_simsopt = field_essos_to_simsopt.B() - B_simsopt_to_essos = field_essos.B(position) - B_error_avg_simsopt_to_essos += jnp.abs(jnp.linalg.norm(B_simsopt) - jnp.linalg.norm(B_simsopt_to_essos)) - B_error_avg_essos_to_simsopt += jnp.abs(jnp.linalg.norm(B_simsopt) - jnp.linalg.norm(B_essos_to_simsopt)) - B_error_avg_simsopt_to_essos = B_error_avg_simsopt_to_essos/len(positions) - B_error_avg_essos_to_simsopt = B_error_avg_essos_to_simsopt/len(positions) - - fig = plt.figure(figsize = (8, 6)) - X_axis = jnp.arange(2) - plt.bar(X_axis[0] - 0.2, gamma_error_simsopt_to_essos+1e-19, 0.3, label='SIMSOPT to ESSOS coils', color='blue', edgecolor='black', hatch='/') - plt.bar(X_axis[0] + 0.2, gamma_error_essos_to_simsopt+1e-19, 0.3, label='ESSOS to SIMSOPT coils', color='red', edgecolor='black', hatch='-') - plt.bar(X_axis[1] - 0.2, B_error_avg_simsopt_to_essos+1e-19, 0.3, label=r'SIMSOPT to ESSOS $B$', color='blue', edgecolor='black', hatch='||') - plt.bar(X_axis[1] + 0.2, B_error_avg_essos_to_simsopt+1e-19, 0.3, label=r'ESSOS to SIMSOPT $B$', color='red', edgecolor='black', hatch='*') - plt.xticks(X_axis, ['Coil Error', 'B Error']) - plt.xlabel('Parameter', fontsize=14) - plt.ylabel('Error Magnitude', fontsize=14) - plt.yscale('log') - plt.ylim(1e-20, 1e-11) - plt.legend(fontsize=14) - plt.grid(axis='y') - plt.title(f"{name}", fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"error_gamma_B_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() - - def update_nsegments_simsopt(curve_simsopt, n_segments): - new_curve = CurveXYZFourier(n_segments, curve_simsopt.order) - new_curve.x = curve_simsopt.x - return new_curve - - for index, n_segments in enumerate(list_segments): - coils_essos.n_segments = n_segments - - base_curves_simsopt = [update_nsegments_simsopt(coil_simsopt.curve, n_segments) for coil_simsopt in base_coils_simsopt] - coils_simsopt = coils_via_symmetries(base_curves_simsopt, currents_simsopt[0:len(base_coils_simsopt)], nfp, True) - curves_simsopt = [c.curve for c in coils_simsopt] - - [curve.gamma() for curve in curves_simsopt] - coils_essos.gamma - - start_time = time() - gamma_curves_simsopt = block_until_ready(jnp.array([curve.gamma() for curve in curves_simsopt])) - t_gamma_avg_simsopt = t_gamma_avg_simsopt.at[index].set(t_gamma_avg_simsopt[index] + time() - start_time) - - start_time = time() - gamma_curves_essos = block_until_ready(jnp.array(coils_essos.gamma)) - t_gamma_avg_essos = t_gamma_avg_essos.at[index].set(t_gamma_avg_essos[index] + time() - start_time) - - start_time = time() - gammadash_curves_simsopt = block_until_ready(jnp.array([curve.gammadash() for curve in curves_simsopt])) - t_gammadash_avg_simsopt = t_gammadash_avg_simsopt.at[index].set(t_gammadash_avg_simsopt[index] + time() - start_time) - - start_time = time() - gammadash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dash)) - t_gammadash_avg_essos = t_gammadash_avg_essos.at[index].set(t_gammadash_avg_essos[index] + time() - start_time) - - start_time = time() - gammadashdash_curves_simsopt = block_until_ready(jnp.array([curve.gammadashdash() for curve in curves_simsopt])) - t_gammadashdash_avg_simsopt = t_gammadashdash_avg_simsopt.at[index].set(t_gammadashdash_avg_simsopt[index] + time() - start_time) - - start_time = time() - gammadashdash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dashdash)) - t_gammadashdash_avg_essos = t_gammadashdash_avg_essos.at[index].set(t_gammadashdash_avg_essos[index] + time() - start_time) - - start_time = time() - curvature_curves_simsopt = block_until_ready(jnp.array([curve.kappa() for curve in curves_simsopt])) - t_curvature_avg_simsopt = t_curvature_avg_simsopt.at[index].set(t_curvature_avg_simsopt[index] + time() - start_time) - - start_time = time() - curvature_curves_essos = block_until_ready(jnp.array(coils_essos.curvature)) - t_curvature_avg_essos = t_curvature_avg_essos.at[index].set(t_curvature_avg_essos[index] + time() - start_time) - - gamma_error_avg = gamma_error_avg. at[index].set(gamma_error_avg[index] + jnp.linalg.norm(gamma_curves_essos - gamma_curves_simsopt)) - gammadash_error_avg = gammadash_error_avg. at[index].set(gammadash_error_avg[index] + jnp.linalg.norm(gammadash_curves_essos - gammadash_curves_simsopt)) - gammadashdash_error_avg = gammadashdash_error_avg.at[index].set(gammadashdash_error_avg[index] + jnp.linalg.norm(gammadashdash_curves_essos - gammadashdash_curves_simsopt)) - curvature_error_avg = curvature_error_avg.at[index].set(curvature_error_avg[index] + jnp.linalg.norm(curvature_curves_essos - curvature_curves_simsopt)) - - field_essos = BiotSavart_essos(coils_essos) - field_simsopt = BiotSavart_simsopt(coils_simsopt) - - for j, position in enumerate(positions): - field_essos.B(position) - time1 = time() - result_B_essos = field_essos.B(position) - t_B_avg_essos = t_B_avg_essos.at[index].set(t_B_avg_essos[index] + time() - time1) - normB_essos = jnp.linalg.norm(result_B_essos) - - field_simsopt.set_points(jnp.array([position])) - field_simsopt.B() - time3 = time() - field_simsopt.set_points(jnp.array([position])) - result_simsopt = field_simsopt.B() - t_B_avg_simsopt = t_B_avg_simsopt.at[index].set(t_B_avg_simsopt[index] + time() - time3) - normB_simsopt = jnp.linalg.norm(jnp.array(result_simsopt)) - - B_error_avg = B_error_avg.at[index].set(B_error_avg[index] + jnp.abs(normB_essos - normB_simsopt)) - - field_essos.dB_by_dX(position) - time1 = time() - field_simsopt.set_points(jnp.array([position])) - result_dB_by_dX_essos = field_essos.dB_by_dX(position) - t_dB_by_dX_avg_essos = t_dB_by_dX_avg_essos.at[index].set(t_dB_by_dX_avg_essos[index] + time() - time1) - norm_dB_by_dX_essos = jnp.linalg.norm(result_dB_by_dX_essos) - - field_simsopt.dB_by_dX() - time3 = time() - field_simsopt.set_points(jnp.array([position])) - result_dB_by_dX_simsopt = field_simsopt.dB_by_dX() - t_dB_by_dX_avg_simsopt = t_dB_by_dX_avg_simsopt.at[index].set(t_dB_by_dX_avg_simsopt[index] + time() - time3) - norm_dB_by_dX_simsopt = jnp.linalg.norm(jnp.array(result_dB_by_dX_simsopt)) - - dB_by_dX_error_avg = dB_by_dX_error_avg.at[index].set(dB_by_dX_error_avg[index] + jnp.abs(norm_dB_by_dX_essos - norm_dB_by_dX_simsopt)) - - X_axis = jnp.arange(len_list_segments) - - fig = plt.figure(figsize = (8, 6)) - plt.bar(X_axis-0.2, B_error_avg, 0.1, label = r"$B_{\text{essos}} - B_{\text{simsopt}}$", color="green", edgecolor="black", hatch="/") - plt.bar(X_axis-0.1, dB_by_dX_error_avg, 0.1, label = r"${B'}_{\text{essos}} - {B'}_{\text{simsopt}}$", color="purple", edgecolor="black", hatch="x") - plt.bar(X_axis+0.0, gamma_error_avg, 0.1, label = r"$\Gamma_{\text{essos}} - \Gamma_{\text{simsopt}}$", color="orange", edgecolor="black", hatch="|") - plt.bar(X_axis+0.1, gammadash_error_avg, 0.1, label = r"${\Gamma'}_{\text{essos}} - {\Gamma'}_{\text{simsopt}}$", color="gray", edgecolor="black", hatch="-") - plt.bar(X_axis+0.2, gammadashdash_error_avg, 0.1, label = r"${\Gamma''}_{\text{essos}} - {\Gamma''}_{\text{simsopt}}$", color="black", edgecolor="black", hatch="*") - plt.bar(X_axis+0.3, curvature_error_avg, 0.1, label = r"$\kappa_{\text{essos}} - \kappa_{\text{simsopt}}$", color="brown", edgecolor="black", hatch="\\") - plt.xticks(X_axis, list_segments) - plt.xlabel("Number of segments of each coil", fontsize=14) - plt.ylabel(f"Difference SIMSOPT vs ESSOS", fontsize=14) - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.legend(fontsize=14) - plt.yscale("log") - plt.grid(axis='y') - plt.ylim(1e-18, 1e-10) - plt.title(f"{name}", fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"error_BiotSavart_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() - - fig = plt.figure(figsize = (8, 6)) - plt.bar(X_axis - 0.30, t_B_avg_essos, 0.05, label = r'B ESSOS', color="red", edgecolor="black") - plt.bar(X_axis - 0.25, t_B_avg_simsopt, 0.05, label = r'B SIMSOPT', color="blue", edgecolor="black") - plt.bar(X_axis - 0.20, t_dB_by_dX_avg_essos, 0.05, label = r"$B'$ ESSOS", color="red", edgecolor="black") - plt.bar(X_axis - 0.15, t_dB_by_dX_avg_simsopt, 00.05, label = r"$B'$ SIMSOPT", color="blue", edgecolor="black") - plt.bar(X_axis - 0.10, t_gamma_avg_essos, 0.05, label = r'$\Gamma$ ESSOS', color="red", edgecolor="black", hatch="//") - plt.bar(X_axis - 0.05, t_gamma_avg_simsopt, 0.05, label = r'$\Gamma$ SIMSOPT', color="blue", edgecolor="black", hatch="-") - plt.bar(X_axis + 0.0, t_gammadash_avg_essos, 0.05, label = r"${\Gamma'}$ ESSOS", color="red", edgecolor="black", hatch="\\") - plt.bar(X_axis + 0.05, t_gammadash_avg_simsopt, 0.05, label = r"${\Gamma'}$ SIMSOPT", color="blue", edgecolor="black", hatch="||") - plt.bar(X_axis + 0.10, t_gammadashdash_avg_essos, 0.05, label = r"${\Gamma''}$ ESSOS", color="red", edgecolor="black", hatch="*") - plt.bar(X_axis + 0.15, t_gammadashdash_avg_simsopt, 0.05, label = r"${\Gamma''}$ SIMSOPT", color="blue", edgecolor="black", hatch="|") - plt.bar(X_axis + 0.20, t_curvature_avg_essos, 0.05, label = r"$\kappa$ ESSOS", color="red", edgecolor="black", hatch="x") - plt.bar(X_axis + 0.25, t_curvature_avg_simsopt, 0.05, label = r"$\kappa$ SIMSOPT", color="blue", edgecolor="black", hatch="+") - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.xticks(X_axis, list_segments) - plt.xlabel("Number of segments of each coil", fontsize=14) - plt.ylabel("Time to evaluate SIMSOPT vs ESSOS (s)", fontsize=14) - plt.grid(axis='y') - # plt.gca().set_ylim((None,0.03)) - plt.yscale("log") - plt.legend(fontsize=14) - plt.title(f"{name}", fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"time_BiotSavart_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() diff --git a/examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index fd0473e..0000000 --- a/examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,195 +0,0 @@ -import os -import time -import jax.numpy as jnp -from jax import block_until_ready -from simsopt import load -from simsopt.field import (particles_to_vtk, compute_fieldlines, plot_poincare_data) -from essos.coils import Coils_from_simsopt -from essos.dynamics import Tracing -from essos.fields import BiotSavart as BiotSavart_essos -import matplotlib.pyplot as plt - -tmax_fl = 150 -nfieldlines = 3 -axis_shft=0.02 -R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nfieldlines) -nfp = 2 -trace_tolerance_SIMSOPT_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] -trace_tolerance_ESSOS = 1e-7 - -Z0 = jnp.zeros(nfieldlines) -phi0 = jnp.zeros(nfieldlines) - -phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) - -fieldlines_SIMSOPT_array = [] -time_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 - -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: - print(f' Tracing SIMSOPT fieldlines with tolerance={trace_tolerance_SIMSOPT}') - t1 = time.time() - fieldlines_SIMSOPT_this_tolerance, fieldlines_SIMSOPT_phi_hits = block_until_ready(compute_fieldlines(field_simsopt, R0, Z0, tmax=tmax_fl, tol=trace_tolerance_SIMSOPT, phis=phis_poincare)) - time_SIMSOPT_array.append(time.time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in fieldlines_SIMSOPT_this_tolerance])//nfieldlines - print(f" Time for SIMSOPT tracing={time.time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - fieldlines_SIMSOPT_array.append(fieldlines_SIMSOPT_this_tolerance) - -particles_to_vtk(fieldlines_SIMSOPT_this_tolerance, os.path.join(output_dir,f'fieldlines_SIMSOPT')) -# plot_poincare_data(fieldlines_phi_hits, phis_poincare, os.path.join(output_dir,f'poincare_fieldline_SIMSOPT.pdf'), dpi=150) - -# Trace in ESSOS -num_steps_essos = int(jnp.mean(jnp.array([len(fieldlines_SIMSOPT[0]) for fieldlines_SIMSOPT in fieldlines_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_fl, num_steps_essos) - -print(f'Tracing ESSOS fieldlines with tolerance={trace_tolerance_ESSOS}') -t1 = time.time() -tracing = block_until_ready(Tracing(field=field_essos, model='FieldLine', initial_conditions=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T, - maxtime=tmax_fl, timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) -fieldlines_ESSOS = tracing.trajectories -time_ESSOS = time.time()-t1 -print(f" Time for ESSOS tracing={time.time()-t1:.3f}s. Num steps={len(fieldlines_ESSOS[0])}") - -tracing.to_vtk(os.path.join(output_dir,f'fieldlines_ESSOS')) -# tracing.poincare_plot(phis_poincare, show=False) - -print('Plotting the results to output directory...') -# Plot time comparison in a bar chart -labels = [f'SIMSOPT\nTol={tol}' for tol in trace_tolerance_SIMSOPT_array] + [f'ESSOS\nTol={trace_tolerance_ESSOS}'] -times = time_SIMSOPT_array + [time_ESSOS] -plt.figure() -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black'], hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -orange_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -plt.legend(handles=[blue_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(fieldine_SIMSOPT, fieldline_ESSOS): - time_SIMSOPT = jnp.array(fieldine_SIMSOPT)[:, 0] # Time values from fieldlines_SIMSOPT - # coords_SIMSOPT = jnp.array(fieldine_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from fieldlines_SIMSOPT - coords_ESSOS = jnp.array(fieldline_ESSOS) - - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z]) - - return coords_ESSOS_interp - -relative_error_array = [] -for i, fieldlines_SIMSOPT in enumerate(fieldlines_SIMSOPT_array): - fieldlines_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(fieldlines_SIMSOPT[i], fieldlines_ESSOS[i]) for i in range(nfieldlines)] - tracing.trajectories = fieldlines_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'fieldlines_ESSOS_interp')) - - relative_error_fieldlines_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nfieldlines): - this_fieldline_SIMSOPT = jnp.array(fieldlines_SIMSOPT[j])[:,1:] - this_fieldlines_ESSOS = fieldlines_ESSOS_interp[j] - average_relative_error = [] - for fieldline_SIMSOPT_t, fieldline_ESSOS_t in zip(this_fieldline_SIMSOPT, this_fieldlines_ESSOS): - relative_error_x = jnp.abs(fieldline_SIMSOPT_t[0] - fieldline_ESSOS_t[0])/(jnp.abs(fieldline_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(fieldline_SIMSOPT_t[1] - fieldline_ESSOS_t[1])/(jnp.abs(fieldline_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(fieldline_SIMSOPT_t[2] - fieldline_ESSOS_t[2])/(jnp.abs(fieldline_SIMSOPT_t[2])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z)/3) - average_relative_error = jnp.array(average_relative_error) - relative_error_fieldlines_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_fl, len(average_relative_error))[1:], average_relative_error[1:], label=f'Fieldline {j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_fieldlines_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - # relative_error_fieldlines_SIMSOPT_vs_ESSOS = jnp.array(relative_error_fieldlines_SIMSOPT_vs_ESSOS) - # print(f"Relative difference between SIMSOPT and ESSOS fieldlines={relative_error_fieldlines_SIMSOPT_vs_ESSOS}") - relative_error_array.append(relative_error_fieldlines_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nfieldlines): - R_SIMSOPT = jnp.sqrt(fieldlines_SIMSOPT[j][:,1]**2+fieldlines_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(fieldlines_SIMSOPT[j][:,2], fieldlines_SIMSOPT[j][:,1]) - Z_SIMSOPT = fieldlines_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(fieldlines_ESSOS_interp[j][:,0]**2+fieldlines_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(fieldlines_ESSOS_interp[j][:,1], fieldlines_ESSOS_interp[j][:,0]) - Z_ESSOS = fieldlines_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.savefig(os.path.join(output_dir,f'fieldlines_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - -# Calculate RMS error for each tolerance -rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - -# Plot RMS error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Fieldline {i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('RMS Error') -plt.yscale('log') -plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'rms_error_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate maximum error for each tolerance -max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot maximum error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Fieldline {i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Maximum Error') -plt.yscale('log') -plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'max_error_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate mean error for each tolerance -mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot mean error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Fieldline {i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Mean Error') -plt.yscale('log') -plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'mean_error_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() diff --git a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index fa5fe45..0000000 --- a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,260 +0,0 @@ -import os -import time -import jax.numpy as jnp -from jax import block_until_ready, random -from simsopt import load -from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) -from essos.coils import Coils_from_simsopt -from essos.constants import PROTON_MASS, ONE_EV -from essos.dynamics import Tracing, Particles -from essos.fields import BiotSavart as BiotSavart_essos -import matplotlib.pyplot as plt -from diffrax import Dopri8 - -tmax_full = 1e-5 -nparticles = 3 -axis_shft=0.02 -R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) -trace_tolerance_SIMSOPT_array = [1e-3, 1e-5, 1e-7, 1e-9]#, 1e-11] -trace_tolerance_ESSOS = 1e-5 -mass=PROTON_MASS -energy=5000*ONE_EV -method_ESSOS_array = ['Boris', Dopri8] - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) - -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) - - -phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] - -particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy, field=field_essos) - -# Trace in SIMSOPT -time_SIMSOPT_array = [] -trajectories_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 -relative_energy_error_SIMSOPT_array = [] -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: - print(f' Tracing SIMSOPT full orbit with tolerance={trace_tolerance_SIMSOPT}') - t1 = time.time() - trajectories_SIMSOPT_this_tolerance, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( - field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, - parallel_speeds=particles.initial_vparallel, tmax=tmax_full, mode='full', - charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) - time_SIMSOPT_array.append(time.time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in trajectories_SIMSOPT_this_tolerance])//nparticles - print(f" Time for SIMSOPT tracing={time.time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - trajectories_SIMSOPT_array.append(trajectories_SIMSOPT_this_tolerance) - - relative_energy_error_SIMSOPT_array.append([jnp.abs(mass*(trajectory[:,4]**2+trajectory[:,5]**2+trajectory[:,6]**2)/2-particles.energy)/particles.energy - for trajectory in trajectories_SIMSOPT_this_tolerance]) - -particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'full_orbit_SIMSOPT')) - - -# Trace in ESSOS -num_steps_essos = int(jnp.max(jnp.array([len(trajectories_SIMSOPT[0]) for trajectories_SIMSOPT in trajectories_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_full, num_steps_essos) - - -tracing_array = [] -trajectories_ESSOS_array = [] -time_ESSOS_array = [] -for method_ESSOS in method_ESSOS_array: - print(f'Tracing ESSOS full orbit '+('Boris' if method_ESSOS=='Boris' else f'with tolerance={trace_tolerance_ESSOS}')+f' and plotting the result.') - t1 = time.time() - tracing = block_until_ready(Tracing('FullOrbit', field_essos, tmax_full, method=method_ESSOS, particles=particles, - timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) - trajectories_ESSOS = tracing.trajectories - time_ESSOS = time.time()-t1 - print(f" Time for ESSOS tracing={time.time()-t1:.3f}s "+('Boris' if method_ESSOS=='Boris' else f'')+f". Num steps={len(trajectories_ESSOS[0])}") - tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='Boris' else '')+'_ESSOS')) - tracing_array.append(tracing) - trajectories_ESSOS_array.append(trajectories_ESSOS) - time_ESSOS_array.append(time_ESSOS) - -print('Plotting the results to output directory...') -plt.figure() -SIMSOPT_energy_interp_this_particle = jnp.zeros((len(trace_tolerance_SIMSOPT_array), nparticles, len(trajectories_SIMSOPT_array[-1][-1][:,0]))) -for j in range(nparticles): - for i, relative_energy_error_SIMSOPT in enumerate(relative_energy_error_SIMSOPT_array): - SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) -for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): - plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') -for method_ESSOS, tracing, trajectories_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array): - relative_energy_error_ESSOS = jnp.abs(tracing.energy()-particles.energy)/particles.energy - plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS'+(' Boris' if method_ESSOS=='Boris' else f' Tol={trace_tolerance_ESSOS}')) -plt.legend() -plt.yscale('log') -plt.xlabel('Time (s)') -plt.ylabel('Average Relative Energy Error') -plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_full_orbit_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -labels = [f'SIMSOPT Tol={tol}' for tol in trace_tolerance_SIMSOPT_array] -times = time_SIMSOPT_array -plt.figure() -for method_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): - # Plot time comparison in a bar chart - labels += ([f'ESSOS Boris Algorithm'] if method_ESSOS=='FullOrbit_Boris' else [f'ESSOS Tol={trace_tolerance_ESSOS}']) - times += [time_ESSOS] -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red', 'orange'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black']*2, hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']*2) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -red_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -orange_patch = plt.Line2D([0], [0], color='orange', lw=4, label=f'ESSOS\nBoris Algorithm') -plt.legend(handles=[blue_patch, red_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_full_orbit'+('_boris' if method_ESSOS=='Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): - time_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 0] # Time values from full orbit SIMSOPT - # coords_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from full orbit SIMSOPT - coords_ESSOS = jnp.array(trajectory_ESSOS) - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - interp_vx = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 3]) - interp_vy = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 4]) - interp_vz = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 5]) - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_vx, interp_vy, interp_vz]) - return coords_ESSOS_interp - -for method_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(method_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): - - relative_error_array = [] - for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): - trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] - tracing.trajectories = trajectories_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_ESSOS_interp')) - - relative_error_trajectories_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nparticles): - this_trajectory_SIMSOPT = jnp.array(trajectories_SIMSOPT[j])[:,1:] - this_trajectory_ESSOS = trajectories_ESSOS_interp[j] - average_relative_error = [] - for trajectory_SIMSOPT_t, trajectory_ESSOS_t in zip(this_trajectory_SIMSOPT, this_trajectory_ESSOS): - relative_error_x = jnp.abs(trajectory_SIMSOPT_t[0] - trajectory_ESSOS_t[0])/(jnp.abs(trajectory_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(trajectory_SIMSOPT_t[1] - trajectory_ESSOS_t[1])/(jnp.abs(trajectory_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(trajectory_SIMSOPT_t[2] - trajectory_ESSOS_t[2])/(jnp.abs(trajectory_SIMSOPT_t[2])+1e-12) - relative_error_vx = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[3])+1e-12) - relative_error_vy = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[4])+1e-12) - relative_error_vz = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[5])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z + relative_error_vx + relative_error_vy + relative_error_vz)/6) - average_relative_error = jnp.array(average_relative_error) - relative_error_trajectories_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_full, len(average_relative_error))[1:], average_relative_error[1:], label=f'Particle {1+j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+f'_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nparticles): - R_SIMSOPT = jnp.sqrt(trajectories_SIMSOPT[j][:,1]**2+trajectories_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(trajectories_SIMSOPT[j][:,2], trajectories_SIMSOPT[j][:,1]) - Z_SIMSOPT = trajectories_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(trajectories_ESSOS_interp[j][:,0]**2+trajectories_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(trajectories_ESSOS_interp[j][:,1], trajectories_ESSOS_interp[j][:,0]) - Z_ESSOS = trajectories_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+f'_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - plt.figure() - for j in range(nparticles): - time_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,0]) - vx_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,4]) - vx_ESSOS = jnp.array(trajectories_ESSOS_interp[j][:,3]) - # plt.plot(time_SIMSOPT, jnp.abs((vx_SIMSOPT-vx_ESSOS)/vx_SIMSOPT), '-', linewidth=2.5, label=f'Particle {1+j}') - plt.plot(time_SIMSOPT, vx_SIMSOPT/particles.total_speed, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(time_SIMSOPT, vx_ESSOS/particles.total_speed, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('Time (s)') - plt.ylabel(r'$v_x/v$') - # plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+f'_vx_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - # Calculate RMS error for each tolerance - rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - - # Plot RMS error in a bar chart - plt.figure() - bar_width = 0.15 - x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) - for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Particle {1+i}') - plt.xlabel('Tracing Tolerance of SIMSOPT') - plt.ylabel('RMS Error') - plt.yscale('log') - plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'rms_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - plt.close() - - # Calculate maximum error for each tolerance - max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) - # Plot maximum error in a bar chart - plt.figure() - bar_width = 0.15 - x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) - for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Particle {1+i}') - plt.xlabel('Tracing Tolerance of SIMSOPT') - plt.ylabel('Maximum Error') - plt.yscale('log') - plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'max_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - plt.close() - - # Calculate mean error for each tolerance - mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) - # Plot mean error in a bar chart - plt.figure() - bar_width = 0.15 - x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) - for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Particle {1+i}') - plt.xlabel('Tracing Tolerance of SIMSOPT') - plt.ylabel('Mean Error') - plt.yscale('log') - plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'mean_error_full_orbit'+('_boris' if method_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - plt.close() diff --git a/examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index eb102a7..0000000 --- a/examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,248 +0,0 @@ -import os -import time -import jax.numpy as jnp -from jax import block_until_ready, random -from simsopt import load -from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) -from essos.coils import Coils_from_simsopt -from essos.constants import PROTON_MASS, ONE_EV -from essos.dynamics import Tracing, Particles -from essos.fields import BiotSavart as BiotSavart_essos -import matplotlib.pyplot as plt - -tmax_gc = 1e-4 -nparticles = 5 -axis_shft=0.02 -R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) -trace_tolerance_SIMSOPT_array = [1e-5, 1e-7, 1e-9, 1e-11] -trace_tolerance_ESSOS = 1e-7 -mass=PROTON_MASS -energy=5000*ONE_EV - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) - -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) - -phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] - -particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) - -# Trace in SIMSOPT -time_SIMSOPT_array = [] -trajectories_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 -relative_energy_error_SIMSOPT_array = [] -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: - print(f'Tracing SIMSOPT guiding center with tolerance={trace_tolerance_SIMSOPT}') - t1 = time.time() - trajectories_SIMSOPT_this_tolerance, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( - field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, - parallel_speeds=particles.initial_vparallel, tmax=tmax_gc, mode='gc_vac', - charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) - time_SIMSOPT_array.append(time.time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in trajectories_SIMSOPT_this_tolerance])//nparticles - print(f" Time for SIMSOPT tracing={time.time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - trajectories_SIMSOPT_array.append(trajectories_SIMSOPT_this_tolerance) - - relative_energy_SIMSOPT = [] - for i, trajectory in enumerate(trajectories_SIMSOPT_this_tolerance): - xyz = jnp.asarray(trajectory[:, 1:4]) - vpar = trajectory[:, 4] - field_simsopt.set_points(xyz) - AbsB = field_simsopt.AbsB()[:,0] - mu = (particles.energy - particles.mass*vpar[0]**2/2)/AbsB[0] - relative_energy_SIMSOPT.append(jnp.abs(particles.mass*vpar**2/2+mu*AbsB-particles.energy)/particles.energy) - relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) - -particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) - -# Trace in ESSOS -num_steps_essos = int(jnp.mean(jnp.array([len(trajectories_SIMSOPT[0]) for trajectories_SIMSOPT in trajectories_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_gc, num_steps_essos) - -print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') -t1 = time.time() -tracing = block_until_ready(Tracing(field=field_essos, model='GuidingCenter', particles=particles, - maxtime=tmax_gc, timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) -trajectories_ESSOS = tracing.trajectories -time_ESSOS = time.time()-t1 -print(f" Time for ESSOS tracing={time.time()-t1:.3f}s. Num steps={len(trajectories_ESSOS[0])}") -tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) - -relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy - -print('Plotting the results to output directory...') -plt.figure() -SIMSOPT_energy_interp_this_particle = jnp.zeros((len(trace_tolerance_SIMSOPT_array), nparticles, len(trajectories_SIMSOPT_array[-1][-1][:,0]))) -for j in range(nparticles): - for i, relative_energy_error_SIMSOPT in enumerate(relative_energy_error_SIMSOPT_array): - SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) -plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS Tol={trace_tolerance_ESSOS}') -for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): - plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') -plt.legend() -plt.yscale('log') -plt.xlabel('Time (s)') -plt.ylabel('Average Relative Energy Error') -plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Plot time comparison in a bar chart -labels = [f'SIMSOPT\nTol={tol}' for tol in trace_tolerance_SIMSOPT_array] + [f'ESSOS\nTol={trace_tolerance_ESSOS}'] -times = time_SIMSOPT_array + [time_ESSOS] -plt.figure() -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black'], hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -orange_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -plt.legend(handles=[blue_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): - time_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 0] # Time values from guiding center SIMSOPT - # coords_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from guiding center SIMSOPT - coords_ESSOS = jnp.array(trajectory_ESSOS) - - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - interp_v = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 3]) - - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_v]) - - return coords_ESSOS_interp - -relative_error_array = [] -for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): - trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] - tracing.trajectories = trajectories_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS_interp')) - - relative_error_trajectories_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nparticles): - this_trajectory_SIMSOPT = jnp.array(trajectories_SIMSOPT[j])[:,1:] - this_trajectory_ESSOS = trajectories_ESSOS_interp[j] - average_relative_error = [] - for trajectory_SIMSOPT_t, trajectory_ESSOS_t in zip(this_trajectory_SIMSOPT, this_trajectory_ESSOS): - relative_error_x = jnp.abs(trajectory_SIMSOPT_t[0] - trajectory_ESSOS_t[0])/(jnp.abs(trajectory_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(trajectory_SIMSOPT_t[1] - trajectory_ESSOS_t[1])/(jnp.abs(trajectory_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(trajectory_SIMSOPT_t[2] - trajectory_ESSOS_t[2])/(jnp.abs(trajectory_SIMSOPT_t[2])+1e-12) - relative_error_v = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[3])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z + relative_error_v)/4) - average_relative_error = jnp.array(average_relative_error) - relative_error_trajectories_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_gc, len(average_relative_error))[1:], average_relative_error[1:], label=f'Particle {1+j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_guiding_center_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nparticles): - R_SIMSOPT = jnp.sqrt(trajectories_SIMSOPT[j][:,1]**2+trajectories_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(trajectories_SIMSOPT[j][:,2], trajectories_SIMSOPT[j][:,1]) - Z_SIMSOPT = trajectories_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(trajectories_ESSOS_interp[j][:,0]**2+trajectories_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(trajectories_ESSOS_interp[j][:,1], trajectories_ESSOS_interp[j][:,0]) - Z_ESSOS = trajectories_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'guiding_center_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - plt.figure() - for j in range(nparticles): - time_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,0]) - vpar_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,4]) - vpar_ESSOS = jnp.array(trajectories_ESSOS_interp[j][:,3]) - # plt.plot(time_SIMSOPT, jnp.abs((vpar_SIMSOPT-vpar_ESSOS)/vpar_SIMSOPT), '-', linewidth=2.5, label=f'Particle {1+j}') - plt.plot(time_SIMSOPT, vpar_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(time_SIMSOPT, vpar_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('Time (s)') - plt.ylabel(r'$v_{\parallel}/v$') - # plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'guiding_center_vpar_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - -# Calculate RMS error for each tolerance -rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - -# Plot RMS error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('RMS Error') -plt.yscale('log') -plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'rms_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate maximum error for each tolerance -max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot maximum error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Maximum Error') -plt.yscale('log') -plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'max_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate mean error for each tolerance -mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot mean error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Mean Error') -plt.yscale('log') -plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'mean_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() \ No newline at end of file diff --git a/examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index 7e1780b..0000000 --- a/examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -from time import time -import matplotlib.pyplot as plt -from jax import vmap -import jax.numpy as jnp -from essos.coils import Coils, CreateEquallySpacedCurves -from essos.fields import Vmec, BiotSavart -from essos.surfaces import B_on_surface, BdotN_over_B, SurfaceRZFourier as SurfaceRZFourier_ESSOS -from simsopt.field import BiotSavart as BiotSavart_simsopt -from simsopt.geo import SurfaceRZFourier as SurfaceRZFourier_SIMSOPT -from simsopt.objectives import SquaredFlux - -# Optimization parameters -max_coil_length = 42 -order_Fourier_series_coils = 4 -number_coil_points = 50 -function_evaluations_array = [30]*1 -diff_step_array = [1e-2]*1 -number_coils_per_half_field_period = 3 - -ntheta = 36 -nphi = 32 - -# Initialize VMEC field -vmec_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', - 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') -vmec = Vmec(vmec_file, ntheta=ntheta, nphi=nphi, close=False) - -# Initialize coils -current_on_each_coil = 1 -number_of_field_periods = vmec.nfp -major_radius_coils = vmec.r_axis -minor_radius_coils = vmec.r_axis/1.5 -curves_essos = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) -field_essos = BiotSavart(coils_essos) -surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) -# surface_essos.to_vtk("essos_surface") - -coils_simsopt = coils_essos.to_simsopt() -curves_simsopt = curves_essos.to_simsopt() -field_simsopt = BiotSavart_simsopt(coils_simsopt) -surface_simsopt = SurfaceRZFourier_SIMSOPT.from_wout(vmec_file, range="full torus", nphi=nphi, ntheta=ntheta) -field_simsopt.set_points(surface_simsopt.gamma().reshape((-1, 3))) -# surface_simsopt.to_vtk("simsopt_surface") - -print("Gamma") -print(jnp.sum(jnp.abs(surface_simsopt.gamma()-surface_essos.gamma))) - -print('Gamma dash theta') -print(jnp.sum(jnp.abs(surface_simsopt.gammadash2()-surface_essos.gammadash_theta))) - -print('Gamma dash phi') -print(jnp.sum(jnp.abs(surface_simsopt.gammadash1()-surface_essos.gammadash_phi))) - -print('Normal') -print(jnp.sum(jnp.abs(surface_simsopt.normal()-surface_essos.normal))) - -print('Unit normal') -print(jnp.sum(jnp.abs(surface_simsopt.unitnormal()-surface_essos.unitnormal))) - -BdotN_over_B_SIMSOPT = SquaredFlux(surface_simsopt, field_simsopt, definition="normalized").J() -BdotN_over_B_ESSOS = BdotN_over_B(surface_essos, field_essos) - -B_on_surface_simsopt = field_simsopt.B().reshape(surface_simsopt.normal().shape) -B_on_surface_ESSOS = B_on_surface(surface_essos, field_essos) -# print("ESSOS: ", BdotN_over_B_ESSOS) -# print("SIMSOPT: ", BdotN_over_B_SIMSOPT) diff --git a/examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index 8c0c1d2..0000000 --- a/examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -from time import time -import jax.numpy as jnp -import matplotlib.pyplot as plt -from jax import block_until_ready, random -from essos.fields import Vmec as Vmec_essos -from simsopt.mhd import Vmec as Vmec_simsopt, vmec_compute_geometry - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -wout_array = [os.path.join(os.path.dirname(__file__), '..', 'input_files', "wout_LandremanPaul2021_QA_reactorScale_lowres.nc"), - os.path.join(os.path.dirname(__file__), '..', 'input_files', "wout_n3are_R7.75B5.7.nc")] -name_array = ["LandremanPaulQA", 'NCSX'] - -print(f'Output being saved to {output_dir}') -for name, wout in zip(name_array, wout_array): - print(f' Running comparison with VMEC file located at: {wout}') - - vmec_essos = Vmec_essos(wout) - vmec_simsopt = Vmec_simsopt(wout) - - s_array=jnp.linspace(0.2, 0.9, 10) - key = random.key(42) - - def absB_simsopt_func(s, theta, phi): - return vmec_compute_geometry(vmec_simsopt, s, theta, phi).modB[0][0][0] - def absB_essos_func(s, theta, phi): - return vmec_essos.AbsB([s, theta, phi]) - def B_simsopt_func(s, theta, phi): - g = vmec_compute_geometry(vmec_simsopt, s, theta, phi) - return jnp.array([g.B_sub_s * g.grad_s_X + g.B_sub_theta_vmec * g.grad_theta_vmec_X + g.B_sub_phi * g.grad_phi_X, - g.B_sub_s * g.grad_s_Y + g.B_sub_theta_vmec * g.grad_theta_vmec_Y + g.B_sub_phi * g.grad_phi_Y, - g.B_sub_s * g.grad_s_Z + g.B_sub_theta_vmec * g.grad_theta_vmec_Z + g.B_sub_phi * g.grad_phi_Z])[:,0,0,0] - def B_essos_func(s, theta, phi): - return vmec_essos.B([s, theta, phi]) - - def timed_B(s, function): - theta = random.uniform(key=key, minval=0, maxval=2 * jnp.pi) - phi = random.uniform(key=key, minval=0, maxval=2 * jnp.pi) - function(s, theta, phi) - time1 = time() - B = block_until_ready(function(s, theta, phi)) - time_taken = time()-time1 - return time_taken, B - - average_time_modB_simsopt = 0 - average_time_modB_essos = 0 - average_time_B_essos = 0 - average_time_B_simsopt = 0 - error_modB = 0 - error_B = 0 - for s in s_array: - time_modB_simsopt, modB_simsopt = timed_B(s, absB_simsopt_func) - average_time_modB_simsopt += time_modB_simsopt - - time_modB_essos, modB_essos = timed_B(s, absB_essos_func) - average_time_modB_essos += time_modB_essos - - time_B_essos, B_essos = timed_B(s, B_essos_func) - average_time_B_essos += time_B_essos - - time_B_simsopt, B_simsopt = timed_B(s, B_simsopt_func) - average_time_B_simsopt += time_B_simsopt - - error_modB += jnp.abs((modB_simsopt-modB_essos)/modB_simsopt) - error_B += jnp.abs((B_simsopt-B_essos)/B_simsopt) - - average_time_modB_simsopt /= len(s_array) - average_time_modB_essos /= len(s_array) - average_time_B_essos /= len(s_array) - average_time_B_simsopt /= len(s_array) - - fig = plt.figure(figsize = (8, 6)) - X_axis = jnp.arange(4) - Y_axis = [average_time_modB_simsopt, average_time_B_simsopt, average_time_modB_essos, average_time_B_essos] - colors = ['blue', 'blue', 'red', 'red'] - hatches = ['/', '\\', '/', '\\'] - bars = plt.bar(X_axis, Y_axis, width=0.4, color=colors) - for bar, hatch in zip(bars, hatches): bar.set_hatch(hatch) - plt.xticks(X_axis, [r"$|\boldsymbol{B}|$ SIMSOPT", r"$\boldsymbol{B}$ SIMSOPT", r"$|\boldsymbol{B}|$ ESSOS", r"$\boldsymbol{B}$ ESSOS"], fontsize=16) - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.ylabel("Time to evaluate VMEC field (s)", fontsize=14) - plt.grid(axis='y') - plt.yscale("log") - plt.ylim(1e-6, 1) - plt.title(name, fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"time_VMEC_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() - - fig = plt.figure(figsize = (8, 6)) - X_axis = jnp.arange(2) - Y_axis = [jnp.mean(error_modB), jnp.mean(error_B)] - colors = ['purple', 'orange'] - hatches = ['/', '//'] - bars = plt.bar(X_axis, Y_axis, width=0.4, color=colors) - for bar, hatch in zip(bars, hatches): bar.set_hatch(hatch) - plt.xticks(X_axis, [r"$|\boldsymbol{B}|$", r"$\boldsymbol{B}$"], fontsize=16) - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.ylabel("Relative error SIMSOPT vs ESSOS (%)", fontsize=14) - plt.grid(axis='y') - plt.yscale("log") - plt.ylim(1e-6, 1e-1) - plt.title(name, fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"error_VMEC_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() \ No newline at end of file From ddb228255396a76d81d30bb8545721348e86fbd5 Mon Sep 17 00:00:00 2001 From: eduardolneto Date: Tue, 14 Oct 2025 20:42:18 +0000 Subject: [PATCH 50/63] Changing surfaces.py to correct the number of modes used, added option to scale the modes with different norms, optimization.py slightly changed to accomodate changes in surfaces. The example optimize_coils_and_surfaces.py was also changed to accomodate the changes --- essos/optimization.py | 4 +- essos/surfaces.py | 315 ++++++++++++++++---- examples/input_files/input.rotating_ellipse | 11 +- examples/input_files/input.toroidal_surface | 13 +- examples/optimize_coils_and_surface.py | 8 +- 5 files changed, 275 insertions(+), 76 deletions(-) diff --git a/essos/optimization.py b/essos/optimization.py index fb1a24b..6291fec 100644 --- a/essos/optimization.py +++ b/essos/optimization.py @@ -64,7 +64,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- dofs_currents = result.x[len_dofs_curves:-len(surface_all.x)] curves = Curves(dofs_curves, n_segments, nfp, stellsym) new_coils = Coils(curves=curves, currents=dofs_currents * coils.currents_scale) - new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) + new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta,mpol=surface_all.mpol,ntor=surface_all.ntor) new_surface.dofs = result.x[-len(surface_all.x):] return new_coils, new_surface elif 'surface_all' in kwargs and 'field_nearaxis' in kwargs and len(initial_dofs) == len(coils.x) + len(kwargs['surface_all'].x) + len(kwargs['field_nearaxis'].x): @@ -73,7 +73,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- dofs_currents = result.x[len_dofs_curves:-len(surface_all.x)-len(field_nearaxis.x)] curves = Curves(dofs_curves, n_segments, nfp, stellsym) new_coils = Coils(curves=curves, currents=dofs_currents * coils.currents_scale) - new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) + new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta,mpol=surface_all.mpol,ntor=surface_all.ntor) new_surface.dofs = result.x[-len(surface_all.x)-len(field_nearaxis.x):-len(field_nearaxis.x)] new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(result.x[-len(field_nearaxis.x):], field_nearaxis) return new_coils, new_surface, new_field_nearaxis diff --git a/essos/surfaces.py b/essos/surfaces.py index 0048e3c..5ef1d3e 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -9,6 +9,35 @@ mesh = Mesh(devices(), ("dev",)) sharding = NamedSharding(mesh, PartitionSpec("dev", None)) + +@partial(jit, static_argnames=['surface','field']) +def toroidal_flux(surface, field, idx=0) -> jnp.ndarray: + curve = surface.gamma[idx] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(field.A)(curve) + Adl = jnp.sum(A_vals * dl, axis=1) + tf = jnp.sum(Adl) + #curve = surface.gamma[idx] + #dl = surface.gammadash_theta[idx] + #A_vals = vmap(field.A)(curve) + #Adl = jnp.sum(A_vals * dl, axis=1)/surface.ntheta + #tf = jnp.sum(Adl) + return tf + +@partial(jit, static_argnames=['surface','field']) +def poloidal_flux(surface, field, idx=0) -> jnp.ndarray: + curve = surface.gamma[:,idx,:] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(field.A)(curve) + Adl = jnp.sum(A_vals * dl, axis=1) + tf = jnp.sum(Adl) + #curve = surface.gamma[:,idx,:] + #dl = surface.gammadash_phi[:,idx,:] + #A_vals = vmap(field.A)(curve) + #Adl = jnp.sum(A_vals * dl, axis=1)/surface.nphi + #tf = jnp.sum(Adl) + return tf + @partial(jit, static_argnames=['surface','field']) def B_on_surface(surface, field): ntheta = surface.ntheta @@ -20,6 +49,8 @@ def B_on_surface(surface, field): B_on_surface = B_on_surface.reshape(nphi, ntheta, 3) return B_on_surface + + @partial(jit, static_argnames=['surface','field']) def BdotN(surface, field): B_surface = B_on_surface(surface, field) @@ -47,62 +78,72 @@ def nested_lists_to_array(ll): for jm, l in enumerate(ll): arr = arr.at[jm, :len(l)].set(jnp.array([x if x is not None else 0 for x in l])) return arr + class SurfaceRZFourier: def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus', - rc=None, zs=None, nfp=None): + rc=None, zs=None, nfp=None, mpol=None, ntor=None,rescaling_type=None,rescaling_factor=None): if rc is not None: self.rc = rc self.zs = zs self.nfp = nfp - self.mpol = rc.shape[0] - self.ntor = (rc.shape[1] - 1) // 2 - m1d = jnp.arange(self.mpol) - n1d = jnp.arange(-self.ntor, self.ntor + 1) - n2d, m2d = jnp.meshgrid(n1d, m1d) - self.xm = m2d.flatten()[self.ntor:] - self.xn = self.nfp*n2d.flatten()[self.ntor:] - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] + self.mpol = mpol + self.ntor = ntor + #m1d = jnp.tile(jnp.arange(-self.ntor, self.ntor + 1),self.mpol) + #n1d = jnp.arange(-self.ntor, self.ntor + 1) + #n2d, m2d = jnp.meshgrid(n1d, m1d) + self.xm = jnp.repeat(jnp.arange(self.mpol+1), 2*self.ntor+1)[self.ntor:]#m2d.flatten()[self.ntor:] + self.xn = self.nfp*jnp.tile(jnp.arange(-self.ntor, self.ntor + 1), self.mpol+1)[self.ntor:]#m2d.flatten()[self.ntor:] + #indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T + self.rmnc_interp = self.rc + self.zmns_interp = self.zs elif isinstance(vmec, str): self.input_filename = vmec import f90nml - all_namelists = f90nml.read(vmec) + all_namelists = f90nml.Parser().read(vmec) nml = all_namelists['indata'] if 'nfp' in nml: self.nfp = nml['nfp'] else: self.nfp = 1 - rc = nested_lists_to_array(nml['rbc']) - zs = nested_lists_to_array(nml['zbs']) - rbc_first_n = nml.start_index['rbc'][0] - rbc_last_n = rbc_first_n + rc.shape[1] - 1 - zbs_first_n = nml.start_index['zbs'][0] - zbs_last_n = zbs_first_n + zs.shape[1] - 1 - self.ntor = jnp.max(jnp.abs(jnp.array([rbc_first_n, rbc_last_n, zbs_first_n, zbs_last_n], dtype='i'))) - rbc_first_m = nml.start_index['rbc'][1] - rbc_last_m = rbc_first_m + rc.shape[0] - 1 - zbs_first_m = nml.start_index['zbs'][1] - zbs_last_m = zbs_first_m + zs.shape[0] - 1 - self.mpol = max(rbc_last_m, zbs_last_m) - self.rc = jnp.zeros((self.mpol, 2 * self.ntor + 1)) - self.zs = jnp.zeros((self.mpol, 2 * self.ntor + 1)) - m_indices_rc = jnp.arange(rc.shape[0]) + nml.start_index['rbc'][1] - n_indices_rc = jnp.arange(rc.shape[1]) + nml.start_index['rbc'][0] + self.ntor - self.rc = self.rc.at[m_indices_rc[:, None], n_indices_rc].set(rc) - m_indices_zs = jnp.arange(zs.shape[0]) + nml.start_index['zbs'][1] - n_indices_zs = jnp.arange(zs.shape[1]) + nml.start_index['zbs'][0] + self.ntor - self.zs = self.zs.at[m_indices_zs[:, None], n_indices_zs].set(zs) - m1d = jnp.arange(self.mpol) - n1d = jnp.arange(-self.ntor, self.ntor + 1) - n2d, m2d = jnp.meshgrid(n1d, m1d) - self.xm = m2d.flatten()[self.ntor:] - self.xn = self.nfp*n2d.flatten()[self.ntor:] - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] + rc = jnp.ravel(nested_lists_to_array(nml['rbc']))[2:] + zs = jnp.ravel(nested_lists_to_array(nml['zbs']))[2:] + #rbc_first_n = nml.start_index['rbc'][0] + #rbc_last_n = rbc_first_n + rc.shape[1] - 1 + #zbs_first_n = nml.start_index['zbs'][0] + #zbs_last_n = zbs_first_n + zs.shape[1] - 1 + #self.ntor = jnp.max(jnp.abs(jnp.array([rbc_first_n, rbc_last_n, zbs_first_n, zbs_last_n], dtype='i'))) + #rbc_first_m = nml.start_index['rbc'][1] + #rbc_last_m = rbc_first_m + rc.shape[0] - 1 + #zbs_first_m = nml.start_index['zbs'][1] + #zbs_last_m = zbs_first_m + zs.shape[0] - 1 + self.ntor = nml['ntor'] + self.mpol = nml['mpol'] + self.rc = jnp.zeros((self.mpol*( 2 * self.ntor + 1)-self.ntor)) + self.zs = jnp.zeros((self.mpol*( 2 * self.ntor + 1)-self.ntor)) + #self.rc = jnp.zeros((self.mpol, 2 * self.ntor + 1)) + #self.zs = jnp.zeros((self.mpol, 2 * self.ntor + 1)) + #m_indices_rc = jnp.arange(rc.shape[0]) + nml.start_index['rbc'][1] + #n_indices_rc = jnp.arange(rc.shape[1]) + nml.start_index['rbc'][0] + self.ntor + #self.rc = self.rc.at[m_indices_rc[:, None], n_indices_rc].set(rc) + #m_indices_zs = jnp.arange(zs.shape[0]) + nml.start_index['zbs'][1] + #n_indices_zs = jnp.arange(zs.shape[1]) + nml.start_index['zbs'][0] + self.ntor + #self.zs = self.zs.at[m_indices_zs[:, None], n_indices_zs].set(zs) + #m1d = jnp.arange(self.mpol) + #n1d = jnp.arange(-self.ntor, self.ntor + 1) + #n2d, m2d = jnp.meshgrid(n1d, m1d) + #self.xm = m2d.flatten()[self.ntor:] + #self.xn = self.nfp*n2d.flatten()[self.ntor:] + #indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T + #self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] + #self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] + self.rc=rc + self.zs=zs + self.rmnc_interp = self.rc + self.zmns_interp = self.zs + self.xm = jnp.repeat(jnp.arange(self.mpol+1), 2*self.ntor+1)[self.ntor:]#m2d.flatten()[self.ntor:] + self.xn = self.nfp*jnp.tile(jnp.arange(-self.ntor, self.ntor + 1), self.mpol+1)[self.ntor:]#m2d.flatten()[self.ntor:] else: try: self.nfp = vmec.nfp @@ -124,13 +165,16 @@ def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus=' self.bmnc_interp = vmap(lambda row: jnp.interp(s, self.s_half_grid, row, left='extrapolate'), in_axes=1)(self.bmnc[1:, :]) self.mpol = vmec.mpol self.ntor = vmec.ntor - self.num_dofs = 2 * (self.mpol + 1) * (2 * self.ntor + 1) - self.ntor - (self.ntor + 1) - shape = (int(jnp.max(self.xm)) + 1, int(jnp.max(self.xn)) + 1) - self.rc = jnp.zeros(shape) - self.zs = jnp.zeros(shape) + self.num_dofs = 2 * ((self.mpol + 1) * (2 * self.ntor + 1) - self.ntor ) + #shape = (int(jnp.max(self.xm)) + 1, int(jnp.max(self.xn)) + 1) + #self.rc = jnp.zeros(shape) + #self.zs = jnp.zeros(shape) indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rc = self.rc.at[indices[:, 0], indices[:, 1]].set(self.rmnc_interp) - self.zs = self.zs.at[indices[:, 0], indices[:, 1]].set(self.zmns_interp) + self.rc = self.rmnc_interp + self.zs = self.zmns_interp + #self.zs = self.zs.at[indices[:, 0], indices[:, 1]].set(self.zmns_interp) + #self.rc = self.rc.at[indices[:, 0], indices[:, 1]].set(self.rmnc_interp) + #self.zs = self.zs.at[indices[:, 0], indices[:, 1]].set(self.zmns_interp) except: raise ValueError("vmec must be a Vmec object or a string pointing to a VMEC input file.") self.ntheta = ntheta @@ -143,10 +187,25 @@ def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus=' self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=self.ntheta, endpoint=True if close else False) self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=self.nphi, endpoint=True if close else False) self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) - self.num_dofs_rc = len(jnp.ravel(self.rc)[self.ntor:]) - self.num_dofs_zs = len(jnp.ravel(self.zs)[self.ntor:]) - self._dofs = jnp.concatenate((jnp.ravel(self.rc)[self.ntor:], jnp.ravel(self.zs)[self.ntor:])) - + self.num_dofs_rc = len(jnp.ravel(self.rc)) + self.num_dofs_zs = len(jnp.ravel(self.zs)) + + self.rescaling_factor=rescaling_factor + if rescaling_type is None: + self.rescaling_function=lambda x: x + self.unscaling_function=lambda x: x + elif rescaling_type=='L_infty': + self.rescaling_function=self.scaling_L_infty + self.unscaling_function=self.unscaling_L_infty + elif rescaling_type=='L_1': + self.rescaling_function=self.scaling_L_1 + self.unscaling_function=self.unscaling_L_1 + elif rescaling_type=='L_2': + self.rescaling_function=self.scaling_L_2 + self.unscaling_function=self.unscaling_L_2 + + self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) + self.angles = jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) (self._gamma, self._gammadash_theta, self._gammadash_phi, @@ -160,13 +219,21 @@ def dofs(self): return self._dofs @dofs.setter - def dofs(self, new_dofs): - self._dofs = new_dofs - self.rc = jnp.concatenate((jnp.zeros(self.ntor),new_dofs[:self.num_dofs_rc])).reshape(self.rc.shape) - self.zs = jnp.concatenate((jnp.zeros(self.ntor),new_dofs[self.num_dofs_rc:])).reshape(self.zs.shape) + def dofs(self, new_dofs,scaled=True): + if scaled==True: + self._dofs = new_dofs + else: + self._dofs = self.rescaling_function(new_dofs) + if scaled==True: + self.rc=self.unscaling_function(new_dofs)[:self.num_dofs_rc] + self.zs=self.unscaling_function(new_dofs)[self.num_dofs_rc:] + else: + self.rc = new_dofs[:self.num_dofs_rc] + self.zs = new_dofs[self.num_dofs_rc:] + indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] + self.rmnc_interp = self.rc + self.zmns_interp = self.zs (self._gamma, self._gammadash_theta, self._gammadash_phi, self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) # if hasattr(self, 'bmnc'): @@ -235,7 +302,119 @@ def x(self): @x.setter def x(self, new_dofs): self.dofs = new_dofs - + + @property + def volume(self): + + xyz = self.gamma # shape: (nphi, ntheta, 3) + n = self.normal # shape: (nphi, ntheta, 3) + + integrand = jnp.sum(xyz * n, axis=2) # dot(x, n), shape: (nphi, ntheta) + volume = jnp.mean(integrand) / 3.0 + return volume + + @property + def area(self): + #n = self.normal # (nphi, ntheta, 3) + #norm_n = jnp.linalg.norm(n, axis=2) # shape: (nphi, ntheta) + #avg_area = jnp.mean(norm_n) + #return avg_area + n = self.normal # shape: (nphi, ntheta, 3) + norm_n = jnp.linalg.norm(n, axis=2) + + dphi = 2 * jnp.pi / self.nphi + dtheta = 2 * jnp.pi / self.ntheta + + area = jnp.sum(norm_n) * dphi * dtheta + return area + + def scaling_L_infty(self,x): + return x / jnp.exp(-self.rescaling_factor*jnp.maximum(jnp.abs(self.xm),jnp.abs(self.xn))) + + def scaling_L_1(self,x): + return x / jnp.exp(-self.rescaling_factor*(jnp.abs(self.xm)+jnp.abs(self.xn))) + + def scaling_L_2(x): + return x / jnp.exp(-self.rescaling_factor*jnp.sqrt(self.xm**2+self.xn**2)) + + def unscaling_L_infty(self,x): + return x * jnp.exp(-self.rescaling_factor*jnp.maximum(jnp.abs(self.xm),jnp.abs(self.xn))) + + def unscaling_L_1(self,x): + return x * jnp.exp(-self.rescaling_factor*(jnp.abs(self.xm)+jnp.abs(self.xn))) + + def unscaling_L_2(self,x): + return x * jnp.exp(-self.rescaling_factor*jnp.sqrt(self.xm**2+self.xn**2)) + + def change_resolution(self, mpol: int, ntor: int, ntheta=None, nphi=None,close=True): + """ + Change the values of `mpol` and `ntor`. + New Fourier coefficients are zero by default. + Old coefficients outside the new range are discarded. + """ + rc_old, zs_old = self.rc, self.zs + mpol_old, ntor_old = self.mpol, self.ntor + if ntheta is not None: + self.ntheta = ntheta + else: + ntheta = self.ntheta + + if nphi is not None: + self.nphi = nphi + else: + nphi = self.nphi + + #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) + #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) + rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) + zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) + m_keep = min(mpol_old, mpol) + n_keep = min(ntor_old, ntor) + + xm_old=self.xm + xn_old=self.xn + self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] + self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] + # Copy overlapping region + for l in range(len(self.xm)): + if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: + index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp + rc_new=rc_new.at[l].set(self.rc[index]) + zs_new=zs_new.at[l].set(self.zs[index]) + + + # Update attributes + self.mpol, self.ntor = mpol, ntor + self.rc, self.zs = rc_new, zs_new + + self.rmnc_interp = self.rc + self.zmns_interp = self.zs + + # Update degrees of freedom + self.num_dofs_rc = len(jnp.ravel(self.rc)) + self.num_dofs_zs = len(jnp.ravel(self.zs)) + self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) + + # Recompute angles and geometry + if self.range_torus == 'full torus': div = 1 + else: div = self.nfp + if self.range_torus == 'half period': end_val = 0.5 + else: end_val = 1.0 + self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) + self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) + self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) + + self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) + (self._gamma, self._gammadash_theta, self._gammadash_phi, + self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + + + # Recompute AbsB if available + if hasattr(self, 'bmnc'): + self._AbsB = self._set_AbsB() + + return self + def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.") @@ -299,15 +478,11 @@ def to_vmec(self, filename): nml += 'LASYM = .FALSE.\n' nml += f'NFP = {self.nfp}\n' - for m in range(self.mpol + 1): - nmin = -self.ntor - if m == 0: - nmin = 0 - for n in range(nmin, self.ntor + 1): - rc = self.rc[m, n + self.ntor] - zs = self.zs[m, n + self.ntor] - if jnp.abs(rc) > 0 or jnp.abs(zs) > 0: - nml += f"RBC({n:4d},{m:4d}) ={rc:23.15e}, ZBS({n:4d},{m:4d}) ={zs:23.15e}\n" + # Copy overlapping region + for l in range(len(self.xm)): + rc = self.rc[l] + zs = self.zs[l] + nml += f"RBC({self.xn[l]:4d},{self.xm[l]:4d}) ={rc:23.15e}, ZBS({self.xn[l]:4d},{self.xm[l]:4d}) ={zs:23.15e}\n" nml += '/\n' with open(filename, 'w') as f: @@ -454,3 +629,11 @@ def signed_distance_from_surface_extras(xyz, surface): +def plot_scalar_on_flux_surface(surface, scalar_map): + ''' + surface: the surface object in which to plot the scalar_map + scalar_map: a scalar_map as function of theta and phi + ''' + + + diff --git a/examples/input_files/input.rotating_ellipse b/examples/input_files/input.rotating_ellipse index a35f3af..bce19ba 100644 --- a/examples/input_files/input.rotating_ellipse +++ b/examples/input_files/input.rotating_ellipse @@ -5,10 +5,17 @@ MPOL = 002 NTOR = 002 !----- Boundary Parameters (n,m) ----- - RBC( 000,000) = 10 ZBS( 000,000) = 0 - RBC( 001,000) = 1 ZBS( 001,000) = -1 + RBC( 000,000) = 10. ZBS( 000,000) = 0. + RBC( 001,000) = 1. ZBS( 001,000) = -1. + RBC( 002,000) = 0. ZBS( 002,000) = 0. + RBC( -002,001) = 0. ZBS( -002,001) = 0. RBC(-001,001) = 0.1 ZBS(-001,001) = 0.1 RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 RBC( 001,001) = -1 ZBS( 001,001) = 1 + RBC( 002,001) = 0 ZBS( 002,001) = 0 RBC(-002,002) = 1E-4 ZBS(-002,002) = 1E-4 + RBC(-001,002) = 0. ZBS(-001,002) = 0. + RBC( 000,002) = 0. ZBS( 000,002) = 0. + RBC( 001,002) = 0. ZBS( 001,002) = 0. + RBC( 002,002) = 0. ZBS( 002,002) = 0. / diff --git a/examples/input_files/input.toroidal_surface b/examples/input_files/input.toroidal_surface index 3a133b2..533ae61 100644 --- a/examples/input_files/input.toroidal_surface +++ b/examples/input_files/input.toroidal_surface @@ -1,14 +1,21 @@ !----- Runtime Parameters ----- &INDATA LASYM = F - NFP = 0001 + NFP = 0002 MPOL = 002 NTOR = 002 !----- Boundary Parameters (n,m) ----- - RBC( 000,000) = 7.75 ZBS( 000,000) = 0 + RBC( 000,000) = 10.0 ZBS( 000,000) = 0 RBC( 001,000) = 0.000001 ZBS( 001,000) = -0.000001 + RBC( 002,000) = 0. ZBS( 002,000) = 0. + RBC( -002,001) = 0. ZBS( -002,001) = 0. RBC(-001,001) = 0.000001 ZBS(-001,001) = 0.000001 - RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 + RBC( 000,001) = 0.5 ZBS( 000,001) = 0.5 RBC( 001,001) = 0.000001 ZBS( 001,001) = 0.000001 + RBC( 002,001) = 0 ZBS( 002,001) = 0 RBC(-002,002) = 1E-7 ZBS(-002,002) = 1E-7 + RBC(-001,002) = 0. ZBS(-001,002) = 0. + RBC( 000,002) = 0. ZBS( 000,002) = 0. + RBC( 001,002) = 0. ZBS( 001,002) = 0. + RBC( 002,002) = 0. ZBS( 002,002) = 0. / diff --git a/examples/optimize_coils_and_surface.py b/examples/optimize_coils_and_surface.py index 1bef5b8..93fc7b4 100644 --- a/examples/optimize_coils_and_surface.py +++ b/examples/optimize_coils_and_surface.py @@ -20,8 +20,10 @@ ntheta=30 nphi=30 +mpol=2 +ntor=2 input = os.path.join('input_files','input.rotating_ellipse') -surface_initial = SurfaceRZFourier(input, ntheta=ntheta, nphi=nphi, range_torus='half period') +surface_initial = SurfaceRZFourier(input, ntheta=ntheta, nphi=nphi, range_torus='half period',mpol=mpol,ntor=ntor) # Optimization parameters max_coil_length = 38 @@ -122,7 +124,7 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents n_segments=60, stellsym=True, max_coil_curvature=0.5, target_B_on_surface=5.7): field=field_from_dofs(x[:-len(surface_all.x)-len(field_nearaxis.x)] ,dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) - surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) + surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta,mpol=surface_all.mpol,ntor=surface_all.ntor) surface.dofs = x[-len(surface_all.x)-len(field_nearaxis.x):-len(field_nearaxis.x)] field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len(field_nearaxis.x):], field_nearaxis) @@ -233,7 +235,6 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents # tracing_optimized.plot(ax=ax2, show=False) plt.tight_layout() plt.show() - # Save the surface to a VMEC file surface_optimized.to_vmec('input.optimized') @@ -244,6 +245,7 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents surface_optimized.to_vtk('optimized_surface', field=BiotSavart(coils_optimized)) coils_optimized.to_vtk('optimized_coils') field_nearaxis_optimized.to_vtk('optimized_field_nearaxis', r=major_radius_coils/12, field=BiotSavart(coils_optimized)) + # tracing_initial.to_vtk('initial_tracing') # tracing_optimized.to_vtk('optimized_tracing') From a0374ccabf1413d5c354fb51579c83e1135f7d73 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Tue, 21 Oct 2025 18:47:52 +0200 Subject: [PATCH 51/63] Fix(coils, fields): turned field BiotSavart into a PyTree and modified Coils and Curves into correct PyTrees --- essos/coils.py | 454 ++++++++++++++++++++++++++++++------------------ essos/fields.py | 206 +++++++++++----------- 2 files changed, 390 insertions(+), 270 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index 6f94b78..c543086 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -6,31 +6,31 @@ from functools import partial from .plot import fix_matplotlib_3d -def compute_curvature(gammadash, gammadashdash): - return jnp.linalg.norm(jnp.cross(gammadash, gammadashdash, axis=1), axis=1) / jnp.linalg.norm(gammadash, axis=1)**3 - class Curves: - """ - Class to store the curves + """ Class to store the curves - ----------- Attributes: - dofs (jnp.ndarray - shape (n_indcurves, 3, 2*order+1)): Fourier Coefficients of the independent curves + dofs (jnp.ndarray - shape (n_base_curves, 3, 2*order+1)): Fourier Coefficients of the base curves n_segments (int): Number of segments to discretize the curves + quadpoints (jnp.ndarray - shape (n_segments,)): Quadrature points used to discretize the curves nfp (int): Number of field periods stellsym (bool): Stellarator symmetry order (int): Order of the Fourier series - curves jnp.ndarray - shape (n_indcurves*nfp*(1+stellsym), 3, 2*order+1)): Curves obtained by applying rotations and flipping corresponding to nfp fold rotational symmetry and optionally stellarator symmetry - gamma (jnp.array - shape (n_curves, n_segments, 3)): Discretized curves - gamma_dash (jnp.array - shape (n_curves, n_segments, 3)): Discretized curves derivatives - + n_base_curves (int): Number of base curves before applying symmetries + curves (jnp.ndarray - shape (n_base_curves*nfp*(1+stellsym), 3, 2*order+1)): Curves obtained by applying rotations and flipping corresponding to nfp fold rotational symmetry and optionally stellarator symmetry + gamma (jnp.ndarray - shape (n_curves, n_segments, 3)): Discretized curves + gamma_dash (jnp.ndarray - shape (n_curves, n_segments, 3)): Discretized curves derivatives + gamma_dashdash (jnp.ndarray - shape (n_curves, n_segments, 3)): Discretized curves second derivatives """ - def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stellsym: bool = True): - dofs = jnp.array(dofs) - # assert isinstance(dofs, jnp.ndarray), "dofs must be a jnp.ndarray" - assert dofs.ndim == 3, "dofs must be a 3D array with shape (n_curves, 3, 2*order+1)" - assert dofs.shape[1] == 3, "dofs must have shape (n_curves, 3, 2*order+1)" - assert dofs.shape[2] % 2 == 1, "dofs must have shape (n_curves, 3, 2*order+1)" + def __init__(self, + dofs: jnp.ndarray, + n_segments: int = 100, + nfp: int = 1, + stellsym: bool = True): + if hasattr(dofs, 'shape'): + assert len(dofs.shape) == 3, "dofs must be a 3D array with shape (n_curves, 3, 2*order+1)" + assert dofs.shape[1] == 3, "dofs must have shape (n_curves, 3, 2*order+1)" + assert dofs.shape[2] % 2 == 1, "dofs must have shape (n_curves, 3, 2*order+1)" assert isinstance(n_segments, int), "n_segments must be an integer" assert n_segments > 2, "n_segments must be greater than 2" assert isinstance(nfp, int), "nfp must be a positive integer" @@ -41,164 +41,153 @@ def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stell self._n_segments = n_segments self._nfp = nfp self._stellsym = stellsym - self._order = dofs.shape[2]//2 - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self.quadpoints = jnp.linspace(0, 1, self.n_segments, endpoint=False) + + self.quadpoints = jnp.linspace(0, 1, self._n_segments, endpoint=False) + self._curves = None + self._gamma = None + self._gamma_dash = None + self._gamma_dashdash = None + self._length = None + self._curvature = None + + # reset_cache method + def reset_cache(self): + self._curves = None self._gamma = None self._gamma_dash = None self._gamma_dashdash = None self._curvature = None self._length = None - self.n_base_curves=dofs.shape[0] - - def __str__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" - - def __repr__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" - - def _tree_flatten(self): - children = (self._dofs,) # arrays / dynamic values - aux_data = {"n_segments": self._n_segments, "nfp": self._nfp, "stellsym": self._stellsym} # static values - return (children, aux_data) - @classmethod - def _tree_unflatten(cls, aux_data, children): - return cls(*children, **aux_data) - - # @partial(jit, static_argnames=['self']) - def _set_gamma(self): - def fori_createdata(order_index: int, data: jnp.ndarray) -> jnp.ndarray: - return data[0] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], jnp.cos(2 * jnp.pi * order_index * self.quadpoints)), \ - data[1] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], 2*jnp.pi *order_index *jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], -2*jnp.pi *order_index *jnp.sin(2 * jnp.pi * order_index * self.quadpoints)), \ - data[2] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], -4*jnp.pi**2*order_index**2*jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], -4*jnp.pi**2*order_index**2*jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) - gamma = jnp.einsum("ij,k->ikj", self._curves[:, :, 0], jnp.ones(self.n_segments)) - gamma_dash = jnp.zeros((jnp.size(self._curves, 0), self.n_segments, 3)) - gamma_dashdash = jnp.zeros((jnp.size(self._curves, 0), self.n_segments, 3)) - gamma, gamma_dash, gamma_dashdash = fori_loop(1, self._order+1, fori_createdata, (gamma, gamma_dash, gamma_dashdash)) - length = jnp.mean(jnp.linalg.norm(gamma_dash, axis=2), axis=1) - curvature = vmap(compute_curvature)(gamma_dash, gamma_dashdash) - self._gamma = gamma - self._gamma_dash = gamma_dash - self._gamma_dashdash = gamma_dashdash - self._curvature = curvature - self._length = length - + # dofs property and setter @property def dofs(self): - return self._dofs + return jnp.array(self._dofs) @dofs.setter def dofs(self, new_dofs): - assert isinstance(new_dofs, jnp.ndarray) - assert new_dofs.ndim == 3 - assert jnp.size(new_dofs, 1) == 3 - assert jnp.size(new_dofs, 2) % 2 == 1 + self.reset_cache() self._dofs = new_dofs - self._order = jnp.size(new_dofs, 2)//2 - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() - @property - def curves(self): - return self._curves - - @property - def order(self): - return self._order - - @order.setter - def order(self, new_order): - assert isinstance(new_order, int) - assert new_order > 0 - self._dofs = jnp.pad(self.dofs, ((0, 0), (0, 0), (0, 2*(new_order-self._order)))) if new_order > self._order else self.dofs[:, :, :2*(new_order)+1] - self._order = new_order - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() - + # n_segments property and setter @property def n_segments(self): return self._n_segments @n_segments.setter def n_segments(self, new_n_segments): - assert isinstance(new_n_segments, int) - assert new_n_segments > 2 + self.reset_cache() self._n_segments = new_n_segments self.quadpoints = jnp.linspace(0, 1, self._n_segments, endpoint=False) - self._set_gamma() - + + # nfp property and setter @property def nfp(self): return self._nfp @nfp.setter def nfp(self, new_nfp): - assert isinstance(new_nfp, int) - assert new_nfp > 0 + self.reset_cache() self._nfp = new_nfp - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() - + + # stellsym property and setter @property def stellsym(self): return self._stellsym @stellsym.setter def stellsym(self, new_stellsym): - assert isinstance(new_stellsym, bool) + self.reset_cache() self._stellsym = new_stellsym - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() + # order property and setter + @property + def order(self): + return self.dofs.shape[2]//2 + + @order.setter + def order(self, new_order): + self.reset_cache() + self._dofs = jnp.pad(self.dofs, ((0,0), (0,0), (0, max(0, 2*(new_order-self.order)))))[:, :, :2*(new_order)+1] + + # n_base_curves property + @property + def n_base_curves(self): + return self.dofs.shape[0] + + # curves property + @property + def all_curves(self): + if self._curves is None: + self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) + return self._curves + + # compute_curvature static method + @staticmethod + def compute_curvature(gammadash, gammadashdash): + return jnp.linalg.norm(jnp.cross(gammadash, gammadashdash, axis=1), axis=1) / jnp.linalg.norm(gammadash, axis=1)**3 + + # _compute_gamma method + @jit + def _compute_gamma(self): + def fori_createdata(order_index: int, data: jnp.ndarray) -> jnp.ndarray: + return data[0] + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index - 1], jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index], jnp.cos(2 * jnp.pi * order_index * self.quadpoints)), \ + data[1] + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index - 1], 2*jnp.pi *order_index *jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index], -2*jnp.pi *order_index *jnp.sin(2 * jnp.pi * order_index * self.quadpoints)), \ + data[2] + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index - 1], -4*jnp.pi**2*order_index**2*jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index], -4*jnp.pi**2*order_index**2*jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) + + gamma0 = jnp.einsum("ij,k->ikj", self.curves[:, :, 0], jnp.ones(self.n_segments)) + gamma_dash0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) + gamma_dashdash0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) + + gamma, gamma_dash, gamma_dashdash = fori_loop(1, self.order+1, fori_createdata, (gamma0, gamma_dash0, gamma_dashdash0)) + return gamma, gamma_dash, gamma_dashdash + + # gamma property @property def gamma(self): if self._gamma is None: - self._set_gamma() + self._gamma, self._gamma_dash, self._gamma_dashdash = self._compute_gamma() return self._gamma - - @gamma.setter - def gamma(self, new_gamma): - self._gamma = new_gamma + # gamma_dash property @property def gamma_dash(self): if self._gamma_dash is None: - self._set_gamma() + self._gamma, self._gamma_dash, self._gamma_dashdash = self._compute_gamma() return self._gamma_dash - @gamma_dash.setter - def gamma_dash(self, new_gamma_dash): - self._gamma_dash = new_gamma_dash - - - + # gamma_dashdash property @property def gamma_dashdash(self): if self._gamma_dashdash is None: - self._set_gamma() + self._gamma, self._gamma_dash, self._gamma_dashdash = self._compute_gamma() return self._gamma_dashdash - @gamma_dashdash.setter - def gamma_dashdash(self, new_gamma_dashdash): - self._gamma_dashdash = new_gamma_dashdash - + # length property @property def length(self): if self._length is None: - self._set_gamma() + self._length = jnp.mean(jnp.linalg.norm(self.gamma_dash, axis=2), axis=1) return self._length + # curvature property @property def curvature(self): if self._curvature is None: - self._set_gamma() + self._curvature = vmap(self.compute_curvature)(self.gamma_dash, self.gamma_dashdash) return self._curvature + # magic methods + def __str__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" + + def __repr__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" + def __len__(self): - return jnp.size(self.curves, 0) + return self.curves.shape[0] def __getitem__(self, key): if isinstance(key, int): @@ -339,78 +328,196 @@ def from_simsopt(cls, simsopt_curves, nfp=1, stellsym=True): ), (len(simsopt_curves), 3, 2*simsopt_curves[0].order+1)) n_segments = len(simsopt_curves[0].quadpoints) return cls(dofs, n_segments, nfp, stellsym) + + def _tree_flatten(self): + children = (self._dofs,) # arrays / dynamic values + aux_data = {"n_segments": self._n_segments, + "nfp": self._nfp, + "stellsym": self._stellsym} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) tree_util.register_pytree_node(Curves, Curves._tree_flatten, Curves._tree_unflatten) -class Coils(Curves): +class Coils: + """ Class to store the coils + + Attributes: + curves (Curves): Curves object storing the coil geometry + dofs_currents_raw (jnp.ndarray - shape (n_base_curves,)): Non-normalized currents of the base curves + currents_scale (float): Normalization factor for the currents + dofs_currents (jnp.ndarray - shape (n_base_curves,)): Normalized currents of the base curves + currents (jnp.ndarray - shape (n_base_curves * nfp * (1 + stellsym),)): Currents obtained by applying symmetries to the base currents + dofs_curves (jnp.ndarray - shape (n_base_curves, 3, 2*order+1)): Degrees of freedom of the curves + dofs (jnp.ndarray - shape (n_base_curves * 3 * (2 * order + 1) + n_base_curves,)): Degrees of freedom of the coils (curves and normalized currents) + + """ def __init__(self, curves: Curves, currents: jnp.ndarray): - assert isinstance(curves, Curves) - currents = jnp.array(currents) - assert jnp.size(currents) == jnp.size(curves.dofs, 0) - super().__init__(curves.dofs, curves.n_segments, curves.nfp, curves.stellsym) - self._currents_scale = jnp.mean(jnp.abs(currents)) - self._dofs_currents = currents/self._currents_scale - self._currents = apply_symmetries_to_currents(self._dofs_currents*self._currents_scale, self.nfp, self.stellsym) + if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): + assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" - def __str__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ - + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ - + f"Currents scaling factor\n{self.currents_scale}\n" - - def __repr__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ - + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ - + f"Currents scaling factor\n{self.currents_scale}\n" + self.curves = curves + self._dofs_currents_raw = currents # Non-normalized base currents + self._currents_scale = None + self._dofs_currents = None + self._currents = None + + # reset_cache method + def reset_cache(self): + self._dofs_currents = None + self._currents_scale = None + self._currents = None + + # dofs_curves property and setter @property def dofs_curves(self): - return self._dofs + return self.curves.dofs @dofs_curves.setter def dofs_curves(self, new_dofs_curves): - self.dofs = new_dofs_curves + self.curves.dofs = new_dofs_curves + + # dofs_currents_raw property and setter + @property + def dofs_currents_raw(self): + return jnp.array(self._dofs_currents_raw) + + @dofs_currents_raw.setter + def dofs_currents_raw(self, new_dofs_currents_raw): + self.reset_cache() + self._dofs_currents_raw = new_dofs_currents_raw + + # currents_scale property and setter + @property + def currents_scale(self): + if self._currents_scale is None: + self._currents_scale = jnp.mean(jnp.abs(self.dofs_currents_raw)) + return self._currents_scale + + @currents_scale.setter + def currents_scale(self, new_currents_scale): + self._dofs_currents_raw = self.dofs_currents * new_currents_scale + self._currents_scale = new_currents_scale + self._currents = None + # dofs_currents property and setter @property def dofs_currents(self): + if self._dofs_currents is None: + self._dofs_currents = self.dofs_currents_raw / self.currents_scale return self._dofs_currents @dofs_currents.setter def dofs_currents(self, new_dofs_currents): - self._dofs_currents = new_dofs_currents - self._currents = apply_symmetries_to_currents(self._dofs_currents*self.currents_scale, self.nfp, self.stellsym) - + self.dofs_currents_raw = new_dofs_currents * self.currents_scale + + # currents property @property - def currents_scale(self): - return self._currents_scale + def currents(self): + if self._currents is None: + self._currents = apply_symmetries_to_currents(self.dofs_currents_raw, self.nfp, self.stellsym) + return self._currents + + # dofs property and setter + @property + def dofs(self): + return jnp.hstack([self.dofs_curves.ravel(), self.dofs_currents]) - @currents_scale.setter - def currents_scale(self, new_currents_scale): - self._currents_scale = new_currents_scale - self._currents = apply_symmetries_to_currents(self.dofs_currents*new_currents_scale, self.nfp, self.stellsym) + @dofs.setter + def dofs(self, new_dofs): + n_curve_dofs = jnp.size(self.dofs_curves) + self.dofs_curves = jnp.reshape(new_dofs[:n_curve_dofs], self.dofs_curves.shape) + self.dofs_currents = new_dofs[n_curve_dofs:] + # TODO: remove x property. This is a placeholder for compatibility with the examples that need to be updated. + # x property and setter @property def x(self): - dofs_curves = jnp.ravel(self.dofs_curves) - dofs_currents = jnp.ravel(self.dofs_currents) - return jnp.concatenate((dofs_curves, dofs_currents)) + return self.dofs @x.setter def x(self, new_dofs): - old_dofs_curves = jnp.ravel(self.dofs) - old_dofs_currents = jnp.ravel(self.dofs_currents) - new_dofs_curves = new_dofs[:old_dofs_curves.shape[0]] - new_dofs_currents = new_dofs[old_dofs_currents.shape[0]:] - self.dofs_curves = jnp.reshape(new_dofs_curves, (self.dofs_curves.shape)) - self.dofs_currents = new_dofs_currents - + self.dofs = new_dofs + + # currents property @property def currents(self): + if self._currents is None: + self._currents = apply_symmetries_to_currents(self.dofs_currents*self.currents_scale, self.nfp, self.stellsym) return self._currents + # gamma property + @property + def gamma(self): + return self.curves.gamma + + # gamma_dash property + @property + def gamma_dash(self): + return self.curves.gamma_dash + + # gamma_dashdash property + @property + def gamma_dashdash(self): + return self.curves.gamma_dashdash + + # length property + @property + def length(self): + return self.curves.length + + # curvature property + @property + def curvature(self): + return self.curves.curvature + + # nfp property + @property + def nfp(self): + return self.curves.nfp + + # stellsym property + @property + def stellsym(self): + return self.curves.stellsym + + # order property + @property + def order(self): + return self.curves.order + + # n_segments property and setter + @property + def n_segments(self): + return self.curves.n_segments + + @n_segments.setter + def n_segments(self, new_n_segments): + self.curves.n_segments = new_n_segments + + # magic methods + + def __str__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ + + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ + + f"Currents scaling factor\n{self.currents_scale}\n" + + def __repr__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ + + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ + + f"Currents scaling factor\n{self.currents_scale}\n" + + def __len__(self): + return len(self.curves) + def __getitem__(self, key): if isinstance(key, int): return Coils(Curves(jnp.expand_dims(self.curves[key], 0), self.n_segments, 1, False), jnp.expand_dims(self.currents[key], 0)) @@ -443,12 +550,6 @@ def __eq__(self, other): else: raise TypeError(f"Invalid argument type. Got {type(other)}, expected Coils.") - - def _tree_flatten(self): - children = (Curves(self.dofs, self.n_segments, self.nfp, self.stellsym), self._dofs_currents*self._currents_scale) # arrays / dynamic values - aux_data = {} # static values - return (children, aux_data) - def save_coils(self, filename: str, text=""): """ Save the coils to a file @@ -486,9 +587,15 @@ def to_json(self, filename: str): "dofs_currents": self.dofs_currents.tolist(), } import json - with open(filename, "w") as file: + with open(filename, 'w') as file: json.dump(data, file) - + + def plot(self, *args, **kwargs): + self.curves.plot(*args, **kwargs) + + def to_vtk(self, *args, **kwargs): + self.curves.to_vtk(*args, **kwargs) + @classmethod def from_simsopt(cls, simsopt_coils, nfp=1, stellsym=True): """ This assumes coils have all nfp and stellsym symmetries""" @@ -502,23 +609,37 @@ def from_simsopt(cls, simsopt_coils, nfp=1, stellsym=True): @classmethod def from_json(cls, filename: str): - """ - Create a Coils object from a json file - """ + """ Creates a Coils object from a json file""" import json with open(filename, "r") as file: data = json.load(file) curves = Curves(jnp.array(data["dofs_curves"]), data["n_segments"], data["nfp"], data["stellsym"]) currents = jnp.array(data["dofs_currents"]) return cls(curves, currents) + + def _tree_flatten(self): + children = (self.curves, self._dofs_currents_raw) # arrays / dynamic values + aux_data = {} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) tree_util.register_pytree_node(Coils, Coils._tree_flatten, Coils._tree_unflatten) -def CreateEquallySpacedCurves(n_curves: int, order: int, R: float, r: float, n_segments: int = 100, - nfp: int = 1, stellsym: bool = False) -> jnp.ndarray: +def CreateEquallySpacedCurves(n_curves: int, + order: int, + R: float, + r: float, + n_segments: int = 100, + nfp: int = 1, + stellsym: bool = False) -> Curves: + """ Creates n_curves equally spaced on a torus of major radius R and minor radius r using Fourier + representation up to the specified order.""" angles = (jnp.arange(n_curves) + 0.5) * (2 * jnp.pi) / ((1 + int(stellsym)) * nfp * n_curves) curves = jnp.zeros((n_curves, 3, 1 + 2 * order)) @@ -529,6 +650,7 @@ def CreateEquallySpacedCurves(n_curves: int, order: int, R: float, r: float, n_s curves = curves.at[:, 2, 1].set(-r) # z[1] (constant for all) return Curves(curves, n_segments=n_segments, nfp=nfp, stellsym=stellsym) +@partial(jit, static_argnames=["flip"]) def RotatedCurve(curve, phi, flip): rotmat_T = jnp.array( [[ jnp.cos(phi), jnp.sin(phi), 0], diff --git a/essos/fields.py b/essos/fields.py index deb50fd..82c2154 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -1,164 +1,162 @@ import jax jax.config.update("jax_enable_x64", True) from jax import vmap -from essos.coils import compute_curvature +from essos.coils import Curves import jax.numpy as jnp from functools import partial from jax import jit, jacfwd, grad, vmap, tree_util, lax -from essos.surfaces import SurfaceRZFourier, BdotN_over_B,SurfaceClassifier +from essos.surfaces import SurfaceRZFourier, BdotN_over_B, SurfaceClassifier from essos.plot import fix_matplotlib_3d from essos.util import newton -class BiotSavart(): - def __init__(self, coils): - self.coils = coils - self.currents = coils.currents - self.gamma = coils.gamma - self.gamma_dash = coils.gamma_dash - #self.gamma_dashdash = coils.gamma_dashdash - self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) - self.coils_curvature= vmap(compute_curvature)(self.gamma_dash, coils.gamma_dashdash) - self.r_axis=jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(self.coils.dofs_curves))) - self.z_axis=jnp.mean(vmap(lambda dofs: dofs[2, 0])(self.coils.dofs_curves)) +class MagneticField(): + def __init__(self): + pass - - @partial(jit, static_argnames=['self']) + @jit def sqrtg(self, points): - return 1. + raise NotImplementedError("sqrtg method not implemented") - @partial(jit, static_argnames=['self']) + @jit def B(self, points): - dif_R = (jnp.array(points)-self.gamma).T - dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 - dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") - return jnp.mean(dB_sum, axis=0) - - @partial(jit, static_argnames=['self']) + raise NotImplementedError("B method not implemented") + + @jit def B_covariant(self, points): return self.B(points) - - @partial(jit, static_argnames=['self']) + + @jit def B_contravariant(self, points): return self.B(points) - @partial(jit, static_argnames=['self']) + @jit def AbsB(self, points): return jnp.linalg.norm(self.B(points)) - @partial(jit, static_argnames=['self']) + @jit def dB_by_dX(self, points): return jacfwd(self.B)(points) - - @partial(jit, static_argnames=['self']) + @jit def dAbsB_by_dX(self, points): return grad(self.AbsB)(points) - @partial(jit, static_argnames=['self']) + @jit def grad_B_covariant(self, points): - return jacfwd(self.B_covariant)(points) - - @partial(jit, static_argnames=['self']) + return jacfwd(self.B_covariant)(points) + + @jit def curl_B(self, points): grad_B_cov=self.grad_B_covariant(points) - return jnp.array([grad_B_cov[2][1] -grad_B_cov[1][2], - grad_B_cov[0][2] -grad_B_cov[2][0], - grad_B_cov[1][0] -grad_B_cov[0][1]])/self.sqrtg(points) - - @partial(jit, static_argnames=['self']) - def curl_b(self, points): - return self.curl_B(points)/self.AbsB(points)+jnp.cross(self.B_covariant(points),jnp.array(self.dAbsB_by_dX(points)))/self.AbsB(points)**2/self.sqrtg(points) + return jnp.array([grad_B_cov[2][1] - grad_B_cov[1][2], + grad_B_cov[0][2] - grad_B_cov[2][0], + grad_B_cov[1][0] - grad_B_cov[0][1]])/self.sqrtg(points) - @partial(jit, static_argnames=['self']) + @jit + def curl_b(self, points): + return self.curl_B(points) / self.AbsB(points) + jnp.cross(self.B_covariant(points), jnp.array(self.dAbsB_by_dX(points))) / self.AbsB(points)**2 / self.sqrtg(points) + + @jit def kappa(self, points): - return -jnp.cross(self.B_contravariant(points),self.curl_b(points))*self.sqrtg(points)/self.AbsB(points) + return -jnp.cross(self.B_contravariant(points), self.curl_b(points)) * self.sqrtg(points) / self.AbsB(points) - @partial(jit, static_argnames=['self']) + @jit + def to_xyz(self, points): + raise NotImplementedError("to_xyz method not implemented") + +class BiotSavart(MagneticField): + def __init__(self, coils): + self.coils = coils + # self.r_axis=jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(self.coils.dofs_curves))) + # self.z_axis=jnp.mean(vmap(lambda dofs: dofs[2, 0])(self.coils.dofs_curves)) + + @property + def dofs(self): + return self.coils.dofs + @dofs.setter + def dofs(self, new_dofs): + self.coils.dofs = new_dofs + + @jit + def sqrtg(self, points): + return 1. + + @jit + def B(self, points): + dif_R = (jnp.array(points) - self.coils.gamma).T + dB = jnp.cross(self.coils.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0) / jnp.linalg.norm(dif_R, axis=0)**3 + dB_sum = jnp.einsum("i,bai", self.coils.currents*1e-7, dB, optimize="greedy") + return jnp.mean(dB_sum, axis=0) + + @jit def to_xyz(self, points): return points -# def _tree_flatten(self): -# children = (self.coils,) -# aux_data = {} -# return (children, aux_data) + def _tree_flatten(self): + children = (self.coils,) + aux_data = {} + return (children, aux_data) -# @classmethod -# def _tree_unflatten(cls, aux_data, children): -# return cls(*children, **aux_data) + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) -# tree_util.register_pytree_node(BiotSavart, -# BiotSavart._tree_flatten, -# BiotSavart._tree_unflatten) +tree_util.register_pytree_node(BiotSavart, + BiotSavart._tree_flatten, + BiotSavart._tree_unflatten) -class BiotSavart_from_gamma(): - def __init__(self, gamma,gamma_dash,gamma_dashdash, currents): +class BiotSavart_from_gamma(MagneticField): + def __init__(self, gamma, gamma_dash, gamma_dashdash, currents): self.currents = currents self.gamma = gamma self.gamma_dash = gamma_dash - #self.gamma_dashdash = gamma_dashdash - self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in gamma_dash]) - self.coils_curvature= vmap(compute_curvature)(gamma_dash, gamma_dashdash) - self.r_axis=jnp.average(jnp.linalg.norm(jnp.average(gamma,axis=1)[:,0:2],axis=1)) - self.z_axis=jnp.average(jnp.average(gamma,axis=1)[:,2]) + self.gamma_dashdash = gamma_dashdash + + self.coils_length = None + self.coils_curvature = None + self.r_axis = None + self.z_axis = None + + @property + def coils_length(self): + if self.coils_length is None: + self.coils_length = jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) + return self.coils_length + @property + def coils_curvature(self): + if self._coils_curvature is None: + self._coils_curvature = vmap(Curves.compute_curvature)(self.gamma_dash, self.gamma_dashdash) + return self._coils_curvature + + @property + def r_axis(self): + if self._r_axis is None: + self._r_axis = jnp.average(jnp.linalg.norm(jnp.average(self.gamma, axis=1)[:, 0:2], axis=1)) + return self._r_axis + + @property + def z_axis(self): + if self._z_axis is None: + self._z_axis = jnp.average(jnp.average(self.gamma, axis=1)[:, 2]) + return self._z_axis + @partial(jit, static_argnames=['self']) def sqrtg(self, points): return 1. @partial(jit, static_argnames=['self']) def B(self, points): - dif_R = (jnp.array(points)-self.gamma).T - dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 + dif_R = (jnp.array(points) - self.gamma).T + dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0) / jnp.linalg.norm(dif_R, axis=0)**3 dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") return jnp.mean(dB_sum, axis=0) - @partial(jit, static_argnames=['self']) - def B_covariant(self, points): - return self.B(points) - - @partial(jit, static_argnames=['self']) - def B_contravariant(self, points): - return self.B(points) - - @partial(jit, static_argnames=['self']) - def AbsB(self, points): - return jnp.linalg.norm(self.B(points)) - - @partial(jit, static_argnames=['self']) - def dB_by_dX(self, points): - return jacfwd(self.B)(points) - - - @partial(jit, static_argnames=['self']) - def dAbsB_by_dX(self, points): - return grad(self.AbsB)(points) - - @partial(jit, static_argnames=['self']) - def grad_B_covariant(self, points): - return jacfwd(self.B_covariant)(points) - - @partial(jit, static_argnames=['self']) - def curl_B(self, points): - grad_B_cov=self.grad_B_covariant(points) - return jnp.array([grad_B_cov[2][1] -grad_B_cov[1][2], - grad_B_cov[0][2] -grad_B_cov[2][0], - grad_B_cov[1][0] -grad_B_cov[0][1]])/self.sqrtg(points) - - @partial(jit, static_argnames=['self']) - def curl_b(self, points): - return self.curl_B(points)/self.AbsB(points)+jnp.cross(self.B_covariant(points),jnp.array(self.dAbsB_by_dX(points)))/self.AbsB(points)**2/self.sqrtg(points) - - @partial(jit, static_argnames=['self']) - def kappa(self, points): - return -jnp.cross(self.B_contravariant(points),self.curl_b(points))*self.sqrtg(points)/self.AbsB(points) - @partial(jit, static_argnames=['self']) def to_xyz(self, points): return points - - class Vmec(): def __init__(self, wout_filename, ntheta=50, nphi=50, close=True, range_torus='full torus'): self.wout_filename = wout_filename From c7378bbe799def1bf0cf53bc2c4bb14de9b8de2f Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Mon, 27 Oct 2025 19:34:45 +0100 Subject: [PATCH 52/63] Fix: minor fixes --- analysis/comparisons_simsopt/coils.py | 1 + analysis/gc_vs_fo.py | 8 ++++---- essos/objective_functions.py | 10 +++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/analysis/comparisons_simsopt/coils.py b/analysis/comparisons_simsopt/coils.py index eb4d5f8..e2248d0 100644 --- a/analysis/comparisons_simsopt/coils.py +++ b/analysis/comparisons_simsopt/coils.py @@ -80,6 +80,7 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): [curve.gammadash() for curve in curves_simsopt] [curve.gammadashdash() for curve in curves_simsopt] coils_essos.gamma + coils_essos.reset_cache() # Running the second time for coils characteristics comparison start_time = time() diff --git a/analysis/gc_vs_fo.py b/analysis/gc_vs_fo.py index 8090ae7..6a8c03f 100644 --- a/analysis/gc_vs_fo.py +++ b/analysis/gc_vs_fo.py @@ -36,7 +36,7 @@ particles = particles_passing.join(particles_traped, field=field) # Tracing parameters -tmax = 1e-3 +tmax = 1e-5 trace_tolerance = 1e-14 dt_gc = 1e-7 dt_fo = 1e-9 @@ -46,15 +46,15 @@ # Trace in ESSOS time0 = time() tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, - maxtime=tmax, timestep=num_steps_gc, atol=trace_tolerance, rtol=trace_tolerance, + maxtime=tmax, timestep=dt_gc, atol=trace_tolerance, rtol=trace_tolerance, times_to_trace=200) trajectories_guidingcenter = block_until_ready(tracing_gc.trajectories) print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") time0 = time() tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax, - timestep=num_steps_fo, atol=trace_tolerance, rtol=trace_tolerance, - times_to_trace=200) + timestep=dt_fo, atol=trace_tolerance, rtol=trace_tolerance, + times_to_trace=600) block_until_ready(tracing_fo.trajectories) print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") diff --git a/essos/objective_functions.py b/essos/objective_functions.py index a07aac4..b4366e7 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -8,7 +8,7 @@ from essos.dynamics import Tracing from essos.fields import BiotSavart,BiotSavart_from_gamma from essos.surfaces import BdotN_over_B, BdotN -from essos.coils import Curves, Coils,compute_curvature +from essos.coils import Curves, Coils from essos.optimization import new_nearaxis_from_x_and_old_nearaxis from essos.constants import mu_0 from essos.coil_perturbation import perturb_curves_systematic, perturb_curves_statistic @@ -292,9 +292,9 @@ def loss_coil_curvature(coils, max_coil_curvature=0): return jnp.mean(pointwise_curvature_loss*jnp.linalg.norm(coils.gamma_dash, axis=-1), axis=1) def compute_candidates(coils, min_separation): - centers = coils.curves[:, :, 0] - a_n = coils.curves[:, :, 2 : 2*coils.order+1 : 2] - b_n = coils.curves[:, :, 1 : 2*coils.order : 2] + centers = coils.curves.curves[:, :, 0] + a_n = coils.curves.curves[:, :, 2 : 2*coils.order+1 : 2] + b_n = coils.curves.curves[:, :, 1 : 2*coils.order : 2] radii = jnp.sum(jnp.linalg.norm(a_n, axis=1)+jnp.linalg.norm(b_n, axis=1), axis=1) i_vals, j_vals = jnp.triu_indices(len(coils), k=1) @@ -570,7 +570,7 @@ def loss_lorentz_force_coils(x,dofs_curves,currents_scale,nfp,n_segments=60,stel def lp_force_pure(index,gamma, gamma_dash,gamma_dashdash,currents,quadpoints,p, threshold): """Pure function for minimizing the Lorentz force on a coil. """ - regularization = regularization_circ(1./jnp.average(compute_curvature( gamma_dash.at[index].get(), gamma_dashdash.at[index].get()))) + regularization = regularization_circ(1./jnp.average(Curves.compute_curvature( gamma_dash.at[index].get(), gamma_dashdash.at[index].get()))) B_mutual=jax.vmap(BiotSavart_from_gamma(jnp.roll(gamma, -index, axis=0)[1:], jnp.roll(gamma_dash, -index, axis=0)[1:], jnp.roll(gamma_dashdash, -index, axis=0)[1:], From d40fa9cec9418ff104fea8c55d5f02205a74036a Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 29 Oct 2025 00:08:34 +0100 Subject: [PATCH 53/63] Fix(coils): assertion removal in Coils class; Perf(coils): vmap & separate gamma computation --- essos/coils.py | 58 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index ee84692..104b71d 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -121,46 +121,56 @@ def curves(self): if self._curves is None: self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) return self._curves - - # compute_curvature static method - @staticmethod - def compute_curvature(gammadash, gammadashdash): - return jnp.linalg.norm(jnp.cross(gammadash, gammadashdash, axis=1), axis=1) / jnp.linalg.norm(gammadash, axis=1)**3 # _compute_gamma method @jit def _compute_gamma(self): - def fori_createdata(order_index: int, data: jnp.ndarray) -> jnp.ndarray: - return data[0] + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index - 1], jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index], jnp.cos(2 * jnp.pi * order_index * self.quadpoints)), \ - data[1] + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index - 1], 2*jnp.pi *order_index *jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index], -2*jnp.pi *order_index *jnp.sin(2 * jnp.pi * order_index * self.quadpoints)), \ - data[2] + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index - 1], -4*jnp.pi**2*order_index**2*jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order_index], -4*jnp.pi**2*order_index**2*jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) - - gamma0 = jnp.einsum("ij,k->ikj", self.curves[:, :, 0], jnp.ones(self.n_segments)) - gamma_dash0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) - gamma_dashdash0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) + def create_data(order: int) -> jnp.ndarray: + return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], jnp.sin(2 * jnp.pi * order * self.quadpoints)) \ + + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], jnp.cos(2 * jnp.pi * order * self.quadpoints)) + gamma_0 = jnp.einsum("ij,k->ikj", self.curves[:, :, 0], jnp.ones(self.n_segments)) + gamma_n = vmap(create_data)(jnp.arange(1, self.order+1)) + return gamma_0 + jnp.sum(gamma_n, axis=0) - gamma, gamma_dash, gamma_dashdash = fori_loop(1, self.order+1, fori_createdata, (gamma0, gamma_dash0, gamma_dashdash0)) - return gamma, gamma_dash, gamma_dashdash - # gamma property @property def gamma(self): if self._gamma is None: - self._gamma, self._gamma_dash, self._gamma_dashdash = self._compute_gamma() + self._gamma = self._compute_gamma() return self._gamma + # _compute_gamma_dash method + @jit + def _compute_gamma_dash(self): + def create_data(order: int) -> jnp.ndarray: + return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], 2*jnp.pi * order * jnp.cos(2 * jnp.pi * order * self.quadpoints)) \ + + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], -2 * jnp.pi * order * jnp.sin(2 * jnp.pi * order * self.quadpoints)) + gamma_dash_0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) + gamma_dash_n = vmap(create_data)(jnp.arange(1, self.order+1)) + return gamma_dash_0 + jnp.sum(gamma_dash_n, axis=0) + # gamma_dash property @property def gamma_dash(self): if self._gamma_dash is None: - self._gamma, self._gamma_dash, self._gamma_dashdash = self._compute_gamma() + self._gamma_dash = self._compute_gamma_dash() return self._gamma_dash + # _compute_gamma_dashdash method + @jit + def _compute_gamma_dashdash(self): + def create_data(order: int) -> jnp.ndarray: + return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], -4*jnp.pi**2 * order**2 * jnp.sin(2 * jnp.pi * order * self.quadpoints)) \ + + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], -4*jnp.pi**2 * order**2 * jnp.cos(2 * jnp.pi * order * self.quadpoints)) + gamma_dashdash_0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) + gamma_dashdash_n = vmap(create_data)(jnp.arange(1, self.order+1)) + return gamma_dashdash_0 + jnp.sum(gamma_dashdash_n, axis=0) + # gamma_dashdash property @property def gamma_dashdash(self): if self._gamma_dashdash is None: - self._gamma, self._gamma_dash, self._gamma_dashdash = self._compute_gamma() + self._gamma_dashdash = self._compute_gamma_dashdash() return self._gamma_dashdash # length property @@ -170,6 +180,12 @@ def length(self): self._length = jnp.mean(jnp.linalg.norm(self.gamma_dash, axis=2), axis=1) return self._length + # compute_curvature static method + @staticmethod + @jit + def compute_curvature(gammadash, gammadashdash): + return jnp.linalg.norm(jnp.cross(gammadash, gammadashdash, axis=1), axis=1) / jnp.linalg.norm(gammadash, axis=1)**3 + # curvature property @property def curvature(self): @@ -358,8 +374,8 @@ class Coils: """ def __init__(self, curves: Curves, currents: jnp.ndarray): - if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): - assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" + # if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): + # assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" self.curves = curves self._dofs_currents_raw = currents # Non-normalized base currents From 37252efe8a6e4a2670dfe1565939d72407109511 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 29 Oct 2025 00:09:43 +0100 Subject: [PATCH 54/63] Fix(analysis): precompile coil gammas for comparison with simsopt --- analysis/comparisons_simsopt/coils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/analysis/comparisons_simsopt/coils.py b/analysis/comparisons_simsopt/coils.py index e2248d0..efb7017 100644 --- a/analysis/comparisons_simsopt/coils.py +++ b/analysis/comparisons_simsopt/coils.py @@ -80,6 +80,9 @@ def update_nsegments_simsopt(curve_simsopt, n_segments): [curve.gammadash() for curve in curves_simsopt] [curve.gammadashdash() for curve in curves_simsopt] coils_essos.gamma + coils_essos.gamma_dash + coils_essos.gamma_dashdash + coils_essos.curvature coils_essos.reset_cache() # Running the second time for coils characteristics comparison From df5a525957a9dab6c9286be1559d1e859f94d6e8 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 5 Nov 2025 19:18:00 +0100 Subject: [PATCH 55/63] Fix(surfaces): PyTree able Surfaces --- essos/surfaces.py | 686 ++++++++++++++++++++++++++++------------------ 1 file changed, 416 insertions(+), 270 deletions(-) diff --git a/essos/surfaces.py b/essos/surfaces.py index 998135e..e040233 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -1,7 +1,8 @@ from functools import partial +import jax import jax.numpy as jnp from jax.scipy.interpolate import RegularGridInterpolator -from jax import jit, vmap, devices, device_put +from jax import tree_util, jit, vmap, devices, device_put from jax.sharding import Mesh, NamedSharding, PartitionSpec from essos.plot import fix_matplotlib_3d import jaxkd @@ -58,7 +59,7 @@ def BdotN(surface, field): return B_dot_n @partial(jit, static_argnames=['surface','field']) -def BdotN_over_B(surface, field): +def BdotN_over_B(surface, field, **kwargs): return BdotN(surface, field) / jnp.linalg.norm(B_on_surface(surface, field), axis=2) @partial(jit, static_argnames=['surface','field']) @@ -104,224 +105,364 @@ def nested_lists_to_array(ll): class SurfaceRZFourier: - def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus', - rc=None, zs=None, nfp=None, mpol=None, ntor=None,rescaling_type=None,rescaling_factor=None): - if rc is not None: - self.rc = rc - self.zs = zs - self.nfp = nfp - self.mpol = mpol - self.ntor = ntor - #m1d = jnp.tile(jnp.arange(-self.ntor, self.ntor + 1),self.mpol) - #n1d = jnp.arange(-self.ntor, self.ntor + 1) - #n2d, m2d = jnp.meshgrid(n1d, m1d) - self.xm = jnp.repeat(jnp.arange(self.mpol+1), 2*self.ntor+1)[self.ntor:]#m2d.flatten()[self.ntor:] - self.xn = self.nfp*jnp.tile(jnp.arange(-self.ntor, self.ntor + 1), self.mpol+1)[self.ntor:]#m2d.flatten()[self.ntor:] - #indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc - self.zmns_interp = self.zs - elif isinstance(vmec, str): - self.input_filename = vmec - import f90nml - all_namelists = f90nml.Parser().read(vmec) - nml = all_namelists['indata'] - if 'nfp' in nml: - self.nfp = nml['nfp'] - else: - self.nfp = 1 - rc = jnp.ravel(nested_lists_to_array(nml['rbc']))[2:] - zs = jnp.ravel(nested_lists_to_array(nml['zbs']))[2:] - #rbc_first_n = nml.start_index['rbc'][0] - #rbc_last_n = rbc_first_n + rc.shape[1] - 1 - #zbs_first_n = nml.start_index['zbs'][0] - #zbs_last_n = zbs_first_n + zs.shape[1] - 1 - #self.ntor = jnp.max(jnp.abs(jnp.array([rbc_first_n, rbc_last_n, zbs_first_n, zbs_last_n], dtype='i'))) - #rbc_first_m = nml.start_index['rbc'][1] - #rbc_last_m = rbc_first_m + rc.shape[0] - 1 - #zbs_first_m = nml.start_index['zbs'][1] - #zbs_last_m = zbs_first_m + zs.shape[0] - 1 - self.ntor = nml['ntor'] - self.mpol = nml['mpol'] - self.rc = jnp.zeros((self.mpol*( 2 * self.ntor + 1)-self.ntor)) - self.zs = jnp.zeros((self.mpol*( 2 * self.ntor + 1)-self.ntor)) - #self.rc = jnp.zeros((self.mpol, 2 * self.ntor + 1)) - #self.zs = jnp.zeros((self.mpol, 2 * self.ntor + 1)) - #m_indices_rc = jnp.arange(rc.shape[0]) + nml.start_index['rbc'][1] - #n_indices_rc = jnp.arange(rc.shape[1]) + nml.start_index['rbc'][0] + self.ntor - #self.rc = self.rc.at[m_indices_rc[:, None], n_indices_rc].set(rc) - #m_indices_zs = jnp.arange(zs.shape[0]) + nml.start_index['zbs'][1] - #n_indices_zs = jnp.arange(zs.shape[1]) + nml.start_index['zbs'][0] + self.ntor - #self.zs = self.zs.at[m_indices_zs[:, None], n_indices_zs].set(zs) - #m1d = jnp.arange(self.mpol) - #n1d = jnp.arange(-self.ntor, self.ntor + 1) - #n2d, m2d = jnp.meshgrid(n1d, m1d) - #self.xm = m2d.flatten()[self.ntor:] - #self.xn = self.nfp*n2d.flatten()[self.ntor:] - #indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - #self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - #self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] - self.rc=rc - self.zs=zs - self.rmnc_interp = self.rc - self.zmns_interp = self.zs - self.xm = jnp.repeat(jnp.arange(self.mpol+1), 2*self.ntor+1)[self.ntor:]#m2d.flatten()[self.ntor:] - self.xn = self.nfp*jnp.tile(jnp.arange(-self.ntor, self.ntor + 1), self.mpol+1)[self.ntor:]#m2d.flatten()[self.ntor:] - else: - try: - self.nfp = vmec.nfp - self.bmnc = vmec.bmnc - self.xm = vmec.xm - self.xn = vmec.xn - self.rmnc = vmec.rmnc - self.zmns = vmec.zmns - self.xm_nyq = vmec.xm_nyq - self.xn_nyq = vmec.xn_nyq - self.len_xm_nyq = len(self.xm_nyq) - self.ns = vmec.ns - self.s_full_grid = vmec.s_full_grid - self.ds = vmec.ds - self.s_half_grid = vmec.s_half_grid - self.r_axis = vmec.r_axis - self.rmnc_interp = vmap(lambda row: jnp.interp(s, self.s_full_grid, row, left='extrapolate'), in_axes=1)(self.rmnc) - self.zmns_interp = vmap(lambda row: jnp.interp(s, self.s_full_grid, row, left='extrapolate'), in_axes=1)(self.zmns) - self.bmnc_interp = vmap(lambda row: jnp.interp(s, self.s_half_grid, row, left='extrapolate'), in_axes=1)(self.bmnc[1:, :]) - self.mpol = vmec.mpol - self.ntor = vmec.ntor - self.num_dofs = 2 * ((self.mpol + 1) * (2 * self.ntor + 1) - self.ntor ) - #shape = (int(jnp.max(self.xm)) + 1, int(jnp.max(self.xn)) + 1) - #self.rc = jnp.zeros(shape) - #self.zs = jnp.zeros(shape) - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rc = self.rmnc_interp - self.zs = self.zmns_interp - #self.zs = self.zs.at[indices[:, 0], indices[:, 1]].set(self.zmns_interp) - #self.rc = self.rc.at[indices[:, 0], indices[:, 1]].set(self.rmnc_interp) - #self.zs = self.zs.at[indices[:, 0], indices[:, 1]].set(self.zmns_interp) - except: - raise ValueError("vmec must be a Vmec object or a string pointing to a VMEC input file.") - self.ntheta = ntheta - self.nphi = nphi - self.range_torus = range_torus - if range_torus == 'full torus': div = 1 - else: div = self.nfp - if range_torus == 'half period': end_val = 0.5 - else: end_val = 1.0 - self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=self.ntheta, endpoint=True if close else False) - self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=self.nphi, endpoint=True if close else False) - self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) - self.num_dofs_rc = len(jnp.ravel(self.rc)) - self.num_dofs_zs = len(jnp.ravel(self.zs)) - - self.rescaling_factor=rescaling_factor - if rescaling_type is None: - self.rescaling_function=lambda x: x - self.unscaling_function=lambda x: x - elif rescaling_type=='L_infty': - self.rescaling_function=self.scaling_L_infty - self.unscaling_function=self.unscaling_L_infty - elif rescaling_type=='L_1': - self.rescaling_function=self.scaling_L_1 - self.unscaling_function=self.unscaling_L_1 - elif rescaling_type=='L_2': - self.rescaling_function=self.scaling_L_2 - self.unscaling_function=self.unscaling_L_2 - - self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) - - self.angles = jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) + def __init__(self, rc, zs, nfp, mpol, ntor, ntheta=30, nphi=30, close=True, range_torus='full torus', + scaling_type=2, scaling_factor=0): + """ rc, zs: dynamic arrays + nfp, mpol, ntor: static """ + + assert isinstance(nfp, int) and nfp > 0, "nfp must be a positive integer." + assert isinstance(mpol, int) and mpol >= 0, "mpol must be a non-negative integer." + assert isinstance(ntor, int) and ntor >= 0, "ntor must be a non-negative integer." + assert isinstance(ntheta, int) and ntheta > 0, "ntheta must be a positive integer." + assert isinstance(nphi, int) and nphi > 0, "nphi must be a positive integer." + assert isinstance(close, bool), "close must be a boolean." + assert range_torus in ['full torus', 'half period'], f"Unknown range_torus: {range_torus}. Choose 'full torus' or 'half period'." + + self._rc = rc + self._zs = zs + self._nfp = nfp + self._mpol = mpol + self._ntor = ntor + + self._gamma = None + self._gammadash_theta = None + self._gammadash_phi = None + self._normal = None + self._unitnormal = None + self._area_element = None + self._xm = None + self._xn = None + + self._ntheta = ntheta + self._nphi = nphi + self._close = close + self._range_torus = range_torus + + self._quadpoints_theta = None + self._quadpoints_phi = None + self._theta2d = None + self._phi2d = None + self._angles = None + + self._scaling_type = scaling_type # 1 for L-1 norm, 2 for L-2 norm, jnp.inf for L-infinity norm + self._scaling_factor = scaling_factor + self._scaling = None + + + @classmethod + def from_input_file(cls, file, ntheta=30, nphi=30, close=True, range_torus='full torus'): + from f90nml import Parser + nml = Parser().read(file)['indata'] + + nfp = nml["nfp"] if "nfp" in nml else 1 + mpol = nml['mpol'] + ntor = nml['ntor'] + + rc = jnp.ravel(nested_lists_to_array(nml['rbc']))[2:] + zs = jnp.ravel(nested_lists_to_array(nml['zbs']))[2:] + + surface = cls(rc, zs, nfp, mpol, ntor, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) + return surface - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal, self._area_element) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + @classmethod + def from_vmec(cls, vmec, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus'): + nfp = vmec.nfp + mpol = vmec.mpol + ntor = vmec.ntor + + s_full_grid = vmec.s_full_grid + rc = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(vmec.rmnc) + zs = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(vmec.zmns) + + surface = cls(rc, zs, nfp, mpol, ntor, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) + surface._xm = vmec.xm + surface._xn = vmec.xn + + return surface + + @classmethod + def from_wout_file(cls, file, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus'): + from netCDF4 import Dataset + nc = Dataset(file) + + nfp = int(nc.variables["nfp"][0]) + xm = jnp.array(nc.variables["xm"][:]) + xn = jnp.array(nc.variables["xn"][:]) + mpol = int(jnp.max(xm)+1) + ntor = int(jnp.max(jnp.abs(xn)) / nfp) - if hasattr(self, 'bmnc'): - self._AbsB = self._set_AbsB() + ns = nc.variables["ns"][0] + s_full_grid = jnp.linspace(0, 1, ns) + rc = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(jnp.array(nc.variables["rmnc"][:])) + zs = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(jnp.array(nc.variables["zmns"][:])) + + surface = cls(rc, zs, nfp, mpol, ntor, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) + surface._xm = xm + surface._xn = xn + + return surface + + # reset_cache method + def reset_cache(self): + self._gamma = None + self._gammadash_theta = None + self._gammadash_phi = None + self._normal = None + self._unitnormal = None + self._area_element = None + self._xm = None + self._xn = None + self._angles = None + + # reset_mesh method + def reset_mesh(self): + self._quadpoints_theta = None + self._quadpoints_phi = None + self._theta2d = None + self._phi2d = None + self._angles = None + + # rc property and setter + @property + def rc(self): + return self._rc + + @rc.setter + def rc(self, new_rc): + self._rc = new_rc + self.reset_cache() + + # zs property and setter + @property + def zs(self): + return self._zs + + @zs.setter + def zs(self, new_zs): + self._zs = new_zs + self.reset_cache() + + # nfp property + @property + def nfp(self): + return self._nfp + + # mpol property + @property + def mpol(self): + return self._mpol + + # ntor property + @property + def ntor(self): + return self._ntor + + # xm property + @property + def xm(self): + if self._xm is None: + self._xm = jnp.repeat(jnp.arange(self.mpol + 1), 2 * self.ntor + 1)[self.ntor:] + return self._xm + + # xn property + @property + def xn(self): + if self._xn is None: + self._xn = self.nfp * jnp.tile(jnp.arange(-self.ntor, self.ntor + 1), self.mpol + 1)[self.ntor:] + return self._xn + + # _ntheta property and setter + @property + def ntheta(self): + return self._ntheta + + @ntheta.setter + def ntheta(self, new_ntheta): + self._ntheta = new_ntheta + self.reset_mesh() + + # n_phi property and setter + @property + def nphi(self): + return self._nphi + + @nphi.setter + def nphi(self, new_nphi): + self._nphi = new_nphi + self.reset_mesh() + + # close property and setter + @property + def close(self): + return self._close + + @close.setter + def close(self, new_close): + self._close = new_close + self.reset_mesh() + + # range_torus property and setter + @property + def range_torus(self): + return self._range_torus + + @range_torus.setter + def range_torus(self, new_range): + self._range_torus = new_range + self.reset_mesh() + + # _compute_meshgrid method + @jit + def _compute_meshgrid(self): + if self.range_torus == "full torus": + div, end_val = 1., 1. + elif self.range_torus == "half period": + div, end_val = self.nfp, 0.5 + quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=self.ntheta, endpoint=self.close) + quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=self.nphi, endpoint=self.close) + theta2d, phi2d = jnp.meshgrid(quadpoints_theta, quadpoints_phi) + return quadpoints_theta, quadpoints_phi, theta2d, phi2d + + # theta2d property + @property + def theta2d(self): + if self._theta2d is None: + self._quadpoints_theta, self._quadpoints_phi, self._theta2d, self._phi2d = self._compute_meshgrid() + return self._theta2d + + # phi2d property + @property + def phi2d(self): + if self._phi2d is None: + self._quadpoints_theta, self._quadpoints_phi, self._theta2d, self._phi2d = self._compute_meshgrid() + return self._phi2d + + # angles property + @property + def angles(self): + if self._angles is None: + self._angles = jnp.einsum('i,jk->ijk', self.xm, self.theta2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi2d) + return self._angles + + # scaling_type property and setter + @property + def scaling_type(self): + return self._scaling_type + + @scaling_type.setter + def scaling_type(self, new_type): + self._scaling_type = new_type + self._scaling = None + # scaling_factor property and setter + @property + def scaling_factor(self): + return self._scaling_factor + + @scaling_factor.setter + def scaling_factor(self, new_factor): + self._scaling_factor = new_factor + self._scaling = None + + # scaling property + @property + def scaling(self): + if self._scaling is None: + self._scaling = jnp.exp(self.scaling_factor * jnp.linalg.norm(jnp.vstack([self.xm, self.xn]), ord=self.scaling_type, axis=0)) + return self._scaling + + # dofs property and setter @property def dofs(self): - return self._dofs + return jnp.hstack([self.rc * self.scaling, self.zs * self.scaling]) @dofs.setter - def dofs(self, new_dofs,scaled=True): - if scaled==True: - self._dofs = new_dofs - else: - self._dofs = self.rescaling_function(new_dofs) - if scaled==True: - self.rc=self.unscaling_function(new_dofs)[:self.num_dofs_rc] - self.zs=self.unscaling_function(new_dofs)[self.num_dofs_rc:] - else: - self.rc = new_dofs[:self.num_dofs_rc] - self.zs = new_dofs[self.num_dofs_rc:] - - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc - self.zmns_interp = self.zs - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal, self._area_element) = self._set_gamma(self.rmnc_interp, self.zmns_interp) - # if hasattr(self, 'bmnc'): - # self._AbsB = self._set_AbsB() + def dofs(self, new_dofs): + self._rc = new_dofs[:self.rc.size] / self.scaling + self._zs = new_dofs[self.rc.size:] / self.scaling + self.reset_cache() - @partial(jit, static_argnames=['self']) - def _set_gamma(self, rmnc_interp, zmns_interp): - phi_2d = self.phi_2d + # _compute_gamma method + @jit + def _compute_gamma(self): angles = self.angles - + print(angles.shape) sin_angles = jnp.sin(angles) cos_angles = jnp.cos(angles) - r_coordinate = jnp.einsum('i,ijk->jk', rmnc_interp, cos_angles) - z_coordinate = jnp.einsum('i,ijk->jk', zmns_interp, sin_angles) - gamma = jnp.transpose(jnp.array([r_coordinate * jnp.cos(phi_2d), r_coordinate * jnp.sin(phi_2d), z_coordinate]), (1, 2, 0)) - - dX_dtheta = jnp.einsum('i,ijk,i->jk', -self.xm, sin_angles, rmnc_interp) * jnp.cos(phi_2d) - dY_dtheta = jnp.einsum('i,ijk,i->jk', -self.xm, sin_angles, rmnc_interp) * jnp.sin(phi_2d) - dZ_dtheta = jnp.einsum('i,ijk,i->jk', self.xm, cos_angles, zmns_interp) - gammadash_theta = 2*jnp.pi*jnp.transpose(jnp.array([dX_dtheta, dY_dtheta, dZ_dtheta]), (1, 2, 0)) - - dX_dphi = jnp.einsum('i,ijk,i->jk', self.xn, sin_angles, rmnc_interp) * jnp.cos(phi_2d) - r_coordinate * jnp.sin(phi_2d) - dY_dphi = jnp.einsum('i,ijk,i->jk', self.xn, sin_angles, rmnc_interp) * jnp.sin(phi_2d) + r_coordinate * jnp.cos(phi_2d) - dZ_dphi = jnp.einsum('i,ijk,i->jk', -self.xn, cos_angles, zmns_interp) - gammadash_phi = 2*jnp.pi*jnp.transpose(jnp.array([dX_dphi, dY_dphi, dZ_dphi]), (1, 2, 0)) - - normal = jnp.cross(gammadash_phi, gammadash_theta, axis=2) - unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) - area_element = jnp.linalg.norm(jnp.cross(gammadash_theta, gammadash_phi, axis=2), axis=2) + phi2d = self.phi2d + sin_phi2d = jnp.sin(phi2d) + cos_phi2d = jnp.cos(phi2d) + rc = self.rc; zs = self.zs; xm = self.xm; xn = self.xn + + print(rc.shape, cos_angles.shape) + R = jnp.einsum('i,ijk->jk', rc, cos_angles) + Z = jnp.einsum('i,ijk->jk', zs, sin_angles) + X = R * cos_phi2d + Y = R * sin_phi2d + gamma = jnp.stack([X, Y, Z], axis=-1) + + dR_dtheta = -jnp.einsum('i,ijk->jk', xm * rc, sin_angles) + dZ_dtheta = jnp.einsum('i,ijk->jk', xm * zs, cos_angles) + dX_dtheta = dR_dtheta * cos_phi2d + dY_dtheta = dR_dtheta * sin_phi2d + gammadash_theta = jnp.stack([dX_dtheta, dY_dtheta, dZ_dtheta], axis=-1) + + dR_dphi = jnp.einsum('i,ijk->jk', xn*rc, sin_angles) + dZ_dphi = -jnp.einsum('i,ijk->jk', xn*zs, cos_angles) + dX_dphi = dR_dphi * cos_phi2d - R * sin_phi2d + dY_dphi = dR_dphi * sin_phi2d + R * cos_phi2d + gammadash_phi = jnp.stack([dX_dphi, dY_dphi, dZ_dphi], axis=-1) - return (gamma, gammadash_theta, gammadash_phi, normal, unitnormal, area_element) - - @partial(jit, static_argnames=['self']) - def _set_AbsB(self): - angles_nyq = jnp.einsum('i,jk->ijk', self.xm_nyq, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn_nyq, self.phi_2d) - AbsB = jnp.einsum('i,ijk->jk', self.bmnc_interp, jnp.cos(angles_nyq)) - return AbsB + return gamma, gammadash_theta, gammadash_phi + # gamma, gammadash_theta, gammadash_phi properties @property def gamma(self): + if self._gamma is None: + self._gamma, self._gammadash_theta, self._gammadash_phi = self._compute_gamma() return self._gamma @property def gammadash_theta(self): + if self._gammadash_theta is None: + self._gamma, self._gammadash_theta, self._gammadash_phi = self._compute_gamma() return self._gammadash_theta @property def gammadash_phi(self): + if self._gammadash_phi is None: + self._gamma, self._gammadash_theta, self._gammadash_phi = self._compute_gamma() return self._gammadash_phi + + # _compute_properties method + @jit + def _compute_properties(self): + normal = jnp.cross(self.gammadash_theta, self.gammadash_phi, axis=2) + unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) + area_element = jnp.linalg.norm(normal, axis=2) + return normal, unitnormal, area_element + # normal, unitnormal, area_element properties @property def normal(self): + if self._normal is None: + self._normal, self._unitnormal, self._area_element = self._compute_properties() return self._normal @property def unitnormal(self): + if self._unitnormal is None: + self._normal, self._unitnormal, self._area_element = self._compute_properties() return self._unitnormal @property def area_element(self): + if self._area_element is None: + self._normal, self._unitnormal, self._area_element = self._compute_properties() return self._area_element - @property - def AbsB(self): - return self._AbsB - + # TODO: remove x property. This is a placeholder for compatibility with the examples that need to be updated. + # x property and setter @property def x(self): return self.dofs @@ -355,92 +496,74 @@ def area(self): area = jnp.sum(norm_n) * dphi * dtheta return area - def scaling_L_infty(self,x): - return x / jnp.exp(-self.rescaling_factor*jnp.maximum(jnp.abs(self.xm),jnp.abs(self.xn))) - - def scaling_L_1(self,x): - return x / jnp.exp(-self.rescaling_factor*(jnp.abs(self.xm)+jnp.abs(self.xn))) - - def scaling_L_2(x): - return x / jnp.exp(-self.rescaling_factor*jnp.sqrt(self.xm**2+self.xn**2)) - - def unscaling_L_infty(self,x): - return x * jnp.exp(-self.rescaling_factor*jnp.maximum(jnp.abs(self.xm),jnp.abs(self.xn))) - - def unscaling_L_1(self,x): - return x * jnp.exp(-self.rescaling_factor*(jnp.abs(self.xm)+jnp.abs(self.xn))) - - def unscaling_L_2(self,x): - return x * jnp.exp(-self.rescaling_factor*jnp.sqrt(self.xm**2+self.xn**2)) - - def change_resolution(self, mpol: int, ntor: int, ntheta=None, nphi=None,close=True): - """ - Change the values of `mpol` and `ntor`. - New Fourier coefficients are zero by default. - Old coefficients outside the new range are discarded. - """ - rc_old, zs_old = self.rc, self.zs - mpol_old, ntor_old = self.mpol, self.ntor - if ntheta is not None: - self.ntheta = ntheta - else: - ntheta = self.ntheta - - if nphi is not None: - self.nphi = nphi - else: - nphi = self.nphi - - #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) - #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) - rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) - zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) - m_keep = min(mpol_old, mpol) - n_keep = min(ntor_old, ntor) - - xm_old=self.xm - xn_old=self.xn - self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] - self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] - # Copy overlapping region - for l in range(len(self.xm)): - if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: - index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp - rc_new=rc_new.at[l].set(self.rc[index]) - zs_new=zs_new.at[l].set(self.zs[index]) - - - # Update attributes - self.mpol, self.ntor = mpol, ntor - self.rc, self.zs = rc_new, zs_new - - self.rmnc_interp = self.rc - self.zmns_interp = self.zs - - # Update degrees of freedom - self.num_dofs_rc = len(jnp.ravel(self.rc)) - self.num_dofs_zs = len(jnp.ravel(self.zs)) - self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) - - # Recompute angles and geometry - if self.range_torus == 'full torus': div = 1 - else: div = self.nfp - if self.range_torus == 'half period': end_val = 0.5 - else: end_val = 1.0 - self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) - self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) - self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) - - self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) - - - # Recompute AbsB if available - if hasattr(self, 'bmnc'): - self._AbsB = self._set_AbsB() - - return self + # def change_resolution(self, mpol: int, ntor: int, ntheta=None, nphi=None,close=True): + # """ + # Change the values of `mpol` and `ntor`. + # New Fourier coefficients are zero by default. + # Old coefficients outside the new range are discarded. + # """ + # rc_old, zs_old = self.rc, self.zs + # mpol_old, ntor_old = self.mpol, self.ntor + # if ntheta is not None: + # self.ntheta = ntheta + # else: + # ntheta = self.ntheta + + # if nphi is not None: + # self.nphi = nphi + # else: + # nphi = self.nphi + + # #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) + # #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) + # rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) + # zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) + # m_keep = min(mpol_old, mpol) + # n_keep = min(ntor_old, ntor) + + # xm_old=self.xm + # xn_old=self.xn + # self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] + # self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] + # # Copy overlapping region + # for l in range(len(self.xm)): + # if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: + # index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp + # rc_new=rc_new.at[l].set(self.rc[index]) + # zs_new=zs_new.at[l].set(self.zs[index]) + + + # # Update attributes + # self.mpol, self.ntor = mpol, ntor + # self.rc, self.zs = rc_new, zs_new + + # self.rmnc_interp = self.rc + # self.zmns_interp = self.zs + + # # Update degrees of freedom + # self.num_dofs_rc = len(jnp.ravel(self.rc)) + # self.num_dofs_zs = len(jnp.ravel(self.zs)) + # self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) + + # # Recompute angles and geometry + # if self.range_torus == 'full torus': div = 1 + # else: div = self.nfp + # if self.range_torus == 'half period': end_val = 0.5 + # else: end_val = 1.0 + # self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) + # self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) + # self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) + + # self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) + # (self._gamma, self._gammadash_theta, self._gammadash_phi, + # self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + + + # # Recompute AbsB if available + # if hasattr(self, 'bmnc'): + # self._AbsB = self._set_AbsB() + + # return self def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.") @@ -452,7 +575,7 @@ def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if ax is None or ax.name != "3d": fig = plt.figure() ax = fig.add_subplot(projection='3d') - + boundary = self.gamma if hasattr(self, 'bmnc'): @@ -531,6 +654,29 @@ def mean_cross_sectional_area(self): mean_cross_sectional_area = jnp.abs(jnp.mean(jnp.sqrt(x2y2) * dZ_dtheta * detJ))/(2 * jnp.pi) return mean_cross_sectional_area + def _tree_flatten(self): + children = (self._rc, self._zs) # arrays / dynamic values + aux_data = {"nfp": self._nfp, + "mpol": self._mpol, + "ntor": self._ntor, + "ntheta": self._ntheta, + "nphi": self._nphi, + "close": self._close, + "range_torus": self._range_torus, + "scaling_type": self._scaling_type, + "scaling_factor": self._scaling_factor} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + print([jax.core.get_aval(c) for c in children]) + print([jax.core.get_aval(val) for val in aux_data.values() if not isinstance(val, str)]) + return cls(*children, **aux_data) + +tree_util.register_pytree_node(SurfaceRZFourier, + SurfaceRZFourier._tree_flatten, + SurfaceRZFourier._tree_unflatten) + #This class is based on simsopt classifier but translated to fit jax class SurfaceClassifier(): """ From a4fad74a53b737e15697db5289737363337a658f Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 5 Nov 2025 20:06:02 +0100 Subject: [PATCH 56/63] Fix(fields,surfaces): fixed mpol in vmec import --- essos/fields.py | 2 +- essos/surfaces.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/essos/fields.py b/essos/fields.py index 20fc869..b8c324f 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -226,7 +226,7 @@ def __init__(self, wout_filename, ntheta=50, nphi=50, close=True, range_torus='f self.s_half_grid = self.s_full_grid[1:] - 0.5 * self.ds self.r_axis = self.rmnc[0, 0] self.z_axis=self.zmns[0,0] - self.mpol = int(jnp.max(self.xm)+1) + self.mpol = int(jnp.max(self.xm)) self.ntor = int(jnp.max(jnp.abs(self.xn)) / self.nfp) self.range_torus = range_torus self._surface = SurfaceRZFourier(self, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) diff --git a/essos/surfaces.py b/essos/surfaces.py index e040233..0a38cfb 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -188,7 +188,7 @@ def from_wout_file(cls, file, s=1, ntheta=30, nphi=30, close=True, range_torus=' nfp = int(nc.variables["nfp"][0]) xm = jnp.array(nc.variables["xm"][:]) xn = jnp.array(nc.variables["xn"][:]) - mpol = int(jnp.max(xm)+1) + mpol = int(jnp.max(xm)) ntor = int(jnp.max(jnp.abs(xn)) / nfp) ns = nc.variables["ns"][0] From ce7caf710022e797cbe7c005d24ce3a6ad510c87 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Wed, 12 Nov 2025 11:45:11 +0100 Subject: [PATCH 57/63] Fix(surfaces): jit and pjit fix --- essos/surfaces.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/essos/surfaces.py b/essos/surfaces.py index 0a38cfb..2ed0a12 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -4,14 +4,15 @@ from jax.scipy.interpolate import RegularGridInterpolator from jax import tree_util, jit, vmap, devices, device_put from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental.pjit import pjit from essos.plot import fix_matplotlib_3d import jaxkd mesh = Mesh(devices(), ("dev",)) -sharding = NamedSharding(mesh, PartitionSpec("dev", None)) +sharding = NamedSharding(mesh, PartitionSpec("dev")) -@partial(jit, static_argnames=['surface','field']) +@jit def toroidal_flux(surface, field, idx=0) -> jnp.ndarray: curve = surface.gamma[idx] dl = jnp.roll(curve, -1, axis=0) - curve @@ -25,7 +26,7 @@ def toroidal_flux(surface, field, idx=0) -> jnp.ndarray: #tf = jnp.sum(Adl) return tf -@partial(jit, static_argnames=['surface','field']) +@jit def poloidal_flux(surface, field, idx=0) -> jnp.ndarray: curve = surface.gamma[:,idx,:] dl = jnp.roll(curve, -1, axis=0) - curve @@ -39,39 +40,40 @@ def poloidal_flux(surface, field, idx=0) -> jnp.ndarray: #tf = jnp.sum(Adl) return tf -@partial(jit, static_argnames=['surface','field']) +# @jit +@partial(pjit, in_shardings=(sharding, None), out_shardings=sharding) def B_on_surface(surface, field): ntheta = surface.ntheta nphi = surface.nphi gamma = surface.gamma gamma_reshaped = gamma.reshape(nphi * ntheta, 3) - gamma_sharded = device_put(gamma_reshaped, sharding) - B_on_surface = jit(vmap(field.B), in_shardings=sharding, out_shardings=sharding)(gamma_sharded) - B_on_surface = B_on_surface.reshape(nphi, ntheta, 3) - return B_on_surface + # Map field.B over all positions + B_on_surface = vmap(field.B)(gamma_reshaped) + + return B_on_surface.reshape(nphi, ntheta, 3) -@partial(jit, static_argnames=['surface','field']) +@jit def BdotN(surface, field): B_surface = B_on_surface(surface, field) B_dot_n = jnp.sum(B_surface * surface.unitnormal, axis=2) return B_dot_n -@partial(jit, static_argnames=['surface','field']) +@jit def BdotN_over_B(surface, field, **kwargs): return BdotN(surface, field) / jnp.linalg.norm(B_on_surface(surface, field), axis=2) -@partial(jit, static_argnames=['surface','field']) +@jit def _squared_flux_local(surface, field): return 0.5 * jnp.mean(BdotN(surface, field)**2 / jnp.sum(B_on_surface(surface, field)**2, axis=2) * surface.area_element) -@partial(jit, static_argnames=['surface','field']) +@jit def _squared_flux_global(surface, field): return 0.5 * jnp.mean(BdotN(surface, field)**2 * surface.area_element) -@partial(jit, static_argnames=['surface','field']) +@jit def _squared_flux_normalized(surface, field): return 0.5 * jnp.mean(BdotN(surface, field)**2 * surface.area_element) / \ jnp.mean(jnp.sum(B_on_surface(surface, field)**2, axis=2) * surface.area_element) @@ -669,8 +671,6 @@ def _tree_flatten(self): @classmethod def _tree_unflatten(cls, aux_data, children): - print([jax.core.get_aval(c) for c in children]) - print([jax.core.get_aval(val) for val in aux_data.values() if not isinstance(val, str)]) return cls(*children, **aux_data) tree_util.register_pytree_node(SurfaceRZFourier, From 43fd4040443194c96566b556339ffa9b81e5f305 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Tue, 18 Nov 2025 00:09:37 +0100 Subject: [PATCH 58/63] Feat (losses): Create losses files --- essos/losses.py | 206 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 essos/losses.py diff --git a/essos/losses.py b/essos/losses.py new file mode 100644 index 0000000..e499e9a --- /dev/null +++ b/essos/losses.py @@ -0,0 +1,206 @@ +import os +from functools import partial +import jax +import jax.numpy as jnp +from jax import tree_util, jit, grad +from essos.coils import Curves, Coils, CreateEquallySpacedCurves +from essos.surfaces import SurfaceRZFourier +from essos.fields import BiotSavart + +from essos.surfaces import BdotN_over_B +from scipy.optimize import least_squares + +class base_loss: + def __init__(self): + self.losses = [] + self.weights = [] + self._depends_on = {} # Dict of the objects that the losses depend on, e.g., {"coils": Coils, "surface": SurfaceRZFourier, ...} + self._dofs_size = {} # Dict of slices indicating the size of the dofs for each dependency, e.g., {"coils": slice(0, 10), "surface": slice(10, 20), ...} + + + @property + def depends_on(self): + return self._depends_on + + + @depends_on.setter + def depends_on(self, value): + if not isinstance(value, dict): + raise ValueError("depends_on must be a dictionary mapping dependency names to their corresponding objects.") + + sum = 0 + for dependency, obj in value.items(): + if not hasattr(obj, 'dofs'): + raise ValueError(f"The object for dependency '{dependency}' must have a 'dofs' attribute.") + self._dofs_size[dependency] = slice(sum, sum + obj.dofs.size) + sum += obj.dofs.size + + self._depends_on = value + + + @property + def dofs(self): + dofs = jnp.array([]) + for obj in self.depends_on.values(): + dofs = jnp.concatenate([dofs, jnp.ravel(obj.dofs)]) + return dofs + + @dofs.setter + def dofs(self, value): + for dependency in self.depends_on: + self.depends_on[dependency].dofs = jnp.array(jnp.reshape(value[self._dofs_size[dependency]], self.depends_on[dependency].dofs.shape)) + + + def __call__(self, dofs): + if len(self.losses) == 0: + raise ValueError("No losses have been defined in base_loss. Use the 'losses' attribute to specify the loss functions.") + if len(self.depends_on) == 0: + raise ValueError("No dependencies have been defined in base_loss. Use the 'depends_on' attribute to specify the objects that the losses depend on.") + + self.dofs = dofs + return sum(self.weights[ii] * loss(**self.depends_on) for ii, loss in enumerate(self.losses)) + + + def __add__(self, other): + if not isinstance(other, base_loss): + raise TypeError("Addition is only defined between base_loss objects.") + new = base_loss() + new.losses = [*self.losses, *other.losses] # Flatten the losses + new.weights = [*self.weights, *other.weights] # Flatten the weights + return new + + + def __iter__(self): + return iter(self.losses) + + + def __mul__(self, other): + if not isinstance(other, (int, float)): + raise TypeError("Multiplication is only defined between base_loss and a scalar.") + new = base_loss() + new.losses = self.losses # Share reference + new.weights = [w * other for w in self.weights] + return new + + + def __rmul__(self, other): + return self.__mul__(other) + + +class target_loss(base_loss): + def __init__(self, quantity, target=0, mode="max"): + self.losses = [self] + self.weights = [1.] + self.target = target + + if not quantity in ["coil_length", "coil_curvature", "coil_separation"]: + raise ValueError("quantity must be one of 'coil_length', 'coil_curvature', or 'coil_separation'.") + self.quantity = quantity + + if not mode in ["max", "min"]: + raise ValueError("mode must be one of 'max' or 'min'.") + self.mode = mode + + @partial(jit, static_argnames=['self']) + def __call__(self, **kwargs): + optimizable = None + + if self.quantity == 'coil_length': + coils = kwargs.get("coils") + if coils is None: + raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_length'.") + optimizable = coils.length + elif self.quantity == 'coil_curvature': + coils = kwargs.get("coils") + if coils is None: + raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_curvature'.") + optimizable = jnp.mean(coils.curvature, axis=1) + elif self.quantity == 'coil_separation': + coils = kwargs.get("coils") + if coils is None: + raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_separation'.") + optimizable = coils.separation + elif self.quantity == 'surface_area': + coils = kwargs.get("surface") + if coils is None: + raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_separation'.") + optimizable = coils.separation + else: + raise ValueError(f"Unknown quantity: {self.quantity}") + + if self.mode == "max": + return jnp.max(jnp.maximum(0, optimizable - self.target)) + elif self.mode == "min": + return jnp.min(jnp.maximum(0, optimizable - self.target)) + elif self.mode == "abs": + return jnp.sum(jnp.abs(optimizable - self.target)) + + else: + raise ValueError(f"Unknown mode: {self.mode}") + +class custom_loss(base_loss): + def __init__(self, fun): + self.losses = [self] + self.weights = [1.] + self.fun = fun + + @partial(jit, static_argnames=['self']) + def __call__(self, **kwargs): + return self.fun(**kwargs) + +if __name__ == "__main__": + vmec_input = os.path.join(os.path.dirname(__file__), '../examples/input_files/wout_LandremanPaul2021_QA_reactorScale_lowres.nc') + + # JF = Jf \ + # + LENGTH_WEIGHT * sum(Jls) \ + # + CC_WEIGHT * Jccdist \ + # + CS_WEIGHT * Jcsdist \ + # + CURVATURE_WEIGHT * sum(Jcs) \ + # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD, "max") for J in Jmscs) + + """ Creating starting coils and surface """ + N_COILS = 3; FOURIER_ORDER = 3; LARGE_R = 10; SMALL_R = 5.6; NFP = 2; N_SEGMENTS = 45; STELLSYM = True # Curve parameters + COIL_CURRENT = 1. # 1.714e7 # Amperes + + curves = CreateEquallySpacedCurves(N_COILS, FOURIER_ORDER, LARGE_R, SMALL_R, n_segments=N_SEGMENTS, nfp=NFP, stellsym=STELLSYM) + coils = Coils(curves=curves, currents=[COIL_CURRENT]*N_COILS) + coils_initial = coils.copy() + surface = SurfaceRZFourier.from_wout_file(vmec_input, s=1, ntheta=30, nphi=30, range_torus='half period') + field = BiotSavart(coils) + + """ Setting the losses and their weights """ + LENGTH_WEIGHT = 0.; LENGTH_TARGET = 43. + CURVATURE_WEIGHT = 0.; CURVATURE_TARGET = 0.1 + NORMAL_FIELD_WEIGHT = 1. + + L_length = target_loss("coil_length", target=LENGTH_TARGET, mode="max") + L_curvature = target_loss("coil_curvature", target=CURVATURE_TARGET, mode="max") + + def loss(field, **kwargs): + return jnp.sum(jnp.abs(BdotN_over_B(surface, field))) + + L_normal_field = custom_loss(loss) + + L_total = NORMAL_FIELD_WEIGHT*L_normal_field #+ LENGTH_WEIGHT*L_length + CURVATURE_WEIGHT*L_curvature + + print(L_total.losses) + print(L_total.weights) + + L_total.depends_on = {"coils": coils, "field": field} + print(L_total(dofs=L_total.dofs)) + + res = least_squares(L_total, L_total.dofs, diff_step=1e-4, verbose=2, ftol=1e-5, gtol=1e-5, xtol=1e-14, max_nfev=100) + + print(L_total(dofs=res.x)) + + import matplotlib.pyplot as plt + fig = plt.figure(figsize=(8, 4)) + ax1 = fig.add_subplot(121, projection='3d') + ax2 = fig.add_subplot(122, projection='3d') + coils_initial.plot(ax=ax1, show=False) + surface.plot(ax=ax1, show=False) + coils.plot(ax=ax2, show=False) + surface.plot(ax=ax2, show=False) + plt.tight_layout() + plt.show() + From ac5c7bc49e351b1b0c811fc5b05acb5f55f5c066 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Tue, 18 Nov 2025 00:18:21 +0100 Subject: [PATCH 59/63] Fix (losses): Allow for grad usage during optimization --- essos/losses.py | 336 ++++++++++++++++++++++++++++-------------------- 1 file changed, 195 insertions(+), 141 deletions(-) diff --git a/essos/losses.py b/essos/losses.py index e499e9a..997ab25 100644 --- a/essos/losses.py +++ b/essos/losses.py @@ -2,7 +2,9 @@ from functools import partial import jax import jax.numpy as jnp -from jax import tree_util, jit, grad +from jax import tree_util, jit, grad as grad_jax +from jax.flatten_util import ravel_pytree + from essos.coils import Curves, Coils, CreateEquallySpacedCurves from essos.surfaces import SurfaceRZFourier from essos.fields import BiotSavart @@ -12,195 +14,247 @@ class base_loss: def __init__(self): - self.losses = [] - self.weights = [] - self._depends_on = {} # Dict of the objects that the losses depend on, e.g., {"coils": Coils, "surface": SurfaceRZFourier, ...} - self._dofs_size = {} # Dict of slices indicating the size of the dofs for each dependency, e.g., {"coils": slice(0, 10), "surface": slice(10, 20), ...} - + self.losses = [self] + self._dependencies = {} + self._dependencies_buffer = None + self._starting_dofs = None + self._dofs_to_pytree = None - @property - def depends_on(self): - return self._depends_on - + def clear_cache(self): + self._dependencies_buffer = None + self._starting_dofs = None + self._dofs_to_pytree = None - @depends_on.setter - def depends_on(self, value): - if not isinstance(value, dict): - raise ValueError("depends_on must be a dictionary mapping dependency names to their corresponding objects.") - - sum = 0 - for dependency, obj in value.items(): - if not hasattr(obj, 'dofs'): - raise ValueError(f"The object for dependency '{dependency}' must have a 'dofs' attribute.") - self._dofs_size[dependency] = slice(sum, sum + obj.dofs.size) - sum += obj.dofs.size - - self._depends_on = value + @property + def dependencies(self): + return self._dependencies + @dependencies.setter + def dependencies(self, value): + assert isinstance(value, dict), "dependencies must be a dictionary mapping dependency names to their corresponding objects." + self.clear_cache() + self._dependencies = value @property - def dofs(self): - dofs = jnp.array([]) - for obj in self.depends_on.values(): - dofs = jnp.concatenate([dofs, jnp.ravel(obj.dofs)]) - return dofs - - @dofs.setter - def dofs(self, value): - for dependency in self.depends_on: - self.depends_on[dependency].dofs = jnp.array(jnp.reshape(value[self._dofs_size[dependency]], self.depends_on[dependency].dofs.shape)) - - - def __call__(self, dofs): - if len(self.losses) == 0: - raise ValueError("No losses have been defined in base_loss. Use the 'losses' attribute to specify the loss functions.") - if len(self.depends_on) == 0: - raise ValueError("No dependencies have been defined in base_loss. Use the 'depends_on' attribute to specify the objects that the losses depend on.") - - self.dofs = dofs - return sum(self.weights[ii] * loss(**self.depends_on) for ii, loss in enumerate(self.losses)) - + def dependencies_buffer(self): + if self._dependencies_buffer is None: + self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) + return self._dependencies_buffer def __add__(self, other): if not isinstance(other, base_loss): raise TypeError("Addition is only defined between base_loss objects.") - new = base_loss() - new.losses = [*self.losses, *other.losses] # Flatten the losses - new.weights = [*self.weights, *other.weights] # Flatten the weights - return new - + + losses_list = [*self.losses, *other.losses] # Flatten the losses + out_loss = composite_loss(losses_list) + out_loss.dependencies = self.dependencies | other.dependencies + return out_loss def __iter__(self): return iter(self.losses) - def __mul__(self, other): - if not isinstance(other, (int, float)): - raise TypeError("Multiplication is only defined between base_loss and a scalar.") - new = base_loss() - new.losses = self.losses # Share reference - new.weights = [w * other for w in self.weights] - return new - + raise NotImplementedError("Multiplication is only defined in subclasses of base_loss.") def __rmul__(self, other): return self.__mul__(other) - -class target_loss(base_loss): - def __init__(self, quantity, target=0, mode="max"): - self.losses = [self] - self.weights = [1.] - self.target = target - if not quantity in ["coil_length", "coil_curvature", "coil_separation"]: - raise ValueError("quantity must be one of 'coil_length', 'coil_curvature', or 'coil_separation'.") - self.quantity = quantity +class custom_loss(base_loss): + def __init__(self, fun, *args_names, **kwargs): + """ A custom loss function that can take multiple arguments and compute gradients with respect to specified arguments. + + Args: + fun (callable): + The loss function to be optimized. It may take multiple arguments. + All dynamic arguments (i.e., those that require gradients) should be passed as positional arguments, while static arguments (i.e., those that do not require gradients) should be passed as keyword arguments. + args_names (tuple): + A tuple of strings indicating the names of the dynamic arguments. This is used for gradient computation. + *args: Dynamic (differentiable) arguments to be passed to the loss function. + **kwargs: Static (non-differentiable) keyword arguments to be passed to the loss function. + + Returns: + custom_loss: An instance of the custom_loss class. + """ + super().__init__() + self.fun = fun + self.args_names = args_names + self.kwargs = kwargs - if not mode in ["max", "min"]: - raise ValueError("mode must be one of 'max' or 'min'.") - self.mode = mode + # The dofs of a custom loss are the dofs of its arguments + @property + def starting_dofs(self): + if self._starting_dofs is None: + self._starting_dofs, self.dofs_to_pytree = ravel_pytree(tuple(self.dependencies[arg] for arg in self.args_names)) + return self._starting_dofs + + @property + def dofs_to_pytree(self): + if self._dofs_to_pytree is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(tuple(self.dependencies[arg] for arg in self.args_names)) + return self._dofs_to_pytree + + @partial(jit, static_argnames=['self']) + def __call__(self, dofs: jnp.ndarray) -> float: + args = self.dofs_to_pytree(dofs) + return self.fun(*args, **self.kwargs) + + @partial(jit, static_argnames=['self']) + def call_pytree(self, dofs_pytree) -> float: + return self.fun(*dofs_pytree, **self.kwargs) @partial(jit, static_argnames=['self']) - def __call__(self, **kwargs): - optimizable = None - - if self.quantity == 'coil_length': - coils = kwargs.get("coils") - if coils is None: - raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_length'.") - optimizable = coils.length - elif self.quantity == 'coil_curvature': - coils = kwargs.get("coils") - if coils is None: - raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_curvature'.") - optimizable = jnp.mean(coils.curvature, axis=1) - elif self.quantity == 'coil_separation': - coils = kwargs.get("coils") - if coils is None: - raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_separation'.") - optimizable = coils.separation - elif self.quantity == 'surface_area': - coils = kwargs.get("surface") - if coils is None: - raise ValueError("Coils must be provided in when calling target_loss with quantity 'coil_separation'.") - optimizable = coils.separation - else: - raise ValueError(f"Unknown quantity: {self.quantity}") + def grad(self, dofs: jnp.ndarray) -> jnp.ndarray: + args = self.dofs_to_pytree(dofs) + gradient = grad_jax(self.fun, argnums=tuple(range(len(args))))(*args, **self.kwargs) + return ravel_pytree(gradient)[0] + + @partial(jit, static_argnames=['self']) + def grad_pytree(self, dofs_pytree) -> dict: + gradient = grad_jax(self.fun, argnums=tuple(range(len(dofs_pytree))))(*dofs_pytree, **self.kwargs) + buffer = self.dependencies_buffer.copy() + for dep, g in zip(self.args_names, gradient): + buffer[dep] = g + return buffer + + def __mul__(self, other): + if not isinstance(other, (int, float)): + raise TypeError("Multiplication is only defined between base_loss and a scalar.") - if self.mode == "max": - return jnp.max(jnp.maximum(0, optimizable - self.target)) - elif self.mode == "min": - return jnp.min(jnp.maximum(0, optimizable - self.target)) - elif self.mode == "abs": - return jnp.sum(jnp.abs(optimizable - self.target)) + new_fun = lambda *args, **kwargs: other * self.fun(*args, **kwargs) + out_loss = custom_loss(new_fun, *self.args_names, **self.kwargs) + return out_loss - else: - raise ValueError(f"Unknown mode: {self.mode}") + -class custom_loss(base_loss): - def __init__(self, fun): - self.losses = [self] - self.weights = [1.] - self.fun = fun +class composite_loss(base_loss): + def __init__(self, losses: list): + """ A composite loss function that combines multiple loss functions. + + Args: + losses (list): + A list of loss functions to be combined. Each loss function should be an instance of base_loss or its subclasses. + Returns: + composite_loss: An instance of the composite_loss class. + """ + super().__init__() + self.losses = losses + + @property + def dependencies(self): + return self._dependencies + + @dependencies.setter + def dependencies(self, value): + assert isinstance(value, dict), "dependencies must be a dictionary mapping dependency names to their corresponding objects." + self.clear_cache() + self._dependencies = value + for loss in self.losses: + loss.dependencies = self._dependencies + + # The dofs of a composite loss are all the dofs of its dependencies + @property + def starting_dofs(self): + if self._starting_dofs is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(self.dependencies) + return self._starting_dofs + + @property + def dofs_to_pytree(self): + if self._dofs_to_pytree is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(self.dependencies) + return self._dofs_to_pytree + + @partial(jit, static_argnames=['self']) + def __call__(self, dofs: jnp.ndarray) -> float: + dependencies = self.dofs_to_pytree(dofs) + each_loss = [loss.call_pytree(tuple(dependencies[arg] for arg in loss.args_names))\ + for loss in self.losses] + return sum(each_loss) @partial(jit, static_argnames=['self']) - def __call__(self, **kwargs): - return self.fun(**kwargs) + def grad(self, dofs: jnp.ndarray) -> jnp.ndarray: + dependencies = self.dofs_to_pytree(dofs) + + grads_each_loss = [loss.grad_pytree(tuple(dependencies[arg] for arg in loss.args_names))\ + for loss in self.losses] + + grad = jax.tree_util.tree_map(lambda *dofs: jnp.sum(jnp.stack(dofs), axis=0), *grads_each_loss) + dofs_grad = ravel_pytree(grad)[0] + return dofs_grad + + + + if __name__ == "__main__": - vmec_input = os.path.join(os.path.dirname(__file__), '../examples/input_files/wout_LandremanPaul2021_QA_reactorScale_lowres.nc') + import matplotlib.pyplot as plt - # JF = Jf \ - # + LENGTH_WEIGHT * sum(Jls) \ - # + CC_WEIGHT * Jccdist \ - # + CS_WEIGHT * Jcsdist \ - # + CURVATURE_WEIGHT * sum(Jcs) \ - # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD, "max") for J in Jmscs) + vmec_input = os.path.join(os.path.dirname(__file__), '../examples/input_files/wout_LandremanPaul2021_QA_reactorScale_lowres.nc') """ Creating starting coils and surface """ N_COILS = 3; FOURIER_ORDER = 3; LARGE_R = 10; SMALL_R = 5.6; NFP = 2; N_SEGMENTS = 45; STELLSYM = True # Curve parameters - COIL_CURRENT = 1. # 1.714e7 # Amperes + COIL_CURRENT = 1. # Amperes - curves = CreateEquallySpacedCurves(N_COILS, FOURIER_ORDER, LARGE_R, SMALL_R, n_segments=N_SEGMENTS, nfp=NFP, stellsym=STELLSYM) - coils = Coils(curves=curves, currents=[COIL_CURRENT]*N_COILS) - coils_initial = coils.copy() + init_curves = CreateEquallySpacedCurves(N_COILS, FOURIER_ORDER, LARGE_R, SMALL_R, n_segments=N_SEGMENTS, nfp=NFP, stellsym=STELLSYM) + init_coils = Coils(curves=init_curves, currents=[COIL_CURRENT]*N_COILS) + init_field = BiotSavart(init_coils) surface = SurfaceRZFourier.from_wout_file(vmec_input, s=1, ntheta=30, nphi=30, range_torus='half period') - field = BiotSavart(coils) - """ Setting the losses and their weights """ - LENGTH_WEIGHT = 0.; LENGTH_TARGET = 43. - CURVATURE_WEIGHT = 0.; CURVATURE_TARGET = 0.1 + """ Setting the losses weights and targets """ + LENGTH_WEIGHT = 1.; LENGTH_TARGET = 32. + CURVATURE_WEIGHT = 1.; CURVATURE_TARGET = 0.1 NORMAL_FIELD_WEIGHT = 1. - - L_length = target_loss("coil_length", target=LENGTH_TARGET, mode="max") - L_curvature = target_loss("coil_curvature", target=CURVATURE_TARGET, mode="max") - def loss(field, **kwargs): + """ Creating the loss functions """ + def loss(field, surface): return jnp.sum(jnp.abs(BdotN_over_B(surface, field))) - L_normal_field = custom_loss(loss) + def loss_length(field): + return jnp.mean(jnp.maximum(0, field.coils.length - LENGTH_TARGET)) + + def loss_curvature(field): + return jnp.mean(jnp.maximum(0, field.coils.curvature - CURVATURE_TARGET)) + + """ Defining custom losses """ + L_normal_field = custom_loss(loss, "field", surface=surface) + L_length = custom_loss(loss_length, "field") + L_curvature = custom_loss(loss_curvature, "field") - L_total = NORMAL_FIELD_WEIGHT*L_normal_field #+ LENGTH_WEIGHT*L_length + CURVATURE_WEIGHT*L_curvature + """ Defining total loss + setting dependencies """ + L_total = NORMAL_FIELD_WEIGHT*L_normal_field + LENGTH_WEIGHT*L_length + CURVATURE_WEIGHT*L_curvature + L_total.dependencies = {"field": init_field} - print(L_total.losses) - print(L_total.weights) - - L_total.depends_on = {"coils": coils, "field": field} - print(L_total(dofs=L_total.dofs)) + """ Optimizing the total loss """ + res = least_squares(L_total, L_total.starting_dofs, L_total.grad, verbose=2, ftol=1e-5, gtol=1e-5, xtol=1e-14, max_nfev=200) - res = least_squares(L_total, L_total.dofs, diff_step=1e-4, verbose=2, ftol=1e-5, gtol=1e-5, xtol=1e-14, max_nfev=100) + print("Initial loss:", L_total(L_total.starting_dofs)) + print("Loss after optimization:", L_total(res.x)) - print(L_total(dofs=res.x)) + opt_field = L_total.dofs_to_pytree(res.x)["field"] + opt_coils = opt_field.coils - import matplotlib.pyplot as plt fig = plt.figure(figsize=(8, 4)) + ax1 = fig.add_subplot(121, projection='3d') - ax2 = fig.add_subplot(122, projection='3d') - coils_initial.plot(ax=ax1, show=False) + init_coils.plot(ax=ax1, show=False) surface.plot(ax=ax1, show=False) - coils.plot(ax=ax2, show=False) + ax2 = fig.add_subplot(122, projection='3d') + opt_coils.plot(ax=ax2, show=False) surface.plot(ax=ax2, show=False) plt.tight_layout() plt.show() + EXPORT = False + if EXPORT: + output_filepath = os.path.join(os.path.dirname(__file__), "output") + + """ Save the coils to a json file """ + init_coils.to_json(os.path.join(output_filepath, "init_coils_vmec_surface.json")) + opt_coils.to_json(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) + + """ Save results in vtk format to analyze in Paraview """ + surface.to_vtk(os.path.join(output_filepath, "init_surface_vmec_surface.json"), field=init_field) + surface.to_vtk(os.path.join(output_filepath, "final_surface_vmec_surface.json"), field=opt_field) + init_coils.to_vtk(os.path.join(output_filepath, "init_coils_vmec_surface.json")) + opt_coils.to_vtk(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) \ No newline at end of file From c9efcf721f3110655fd53767fff423fe1b385a82 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Tue, 18 Nov 2025 00:33:45 +0100 Subject: [PATCH 60/63] Clean (losses): Move example to correct file --- essos/losses.py | 96 +++---------------------------------------------- 1 file changed, 5 insertions(+), 91 deletions(-) diff --git a/essos/losses.py b/essos/losses.py index 997ab25..a1f965e 100644 --- a/essos/losses.py +++ b/essos/losses.py @@ -1,16 +1,7 @@ -import os from functools import partial -import jax import jax.numpy as jnp -from jax import tree_util, jit, grad as grad_jax +from jax import tree_util, jit, grad as jax_grad from jax.flatten_util import ravel_pytree - -from essos.coils import Curves, Coils, CreateEquallySpacedCurves -from essos.surfaces import SurfaceRZFourier -from essos.fields import BiotSavart - -from essos.surfaces import BdotN_over_B -from scipy.optimize import least_squares class base_loss: def __init__(self): @@ -85,7 +76,7 @@ def __init__(self, fun, *args_names, **kwargs): @property def starting_dofs(self): if self._starting_dofs is None: - self._starting_dofs, self.dofs_to_pytree = ravel_pytree(tuple(self.dependencies[arg] for arg in self.args_names)) + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(tuple(self.dependencies[arg] for arg in self.args_names)) return self._starting_dofs @property @@ -106,12 +97,12 @@ def call_pytree(self, dofs_pytree) -> float: @partial(jit, static_argnames=['self']) def grad(self, dofs: jnp.ndarray) -> jnp.ndarray: args = self.dofs_to_pytree(dofs) - gradient = grad_jax(self.fun, argnums=tuple(range(len(args))))(*args, **self.kwargs) + gradient = jax_grad(self.fun, argnums=tuple(range(len(args))))(*args, **self.kwargs) return ravel_pytree(gradient)[0] @partial(jit, static_argnames=['self']) def grad_pytree(self, dofs_pytree) -> dict: - gradient = grad_jax(self.fun, argnums=tuple(range(len(dofs_pytree))))(*dofs_pytree, **self.kwargs) + gradient = jax_grad(self.fun, argnums=tuple(range(len(dofs_pytree))))(*dofs_pytree, **self.kwargs) buffer = self.dependencies_buffer.copy() for dep, g in zip(self.args_names, gradient): buffer[dep] = g @@ -125,7 +116,6 @@ def __mul__(self, other): out_loss = custom_loss(new_fun, *self.args_names, **self.kwargs) return out_loss - class composite_loss(base_loss): def __init__(self, losses: list): @@ -179,82 +169,6 @@ def grad(self, dofs: jnp.ndarray) -> jnp.ndarray: grads_each_loss = [loss.grad_pytree(tuple(dependencies[arg] for arg in loss.args_names))\ for loss in self.losses] - grad = jax.tree_util.tree_map(lambda *dofs: jnp.sum(jnp.stack(dofs), axis=0), *grads_each_loss) + grad = tree_util.tree_map(lambda *dofs: jnp.sum(jnp.stack(dofs), axis=0), *grads_each_loss) dofs_grad = ravel_pytree(grad)[0] return dofs_grad - - - - - -if __name__ == "__main__": - import matplotlib.pyplot as plt - - vmec_input = os.path.join(os.path.dirname(__file__), '../examples/input_files/wout_LandremanPaul2021_QA_reactorScale_lowres.nc') - - """ Creating starting coils and surface """ - N_COILS = 3; FOURIER_ORDER = 3; LARGE_R = 10; SMALL_R = 5.6; NFP = 2; N_SEGMENTS = 45; STELLSYM = True # Curve parameters - COIL_CURRENT = 1. # Amperes - - init_curves = CreateEquallySpacedCurves(N_COILS, FOURIER_ORDER, LARGE_R, SMALL_R, n_segments=N_SEGMENTS, nfp=NFP, stellsym=STELLSYM) - init_coils = Coils(curves=init_curves, currents=[COIL_CURRENT]*N_COILS) - init_field = BiotSavart(init_coils) - surface = SurfaceRZFourier.from_wout_file(vmec_input, s=1, ntheta=30, nphi=30, range_torus='half period') - - """ Setting the losses weights and targets """ - LENGTH_WEIGHT = 1.; LENGTH_TARGET = 32. - CURVATURE_WEIGHT = 1.; CURVATURE_TARGET = 0.1 - NORMAL_FIELD_WEIGHT = 1. - - """ Creating the loss functions """ - def loss(field, surface): - return jnp.sum(jnp.abs(BdotN_over_B(surface, field))) - - def loss_length(field): - return jnp.mean(jnp.maximum(0, field.coils.length - LENGTH_TARGET)) - - def loss_curvature(field): - return jnp.mean(jnp.maximum(0, field.coils.curvature - CURVATURE_TARGET)) - - """ Defining custom losses """ - L_normal_field = custom_loss(loss, "field", surface=surface) - L_length = custom_loss(loss_length, "field") - L_curvature = custom_loss(loss_curvature, "field") - - """ Defining total loss + setting dependencies """ - L_total = NORMAL_FIELD_WEIGHT*L_normal_field + LENGTH_WEIGHT*L_length + CURVATURE_WEIGHT*L_curvature - L_total.dependencies = {"field": init_field} - - """ Optimizing the total loss """ - res = least_squares(L_total, L_total.starting_dofs, L_total.grad, verbose=2, ftol=1e-5, gtol=1e-5, xtol=1e-14, max_nfev=200) - - print("Initial loss:", L_total(L_total.starting_dofs)) - print("Loss after optimization:", L_total(res.x)) - - opt_field = L_total.dofs_to_pytree(res.x)["field"] - opt_coils = opt_field.coils - - fig = plt.figure(figsize=(8, 4)) - - ax1 = fig.add_subplot(121, projection='3d') - init_coils.plot(ax=ax1, show=False) - surface.plot(ax=ax1, show=False) - ax2 = fig.add_subplot(122, projection='3d') - opt_coils.plot(ax=ax2, show=False) - surface.plot(ax=ax2, show=False) - plt.tight_layout() - plt.show() - - EXPORT = False - if EXPORT: - output_filepath = os.path.join(os.path.dirname(__file__), "output") - - """ Save the coils to a json file """ - init_coils.to_json(os.path.join(output_filepath, "init_coils_vmec_surface.json")) - opt_coils.to_json(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) - - """ Save results in vtk format to analyze in Paraview """ - surface.to_vtk(os.path.join(output_filepath, "init_surface_vmec_surface.json"), field=init_field) - surface.to_vtk(os.path.join(output_filepath, "final_surface_vmec_surface.json"), field=opt_field) - init_coils.to_vtk(os.path.join(output_filepath, "init_coils_vmec_surface.json")) - opt_coils.to_vtk(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) \ No newline at end of file From 41a0c1a530e6fa73daa95c418fd7a62af02e48a0 Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Tue, 18 Nov 2025 00:34:22 +0100 Subject: [PATCH 61/63] Refactor (example): Simplify coils_from_vmec_surface example with new logic --- examples/optimize_coils_vmec_surface.py | 145 ++++++++++++------------ 1 file changed, 75 insertions(+), 70 deletions(-) diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py index 16aa6ec..b10aab6 100644 --- a/examples/optimize_coils_vmec_surface.py +++ b/examples/optimize_coils_vmec_surface.py @@ -1,81 +1,86 @@ import os -number_of_processors_to_use = 5 # Parallelization, this should divide ntheta*nphi -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time import jax.numpy as jnp import matplotlib.pyplot as plt -from essos.surfaces import BdotN_over_B + from essos.coils import Coils, CreateEquallySpacedCurves -from essos.fields import Vmec, BiotSavart -from essos.objective_functions import loss_BdotN -from essos.optimization import optimize_loss_function - -# Optimization parameters -max_coil_length = 10 -max_coil_curvature = 1.0 -order_Fourier_series_coils = 3 -number_coil_points = order_Fourier_series_coils*15 -maximum_function_evaluations = 50 -number_coils_per_half_field_period = 3 -tolerance_optimization = 1e-5 -ntheta=35 -nphi=35 - -# Initialize VMEC field -vmec = Vmec(os.path.join(os.path.dirname(__file__), 'input_files', - 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), - ntheta=ntheta, nphi=nphi, range_torus='half period') - -# Initialize coils -current_on_each_coil = 1 -number_of_field_periods = vmec.nfp -major_radius_coils = vmec.r_axis -minor_radius_coils = vmec.r_axis/1.8 -curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) -print(coils_initial.dofs_curves.shape) -# Optimize coils -print(f'Optimizing coils with {maximum_function_evaluations} function evaluations.') -time0 = time() -coils_optimized = optimize_loss_function(loss_BdotN, initial_dofs=coils_initial.x, coils=coils_initial, tolerance_optimization=tolerance_optimization, - maximum_function_evaluations=maximum_function_evaluations, vmec=vmec, - max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,) -print(f"Optimization took {time()-time0:.2f} seconds") - - -BdotN_over_B_initial = BdotN_over_B(vmec.surface, BiotSavart(coils_initial)) -BdotN_over_B_optimized = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized)) -curvature=jnp.mean(BiotSavart(coils_optimized).coils.curvature, axis=1) -length=jnp.max(jnp.ravel(BiotSavart(coils_optimized).coils.length)) -print(f"Mean curvature: ",curvature) -print(f"Length:", length) -print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") -print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") - -# Plot coils, before and after optimization +from essos.fields import BiotSavart +from essos.surfaces import SurfaceRZFourier, BdotN_over_B +from essos.losses import custom_loss + +# In this exmple, `scipy.optimize.least_squares` is used, but any other optimizer, e.g. from +# `scipy.optimize.minimize` or `jaxopt`, can be used as well and may even be preferable. +from scipy.optimize import least_squares + +input_filepath = os.path.join(os.path.dirname(__file__), "input_files") +vmec_input = os.path.join(input_filepath, 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') + +""" Creating starting coils and surface """ +N_COILS = 3; FOURIER_ORDER = 3; LARGE_R = 10; SMALL_R = 5.6; NFP = 2; N_SEGMENTS = 45; STELLSYM = True # Curve parameters +COIL_CURRENT = 1. # Amperes (optimization does not depend on current magnitude) + +init_curves = CreateEquallySpacedCurves(N_COILS, FOURIER_ORDER, LARGE_R, SMALL_R, n_segments=N_SEGMENTS, nfp=NFP, stellsym=STELLSYM) +init_coils = Coils(curves=init_curves, currents=[COIL_CURRENT]*N_COILS) +init_field = BiotSavart(init_coils) +surface = SurfaceRZFourier.from_wout_file(vmec_input, s=1, ntheta=30, nphi=30, range_torus='half period') + +""" Setting the losses weights and targets """ +LENGTH_WEIGHT = 1.; LENGTH_TARGET = 32. +CURVATURE_WEIGHT = 1.; CURVATURE_TARGET = 0.1 +NORMAL_FIELD_WEIGHT = 1. + +""" Creating the loss functions """ +def loss(field, surface): + return jnp.sum(jnp.abs(BdotN_over_B(surface, field))) + +def loss_length(field): + return jnp.mean(jnp.maximum(0, field.coils.length - LENGTH_TARGET)) + +def loss_curvature(field): + return jnp.mean(jnp.maximum(0, field.coils.curvature - CURVATURE_TARGET)) + +""" Defining custom losses """ +L_normal_field = custom_loss(loss, "field", surface=surface) +L_length = custom_loss(loss_length, "field") +L_curvature = custom_loss(loss_curvature, "field") + +""" Defining total loss + setting dependencies """ +L_total = NORMAL_FIELD_WEIGHT*L_normal_field + LENGTH_WEIGHT*L_length + CURVATURE_WEIGHT*L_curvature +L_total.dependencies = {"field": init_field} + +""" Optimizing the total loss """ +t_start = time() +res = least_squares(L_total, L_total.starting_dofs, L_total.grad, verbose=2, ftol=1e-5, gtol=1e-5, xtol=1e-14, max_nfev=200) +t_end = time() + +print(f"\nOptimization took {t_end - t_start:.2f} seconds") +print("Initial loss:", L_total(L_total.starting_dofs)) +print("Loss after optimization:", L_total(res.x)) + +opt_field = L_total.dofs_to_pytree(res.x)["field"] +opt_coils = opt_field.coils + fig = plt.figure(figsize=(8, 4)) + ax1 = fig.add_subplot(121, projection='3d') +init_coils.plot(ax=ax1, show=False) +surface.plot(ax=ax1, show=False) ax2 = fig.add_subplot(122, projection='3d') -coils_initial.plot(ax=ax1, show=False) -vmec.surface.plot(ax=ax1, show=False) -coils_optimized.plot(ax=ax2, show=False) -vmec.surface.plot(ax=ax2, show=False) +opt_coils.plot(ax=ax2, show=False) +surface.plot(ax=ax2, show=False) plt.tight_layout() plt.show() -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") - -# # Save results in vtk format to analyze in Paraview -# from essos.fields import BiotSavart -# vmec.surface.to_vtk('surface_initial', field=BiotSavart(coils_initial)) -# vmec.surface.to_vtk('surface_final', field=BiotSavart(coils_optimized)) -# coils_initial.to_vtk('coils_initial') -# coils_optimized.to_vtk('coils_optimized') \ No newline at end of file +EXPORT = False +if EXPORT: + output_filepath = os.path.join(os.path.dirname(__file__), "output") + + """ Save the coils to a json file """ + init_coils.to_json(os.path.join(output_filepath, "init_coils_vmec_surface.json")) + opt_coils.to_json(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) + + """ Save results in vtk format to analyze in Paraview """ + surface.to_vtk(os.path.join(output_filepath, "init_surface_vmec_surface.json"), field=init_field) + surface.to_vtk(os.path.join(output_filepath, "final_surface_vmec_surface.json"), field=opt_field) + init_coils.to_vtk(os.path.join(output_filepath, "init_coils_vmec_surface.json")) + opt_coils.to_vtk(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) \ No newline at end of file From 2324cfe969d7f0327d953fd78a5d08dc7f3f129b Mon Sep 17 00:00:00 2001 From: EstevaoMGomes Date: Tue, 18 Nov 2025 00:55:54 +0100 Subject: [PATCH 62/63] Refactor (coils): Decache gamma calculation --- essos/coils.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/essos/coils.py b/essos/coils.py index 104b71d..c329640 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -1,7 +1,6 @@ import jax jax.config.update("jax_enable_x64", True) import jax.numpy as jnp -from jax.lax import fori_loop from jax import tree_util, jit, vmap from functools import partial from .plot import fix_matplotlib_3d @@ -132,12 +131,11 @@ def create_data(order: int) -> jnp.ndarray: gamma_n = vmap(create_data)(jnp.arange(1, self.order+1)) return gamma_0 + jnp.sum(gamma_n, axis=0) + # TODO change gamma from a property to a method # gamma property @property def gamma(self): - if self._gamma is None: - self._gamma = self._compute_gamma() - return self._gamma + return self._compute_gamma() # _compute_gamma_dash method @jit @@ -145,16 +143,13 @@ def _compute_gamma_dash(self): def create_data(order: int) -> jnp.ndarray: return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], 2*jnp.pi * order * jnp.cos(2 * jnp.pi * order * self.quadpoints)) \ + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], -2 * jnp.pi * order * jnp.sin(2 * jnp.pi * order * self.quadpoints)) - gamma_dash_0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) gamma_dash_n = vmap(create_data)(jnp.arange(1, self.order+1)) - return gamma_dash_0 + jnp.sum(gamma_dash_n, axis=0) + return jnp.sum(gamma_dash_n, axis=0) # gamma_dash property @property def gamma_dash(self): - if self._gamma_dash is None: - self._gamma_dash = self._compute_gamma_dash() - return self._gamma_dash + return self._compute_gamma_dash() # _compute_gamma_dashdash method @jit @@ -162,16 +157,13 @@ def _compute_gamma_dashdash(self): def create_data(order: int) -> jnp.ndarray: return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], -4*jnp.pi**2 * order**2 * jnp.sin(2 * jnp.pi * order * self.quadpoints)) \ + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], -4*jnp.pi**2 * order**2 * jnp.cos(2 * jnp.pi * order * self.quadpoints)) - gamma_dashdash_0 = jnp.zeros((jnp.size(self.curves, 0), self.n_segments, 3)) gamma_dashdash_n = vmap(create_data)(jnp.arange(1, self.order+1)) - return gamma_dashdash_0 + jnp.sum(gamma_dashdash_n, axis=0) + return jnp.sum(gamma_dashdash_n, axis=0) # gamma_dashdash property @property def gamma_dashdash(self): - if self._gamma_dashdash is None: - self._gamma_dashdash = self._compute_gamma_dashdash() - return self._gamma_dashdash + return self._compute_gamma_dashdash() # length property @property @@ -189,10 +181,13 @@ def compute_curvature(gammadash, gammadashdash): # curvature property @property def curvature(self): - if self._curvature is None: - self._curvature = vmap(self.compute_curvature)(self.gamma_dash, self.gamma_dashdash) - return self._curvature + return vmap(self.compute_curvature)(self.gamma_dash, self.gamma_dashdash) + # copy method + def copy(self): + deep_copy = tree_util.tree_map(lambda x: x.copy(), self) + return deep_copy + # magic methods def __str__(self): return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ @@ -360,6 +355,7 @@ def _tree_unflatten(cls, aux_data, children): Curves._tree_flatten, Curves._tree_unflatten) +# TODO: change currents logic: save dofs_currents as dynamic -> alter main class Coils: """ Class to store the coils @@ -517,8 +513,18 @@ def n_segments(self): def n_segments(self, new_n_segments): self.curves.n_segments = new_n_segments - # magic methods + # copy method + def copy(self): + coils = Coils(self.curves.copy(), self.dofs_currents_raw.copy()) + + # Initialize caches + coils._dofs_currents = self.dofs_currents + coils._currents_scale = self.currents_scale + coils._currents = self._currents + return coils + + # magic methods def __str__(self): return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ From a2af497e4f3b9adc36c6e0aaf0cffd17d89a0ea1 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Tue, 9 Dec 2025 09:59:46 -0600 Subject: [PATCH 63/63] refactored examples folder --- .../optimize_coils_and_nearaxis.py | 0 .../optimize_coils_and_surface.py | 0 .../optimize_coils_for_nearaxis.py | 0 ...ze_coils_particle_confinement_fullorbit.py | 0 ...particle_confinement_guidingcenter_adam.py | 0 ...ment_guidingcenter_augmented_lagrangian.py | 0 ...article_confinement_guidingcenter_lbfgs.py | 0 ...ment_loss_fraction_augmented_lagrangian.py | 0 .../optimize_coils_vmec_surface.py | 0 ...coils_vmec_surface_augmented_lagrangian.py | 0 ...surface_augmented_lagrangian_stochastic.py | 0 .../optimize_multiple_objectives.py | 0 examples/compare_guidingcenter_fullorbit.py | 95 ------------------- .../comparisons_simsopt/coils.py | 0 .../comparisons_simsopt/field_lines.py | 0 .../comparisons_simsopt/full_orbit.py | 0 .../comparisons_simsopt/guiding_center.py | 0 .../comparisons_simsopt/losses.py | 0 .../comparisons_simsopt/surfaces.py | 0 .../comparisons_simsopt/vmec_import.py | 0 .../poincare_guiding_center_coils.py | 0 .../trace_fieldlines_coils.py | 0 .../trace_fieldlines_vmec.py | 0 .../paper}/fo_integrators.py | 1 - .../paper}/gc_integrators.py | 1 - {analysis => examples/paper}/gc_vs_fo.py | 10 +- {analysis => examples/paper}/gradients.py | 1 - .../paper}/poincare_plots.py | 8 +- .../trace_particles_coils_fullorbit.py | 0 .../trace_particles_coils_guidingcenter.py | 0 ...les_coils_guidingcenter_with_classifier.py | 0 ...gcenter_with_classifier_scaled_currents.py | 0 .../trace_particles_vmec.py | 0 .../trace_particles_vmec_Electric_field.py | 0 ...s_velocity_distributions_mu_Adaptative.py} | 0 ...isions_velocity_distributions_mu_Fixed.py} | 0 ...lisions_velocity_distributions_mu_time.py} | 0 ...enter_with_classifier_with_collisionsMu.py | 0 .../trace_particles_vmec_collisionsMu.py | 0 .../create_perturbed_coils.py | 0 .../create_stellarator_coils.py | 0 41 files changed, 9 insertions(+), 107 deletions(-) rename examples/{ => coil_optimization}/optimize_coils_and_nearaxis.py (100%) rename examples/{ => coil_optimization}/optimize_coils_and_surface.py (100%) rename examples/{ => coil_optimization}/optimize_coils_for_nearaxis.py (100%) rename examples/{ => coil_optimization}/optimize_coils_particle_confinement_fullorbit.py (100%) rename examples/{ => coil_optimization}/optimize_coils_particle_confinement_guidingcenter_adam.py (100%) rename examples/{ => coil_optimization}/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py (100%) rename examples/{ => coil_optimization}/optimize_coils_particle_confinement_guidingcenter_lbfgs.py (100%) rename examples/{ => coil_optimization}/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py (100%) rename examples/{ => coil_optimization}/optimize_coils_vmec_surface.py (100%) rename examples/{ => coil_optimization}/optimize_coils_vmec_surface_augmented_lagrangian.py (100%) rename examples/{ => coil_optimization}/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py (100%) rename examples/{ => coil_optimization}/optimize_multiple_objectives.py (100%) delete mode 100644 examples/compare_guidingcenter_fullorbit.py rename {analysis => examples}/comparisons_simsopt/coils.py (100%) rename {analysis => examples}/comparisons_simsopt/field_lines.py (100%) rename {analysis => examples}/comparisons_simsopt/full_orbit.py (100%) rename {analysis => examples}/comparisons_simsopt/guiding_center.py (100%) rename {analysis => examples}/comparisons_simsopt/losses.py (100%) rename {analysis => examples}/comparisons_simsopt/surfaces.py (100%) rename {analysis => examples}/comparisons_simsopt/vmec_import.py (100%) rename examples/{ => fieldline_tracing}/poincare_guiding_center_coils.py (100%) rename examples/{ => fieldline_tracing}/trace_fieldlines_coils.py (100%) rename examples/{ => fieldline_tracing}/trace_fieldlines_vmec.py (100%) rename {analysis => examples/paper}/fo_integrators.py (97%) rename {analysis => examples/paper}/gc_integrators.py (97%) rename {analysis => examples/paper}/gc_vs_fo.py (85%) rename {analysis => examples/paper}/gradients.py (97%) rename {analysis => examples/paper}/poincare_plots.py (90%) rename examples/{ => particle_tracing}/trace_particles_coils_fullorbit.py (100%) rename examples/{ => particle_tracing}/trace_particles_coils_guidingcenter.py (100%) rename examples/{ => particle_tracing}/trace_particles_coils_guidingcenter_with_classifier.py (100%) rename examples/{ => particle_tracing}/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py (100%) rename examples/{ => particle_tracing}/trace_particles_vmec.py (100%) rename examples/{ => particle_tracing}/trace_particles_vmec_Electric_field.py (100%) rename examples/{testing_collisions_velocity_distributions_mu_Adaptative.py => particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Adaptative.py} (100%) rename examples/{testing_collisions_velocity_distributions_mu_Fixed.py => particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Fixed.py} (100%) rename examples/{testing_collisions_velocity_distributions_mu_time.py => particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_time.py} (100%) rename examples/{ => particle_tracing_collisions}/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py (100%) rename examples/{ => particle_tracing_collisions}/trace_particles_vmec_collisionsMu.py (100%) rename examples/{ => simple_examples}/create_perturbed_coils.py (100%) rename examples/{ => simple_examples}/create_stellarator_coils.py (100%) diff --git a/examples/optimize_coils_and_nearaxis.py b/examples/coil_optimization/optimize_coils_and_nearaxis.py similarity index 100% rename from examples/optimize_coils_and_nearaxis.py rename to examples/coil_optimization/optimize_coils_and_nearaxis.py diff --git a/examples/optimize_coils_and_surface.py b/examples/coil_optimization/optimize_coils_and_surface.py similarity index 100% rename from examples/optimize_coils_and_surface.py rename to examples/coil_optimization/optimize_coils_and_surface.py diff --git a/examples/optimize_coils_for_nearaxis.py b/examples/coil_optimization/optimize_coils_for_nearaxis.py similarity index 100% rename from examples/optimize_coils_for_nearaxis.py rename to examples/coil_optimization/optimize_coils_for_nearaxis.py diff --git a/examples/optimize_coils_particle_confinement_fullorbit.py b/examples/coil_optimization/optimize_coils_particle_confinement_fullorbit.py similarity index 100% rename from examples/optimize_coils_particle_confinement_fullorbit.py rename to examples/coil_optimization/optimize_coils_particle_confinement_fullorbit.py diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_adam.py b/examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_adam.py similarity index 100% rename from examples/optimize_coils_particle_confinement_guidingcenter_adam.py rename to examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_adam.py diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py b/examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py similarity index 100% rename from examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py rename to examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py b/examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_lbfgs.py similarity index 100% rename from examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py rename to examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_lbfgs.py diff --git a/examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py b/examples/coil_optimization/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py similarity index 100% rename from examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py rename to examples/coil_optimization/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py diff --git a/examples/optimize_coils_vmec_surface.py b/examples/coil_optimization/optimize_coils_vmec_surface.py similarity index 100% rename from examples/optimize_coils_vmec_surface.py rename to examples/coil_optimization/optimize_coils_vmec_surface.py diff --git a/examples/optimize_coils_vmec_surface_augmented_lagrangian.py b/examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian.py similarity index 100% rename from examples/optimize_coils_vmec_surface_augmented_lagrangian.py rename to examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian.py diff --git a/examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py b/examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py similarity index 100% rename from examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py rename to examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py diff --git a/examples/optimize_multiple_objectives.py b/examples/coil_optimization/optimize_multiple_objectives.py similarity index 100% rename from examples/optimize_multiple_objectives.py rename to examples/coil_optimization/optimize_multiple_objectives.py diff --git a/examples/compare_guidingcenter_fullorbit.py b/examples/compare_guidingcenter_fullorbit.py deleted file mode 100644 index e063299..0000000 --- a/examples/compare_guidingcenter_fullorbit.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -number_of_processors_to_use = 1 # Parallelization, this should divide nparticles -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from jax import vmap -from time import time -import jax.numpy as jnp -import matplotlib.pyplot as plt -from essos.fields import BiotSavart -from essos.coils import Coils_from_json -from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE -from essos.dynamics import Tracing, Particles -from jax import block_until_ready - -# Load coils and field -json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) -field = BiotSavart(coils) - -# Particle parameters -nparticles = number_of_processors_to_use -mass=PROTON_MASS -energy=5000*ONE_EV -cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass -print("cyclotron period:", 1/cyclotron_frequency) - -# Particles initialization -initial_xyz=jnp.array([[1.23, 0, 0]]) - -particles_passing = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.1], phase_angle_full_orbit=0) -particles_traped = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.9], phase_angle_full_orbit=0) -particles = particles_passing.join(particles_traped, field=field) - -# Tracing parameters -tmax = 1e-4 -trace_tolerance = 1e-15 -dt_gc = 1e-7 -dt_fo = 1e-9 -num_steps_gc = int(tmax/dt_gc) -num_steps_fo = int(tmax/dt_fo) - -# Trace in ESSOS -time0 = time() -tracing_guidingcenter = Tracing(field=field, model='GuidingCenterAdaptative', particles=particles, - maxtime=tmax,times_to_trace=num_steps_gc, atol=trace_tolerance,rtol=trace_tolerance) -trajectories_guidingcenter = block_until_ready(tracing_guidingcenter.trajectories) -print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") - -time0 = time() -tracing_fullorbit = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax, - timesteps=num_steps_fo, tol_step_size=trace_tolerance) -block_until_ready(tracing_fullorbit.trajectories) -print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") - -# Plot trajectories, velocity parallel to the magnetic field, and energy error -fig = plt.figure(figsize=(9, 8)) -ax1 = fig.add_subplot(221, projection='3d') -ax2 = fig.add_subplot(222) -ax3 = fig.add_subplot(223) -ax4 = fig.add_subplot(224) - -coils.plot(ax=ax1, show=False) -tracing_guidingcenter.plot(ax=ax1, show=False) -tracing_fullorbit.plot(ax=ax1, show=False) - -for i, (trajectory_gc, trajectory_fo) in enumerate(zip(trajectories_guidingcenter, tracing_fullorbit.trajectories)): - ax2.plot(tracing_guidingcenter.times, jnp.abs(tracing_guidingcenter.energy[i]-particles.energy)/particles.energy, '-', label=f'Particle {i+1} GC', linewidth=1.0, alpha=0.7) - ax2.plot(tracing_fullorbit.times, jnp.abs(tracing_fullorbit.energy[i]-particles.energy)/particles.energy, '--', label=f'Particle {i+1} FO', linewidth=1.0, markersize=0.5, alpha=0.7) - def compute_v_parallel(trajectory_t): - magnetic_field_unit_vector = field.B(trajectory_t[:3]) / field.AbsB(trajectory_t[:3]) - return jnp.dot(trajectory_t[3:], magnetic_field_unit_vector) - v_parallel_fo = vmap(compute_v_parallel)(trajectory_fo) - ax3.plot(tracing_guidingcenter.times, trajectory_gc[:, 3] / particles.total_speed, '-', label=f'Particle {i+1} GC', linewidth=1.1, alpha=0.95) - ax3.plot(tracing_fullorbit.times, v_parallel_fo / particles.total_speed, '--', label=f'Particle {i+1} FO', linewidth=0.5, markersize=0.5, alpha=0.2) - # ax4.plot(jnp.sqrt(trajectory_gc[:,0]**2+trajectory_gc[:,1]**2), trajectory_gc[:, 2], '-', label=f'Particle {i+1} GC', linewidth=1.5, alpha=0.3) - # ax4.plot(jnp.sqrt(trajectory_fo[:,0]**2+trajectory_fo[:,1]**2), trajectory_fo[:, 2], '--', label=f'Particle {i+1} FO', linewidth=1.5, markersize=0.5, alpha=0.2) -tracing_guidingcenter.poincare_plot(ax=ax4, show=False, color='k', label=f'GC', shifts=[jnp.pi/2])#, 0]) -tracing_fullorbit.poincare_plot( ax=ax4, show=False, color='r', label=f'FO', shifts=[jnp.pi/2])#, 0]) - -ax2.set_xlabel('Time (s)') -ax2.set_ylabel('Relative Energy Error') -ax3.set_ylabel(r'$v_{\parallel}/v$') -ax2.legend(loc='upper right') -ax3.set_xlabel('Time (s)') -ax3.legend(loc='upper right') -ax4.set_xlabel('R (m)') -ax4.set_ylabel('Z (m)') -ax4.legend(loc='upper right') -plt.tight_layout() -plt.show() - - -## Save results in vtk format to analyze in Paraview -tracing_guidingcenter.to_vtk('trajectories_gc') -tracing_fullorbit.to_vtk('trajectories_fo') -coils.to_vtk('coils') \ No newline at end of file diff --git a/analysis/comparisons_simsopt/coils.py b/examples/comparisons_simsopt/coils.py similarity index 100% rename from analysis/comparisons_simsopt/coils.py rename to examples/comparisons_simsopt/coils.py diff --git a/analysis/comparisons_simsopt/field_lines.py b/examples/comparisons_simsopt/field_lines.py similarity index 100% rename from analysis/comparisons_simsopt/field_lines.py rename to examples/comparisons_simsopt/field_lines.py diff --git a/analysis/comparisons_simsopt/full_orbit.py b/examples/comparisons_simsopt/full_orbit.py similarity index 100% rename from analysis/comparisons_simsopt/full_orbit.py rename to examples/comparisons_simsopt/full_orbit.py diff --git a/analysis/comparisons_simsopt/guiding_center.py b/examples/comparisons_simsopt/guiding_center.py similarity index 100% rename from analysis/comparisons_simsopt/guiding_center.py rename to examples/comparisons_simsopt/guiding_center.py diff --git a/analysis/comparisons_simsopt/losses.py b/examples/comparisons_simsopt/losses.py similarity index 100% rename from analysis/comparisons_simsopt/losses.py rename to examples/comparisons_simsopt/losses.py diff --git a/analysis/comparisons_simsopt/surfaces.py b/examples/comparisons_simsopt/surfaces.py similarity index 100% rename from analysis/comparisons_simsopt/surfaces.py rename to examples/comparisons_simsopt/surfaces.py diff --git a/analysis/comparisons_simsopt/vmec_import.py b/examples/comparisons_simsopt/vmec_import.py similarity index 100% rename from analysis/comparisons_simsopt/vmec_import.py rename to examples/comparisons_simsopt/vmec_import.py diff --git a/examples/poincare_guiding_center_coils.py b/examples/fieldline_tracing/poincare_guiding_center_coils.py similarity index 100% rename from examples/poincare_guiding_center_coils.py rename to examples/fieldline_tracing/poincare_guiding_center_coils.py diff --git a/examples/trace_fieldlines_coils.py b/examples/fieldline_tracing/trace_fieldlines_coils.py similarity index 100% rename from examples/trace_fieldlines_coils.py rename to examples/fieldline_tracing/trace_fieldlines_coils.py diff --git a/examples/trace_fieldlines_vmec.py b/examples/fieldline_tracing/trace_fieldlines_vmec.py similarity index 100% rename from examples/trace_fieldlines_vmec.py rename to examples/fieldline_tracing/trace_fieldlines_vmec.py diff --git a/analysis/fo_integrators.py b/examples/paper/fo_integrators.py similarity index 97% rename from analysis/fo_integrators.py rename to examples/paper/fo_integrators.py index d7b3783..1a01571 100644 --- a/analysis/fo_integrators.py +++ b/examples/paper/fo_integrators.py @@ -85,7 +85,6 @@ plt.grid(axis='y', which='major', linestyle='--', linewidth=0.6) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'fo_integration.pdf')) -plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'fo_integration.pdf')) plt.show() ## Save results in vtk format to analyze in Paraview diff --git a/analysis/gc_integrators.py b/examples/paper/gc_integrators.py similarity index 97% rename from analysis/gc_integrators.py rename to examples/paper/gc_integrators.py index 4c5c354..f18e3c4 100644 --- a/analysis/gc_integrators.py +++ b/examples/paper/gc_integrators.py @@ -95,7 +95,6 @@ spine.set_zorder(0) fig.savefig(os.path.join(output_dir, 'gc_integration.pdf')) -fig.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/", 'gc_integration.pdf')) fig_tol.savefig(os.path.join(output_dir, 'energy_vs_tol.pdf')) plt.show() diff --git a/analysis/gc_vs_fo.py b/examples/paper/gc_vs_fo.py similarity index 85% rename from analysis/gc_vs_fo.py rename to examples/paper/gc_vs_fo.py index 6a8c03f..216a9b2 100644 --- a/analysis/gc_vs_fo.py +++ b/examples/paper/gc_vs_fo.py @@ -17,7 +17,7 @@ os.makedirs(output_dir) # Load coils and field -json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +json_file = os.path.join(os.path.dirname(__file__), '../input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils.from_json(json_file) field = BiotSavart(coils) @@ -68,16 +68,16 @@ plt.tight_layout() plt.figure(figsize=(9, 6)) -plt.plot(tracing_gc.times*1000, jnp.abs(tracing_gc.energy()[0]/particles.energy-1), label='Guiding Center', color='red') -plt.plot(tracing_fo.times*1000, jnp.abs(tracing_fo.energy()[0]/particles.energy-1), label='Full Orbit', color='blue') +plt.plot(tracing_gc.times[1:]*1000, jnp.abs(tracing_gc.energy()[0][1:]/particles.energy-1)+1e-17, label='Guiding Center', color='red') +plt.plot(tracing_fo.times[1:]*1000, jnp.abs(tracing_fo.energy()[0][1:]/particles.energy-1)+1e-17, label='Full Orbit', color='blue') plt.xlabel('Time (ms)') plt.ylabel('Relative energy error') plt.xlim(0, tmax*1000) -plt.ylim(bottom=0) +# plt.ylim(bottom=0) +plt.yscale('log') plt.legend() plt.tight_layout() plt.savefig(os.path.join(output_dir, 'energies.png'), dpi=300) -plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" ,'energies.png'), dpi=300) plt.show() diff --git a/analysis/gradients.py b/examples/paper/gradients.py similarity index 97% rename from analysis/gradients.py rename to examples/paper/gradients.py index 4fb04fe..f0eab35 100644 --- a/analysis/gradients.py +++ b/examples/paper/gradients.py @@ -121,5 +121,4 @@ spine.set_zorder(0) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'gradients.pdf')) -plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" ,'gradients.pdf')) plt.show() \ No newline at end of file diff --git a/analysis/poincare_plots.py b/examples/paper/poincare_plots.py similarity index 90% rename from analysis/poincare_plots.py rename to examples/paper/poincare_plots.py index fc878c5..95d2724 100644 --- a/analysis/poincare_plots.py +++ b/examples/paper/poincare_plots.py @@ -34,7 +34,7 @@ print("cyclotron period:", 1/(ELEMENTARY_CHARGE*0.3/mass)) # Load coils and field -json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +json_file = os.path.join(os.path.dirname(__file__), '../input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') coils = Coils.from_json(json_file) field = BiotSavart(coils) @@ -101,7 +101,7 @@ # plt.grid(visible=False) # plt.tight_layout() # plt.savefig(os.path.join(output_dir, 'poincare_plot_fl.png'), dpi=300) -# plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_fl.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fl.png'), dpi=300) # fig, ax = plt.subplots(figsize=(9, 6)) @@ -115,7 +115,7 @@ # plt.grid(visible=False) # plt.tight_layout() # plt.savefig(os.path.join(output_dir 'poincare_plot_fo.png'), dpi=300) -# plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_fo.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fo.png'), dpi=300) # fig, ax = plt.subplots(figsize=(9, 6)) @@ -129,6 +129,6 @@ # plt.grid(visible=False) # plt.tight_layout() # plt.savefig(os.path.join(output_dir, 'poincare_plot_gc.png'), dpi=300) -# plt.savefig(os.path.join(os.path.dirname(__file__), "../../../../UW/article/figures/" , 'poincare_plot_gc.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_gc.png'), dpi=300) # plt.show() \ No newline at end of file diff --git a/examples/trace_particles_coils_fullorbit.py b/examples/particle_tracing/trace_particles_coils_fullorbit.py similarity index 100% rename from examples/trace_particles_coils_fullorbit.py rename to examples/particle_tracing/trace_particles_coils_fullorbit.py diff --git a/examples/trace_particles_coils_guidingcenter.py b/examples/particle_tracing/trace_particles_coils_guidingcenter.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter.py rename to examples/particle_tracing/trace_particles_coils_guidingcenter.py diff --git a/examples/trace_particles_coils_guidingcenter_with_classifier.py b/examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter_with_classifier.py rename to examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier.py diff --git a/examples/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py b/examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py rename to examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py diff --git a/examples/trace_particles_vmec.py b/examples/particle_tracing/trace_particles_vmec.py similarity index 100% rename from examples/trace_particles_vmec.py rename to examples/particle_tracing/trace_particles_vmec.py diff --git a/examples/trace_particles_vmec_Electric_field.py b/examples/particle_tracing/trace_particles_vmec_Electric_field.py similarity index 100% rename from examples/trace_particles_vmec_Electric_field.py rename to examples/particle_tracing/trace_particles_vmec_Electric_field.py diff --git a/examples/testing_collisions_velocity_distributions_mu_Adaptative.py b/examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Adaptative.py similarity index 100% rename from examples/testing_collisions_velocity_distributions_mu_Adaptative.py rename to examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Adaptative.py diff --git a/examples/testing_collisions_velocity_distributions_mu_Fixed.py b/examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Fixed.py similarity index 100% rename from examples/testing_collisions_velocity_distributions_mu_Fixed.py rename to examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Fixed.py diff --git a/examples/testing_collisions_velocity_distributions_mu_time.py b/examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_time.py similarity index 100% rename from examples/testing_collisions_velocity_distributions_mu_time.py rename to examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_time.py diff --git a/examples/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py b/examples/particle_tracing_collisions/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py rename to examples/particle_tracing_collisions/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py diff --git a/examples/trace_particles_vmec_collisionsMu.py b/examples/particle_tracing_collisions/trace_particles_vmec_collisionsMu.py similarity index 100% rename from examples/trace_particles_vmec_collisionsMu.py rename to examples/particle_tracing_collisions/trace_particles_vmec_collisionsMu.py diff --git a/examples/create_perturbed_coils.py b/examples/simple_examples/create_perturbed_coils.py similarity index 100% rename from examples/create_perturbed_coils.py rename to examples/simple_examples/create_perturbed_coils.py diff --git a/examples/create_stellarator_coils.py b/examples/simple_examples/create_stellarator_coils.py similarity index 100% rename from examples/create_stellarator_coils.py rename to examples/simple_examples/create_stellarator_coils.py