Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,5 @@ cython_debug/
*.bsp
outputs/**/*.json
*.pdf
*.npz
*.npz
outputs/**/*.txt
8 changes: 7 additions & 1 deletion qlipper/cli/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ def main():
parser.add_argument(
"--folder", type=str, help="Folder containing the run.", required=False
)
parser.add_argument(
"--show",
action="store_true",
help="Whether to show the plots, default False.",
required=False,
)

args = parser.parse_args()

Expand All @@ -18,7 +24,7 @@ def main():

args.folder = askdirectory(initialdir=OUTPUT_DIR, title="Select case folder")

postprocess_from_folder(Path(args.folder))
postprocess_from_folder(Path(args.folder), args.show)


if __name__ == "__main__":
Expand Down
21 changes: 15 additions & 6 deletions qlipper/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,31 @@ class SimConfig:
steering_law: str
t_span: tuple[float, float]
conv_tol: float
w_oe: Array
w_penalty: float
penalty_function: str
kappa: float
dynamics: str
earth_w_oe: Array
earth_penalty_weight: float
earth_penalty_scaling: float
earth_rp_min: float
moon_w_oe: Array
moon_penalty_weight: float
moon_penalty_scaling: float
moon_rp_min: float
perturbations: list[str]
characteristic_accel: float
epoch_jd: float
ephemeris: str = "real"

def __post_init__(self):
# Validate the configuration
assert len(self.y0) == 6, "Initial state vector must have length 6"
assert len(self.y_target) == 5, "Target state vector must have length 5"

# enforce array types
for key in ["y0", "y_target", "w_oe"]:
for key in [
"y0",
"y_target",
"earth_w_oe",
"moon_w_oe",
]:
object.__setattr__(
self, key, jnp.array(getattr(self, key), dtype=jnp.float64)
)
Expand Down
7 changes: 5 additions & 2 deletions qlipper/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
MU = 3.986004415e14 # Gravitational parameter of Earth [m^3/s^2]
MU_EARTH = 3.986004415e14 # Gravitational parameter of Earth [m^3/s^2]
MU_MOON = 4.9048695e12 # Gravitational parameter of Moon [m^3/s^2]
MU_SUN = 1.32712440018e20 # Gravitational parameter of Sun [m^3/s^2]
R_EARTH = 6378e3 # Earth radius [m]
P_SCALING = 6378e3 # Scaling factor for p [m]
R_MOON = 1737e3 # Moon radius [m]
LENGTH_SCALING = 6378e3 # Scaling factor for p [m]
OUTPUT_DIR = "outputs"
CFG_DIR = "configs"

Expand Down
110 changes: 88 additions & 22 deletions qlipper/converters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from qlipper.constants import MU


@jax.jit
def lvlh_to_steering(dir_lvlh: ArrayLike) -> tuple[float, float]:
Expand Down Expand Up @@ -54,15 +54,18 @@ def steering_to_lvlh(alpha: float, beta: float) -> jax.Array:
return dir_lvlh


@jax.jit
def mee_to_cartesian(mee: ArrayLike) -> jax.Array:
@partial(jax.jit, static_argnums=(1,))
def mee_to_cartesian(mee: ArrayLike, mu: float) -> jax.Array:
"""
Convert modified equinoctial elements to Cartesian elements.

Parameters
----------
mee : ArrayLike
Modified equinoctial elements [p(m), f, g, h, k, L(rad)].
Modified equinoctial elements [a(m), f, g, h, k, L(rad)].
mu : float
Gravitational parameter of the central body;
changing mu triggers a JIT recompile.

Returns
-------
Expand All @@ -72,9 +75,13 @@ def mee_to_cartesian(mee: ArrayLike) -> jax.Array:
Notes
-----
Formulation from https://spsweb.fltops.jpl.nasa.gov/portaldataops/mpg/MPG_Docs/Source%20Docs/EquinoctalElements-modified.pdf

"""
# unpack state vector
p, f, g, h, k, L = mee
a, f, g, h, k, L = mee

# convert SMA and ecc to p
p = a * (1 - f**2 - g**2)

# shorthand quantities defined in the document
alpha_sq = h**2 - k**2
Expand All @@ -99,7 +106,7 @@ def mee_to_cartesian(mee: ArrayLike) -> jax.Array:
vel = (
1
/ s_sq
* jnp.sqrt(MU / p)
* jnp.sqrt(mu / p)
* jnp.array(
[
-(
Expand All @@ -126,26 +133,30 @@ def mee_to_cartesian(mee: ArrayLike) -> jax.Array:
return jnp.concatenate([pos, vel])


@jax.jit
def cartesian_to_mee(cart: ArrayLike) -> jax.Array:
@partial(jax.jit, static_argnums=(1,))
def cartesian_to_mee(cart: ArrayLike, mu: float) -> jax.Array:
"""
Convert Cartesian elements to modified equinoctial elements.

Parameters
----------
cart : ArrayLike
Cartesian elements [x, y, z, vx, vy, vz] (m and m/s).
mu : float
Gravitational parameter of the central body.

Returns
-------
mee : Array
Modified equinoctial elements [p(m), f, g, h, k, L(rad)].
Modified equinoctial elements [a(m), f, g, h, k, L(rad)].
mu : float
Gravitational parameter of the central body.

Notes
-----
Transcribed from Fortran Astrodynamics Toolkit by jacobwilliams
"""

"""
pos = cart[0:3]
vel = cart[3:6]
rdv = pos @ vel
Expand All @@ -155,14 +166,14 @@ def cartesian_to_mee(cart: ArrayLike) -> jax.Array:
hmag = jnp.linalg.norm(hvec, ord=2)
hhat = hvec / hmag
vhat = (rmag * vel - rdv * rhat) / hmag
p = hmag**2 / MU
k = hhat[0] / (1 + hhat[2])
h = -hhat[1] / (1 + hhat[2])
p = hmag**2 / mu
k = hhat[0] / (1 + hhat[2] + 1e-10)
h = -hhat[1] / (1 + hhat[2] + 1e-10)
kk = k**2
hh = h**2
s2 = 1 + hh + kk
tkh = 2 * k * h
ecc = jnp.cross(vel, hvec) / MU - rhat
ecc = jnp.cross(vel, hvec) / mu - rhat
fhat = jnp.array([1 - kk + hh, tkh, -2 * k])
ghat = jnp.array([tkh, 1 + kk - hh, 2 * h])
fhat = fhat / s2
Expand All @@ -171,11 +182,14 @@ def cartesian_to_mee(cart: ArrayLike) -> jax.Array:
g = ecc @ ghat
L = jnp.atan2(rhat[1] - vhat[0], rhat[0] + vhat[1])

return jnp.array([p, f, g, h, k, L])
# convert a to p
a = p / (1 - f**2 - g**2)

return jnp.array([a, f, g, h, k, L])

@jax.jit
def batch_cartesian_to_mee(cart: ArrayLike) -> jax.Array:

@partial(jax.jit, static_argnums=(1,))
def batch_cartesian_to_mee(cart: ArrayLike, mu: float) -> jax.Array:
"""
Vmapped version of cartesian_to_mee, which ensures
that true longitude is unwrapped correctly.
Expand All @@ -184,36 +198,42 @@ def batch_cartesian_to_mee(cart: ArrayLike) -> jax.Array:
----------
cart : ArrayLike
Cartesian elements (N, 6)
mu : float
Gravitational parameter of the central body.

Returns
-------
mee : Array
Modified equinoctial elements (N, 6)
"""
mee = jax.vmap(cartesian_to_mee)(cart)
mee = jax.vmap(partial(cartesian_to_mee, mu=mu))(cart)

# unwrap true longitude
l_unwrap = jnp.unwrap(mee[:, 5])

return jnp.column_stack((mee[:, :5], l_unwrap))


@jax.jit
def batch_mee_to_cartesian(mee: ArrayLike) -> jax.Array:
@partial(jax.jit, static_argnums=(1,))
def batch_mee_to_cartesian(mee: ArrayLike, mu: float) -> jax.Array:
"""
Vmapped version of mee_to_cartesian.

Does nothing special, just for consistency with its sibling function.

Parameters
----------
mee : ArrayLike
Modified equinoctial elements (N, 6)
mu : float
Gravitational parameter of the central body.

Returns
-------
cart : Array
Cartesian elements (N, 6)
"""
return jax.vmap(mee_to_cartesian)(mee)
return jax.vmap(partial(mee_to_cartesian, mu=mu))(mee)


@jax.jit
Expand Down Expand Up @@ -262,3 +282,49 @@ def rot_lvlh_inertial(cart: ArrayLike) -> jax.Array:
"""

return rot_inertial_lvlh(cart).T


@jax.jit
def delta_angle_mod(a: float, b: float) -> float:
"""
Shortest phase difference between two angles, i.e.
how much a is ahead of b. Returns in range [-pi, pi].

Source: https://stackoverflow.com/a/2007279

Parameters
----------
a : float
First angle.
b : float
Second angle.

Returns
-------
float
Shortest phase difference between the two angles.

"""
return ((a - b + jnp.pi) % (2 * jnp.pi)) - jnp.pi


def a_mee_to_p_mee(y_mee: ArrayLike) -> jax.Array:
"""
Converts a-based MEE to p-based MEE.
"""
a = y_mee[0]
f = y_mee[1]
g = y_mee[2]
p = a * (1 - f**2 - g**2)
return y_mee.at[0].set(p)


def p_mee_to_a_mee(y_mee: ArrayLike) -> jax.Array:
"""
Converts p-based MEE to a-based MEE.
"""
p = y_mee[0]
f = y_mee[1]
g = y_mee[2]
a = p / (1 - f**2 - g**2)
return y_mee.at[0].set(a)
Loading