From 38990c78ec3d7c88aac96ad4cc97884328697420 Mon Sep 17 00:00:00 2001 From: Mark Showalter Date: Fri, 26 Sep 2025 17:03:40 -0700 Subject: [PATCH 1/2] Initial check-in --- oops/__init__.py | 6 +- oops/backplane/__init__.py | 37 +- oops/cache.py | 117 ++++++ oops/cadence/__init__.py | 1 + oops/cadence/dualcadence.py | 14 + oops/cadence/metronome.py | 15 +- oops/cadence/reshapedcadence.py | 8 + oops/cadence/reversedcadence.py | 6 + oops/cadence/snapcadence.py | 30 ++ oops/cadence/timeshift.py | 205 ++++++++++ oops/event.py | 2 +- oops/fittable.py | 669 ++++++++++++++++++++++++++++--- oops/fov/__init__.py | 1 + oops/fov/offsetfov.py | 49 +-- oops/fov/platescale.py | 101 +++++ oops/frame/__init__.py | 1 + oops/frame/cmatrix.py | 151 ++++--- oops/frame/frame_.py | 86 +++- oops/frame/frameshift.py | 92 +++++ oops/frame/inclinedframe.py | 169 ++++---- oops/frame/laplaceframe.py | 229 +++++------ oops/frame/navigation.py | 192 ++++----- oops/frame/poleframe.py | 234 +++++------ oops/frame/postargframe.py | 110 ++--- oops/frame/ringframe.py | 228 +++++------ oops/frame/rotation.py | 180 ++++----- oops/frame/spiceframe.py | 232 +++++------ oops/frame/spicetype1frame.py | 124 +++--- oops/frame/spinframe.py | 123 +++--- oops/frame/synchronousframe.py | 92 +++-- oops/frame/trackerframe.py | 158 ++++---- oops/frame/twovectorframe.py | 175 ++++---- oops/observation/insitu.py | 20 +- oops/observation/observation_.py | 56 ++- oops/observation/pixel.py | 13 +- oops/observation/rasterslit1d.py | 6 +- oops/observation/slit1d.py | 52 +-- oops/observation/snapshot.py | 61 ++- oops/observation/timedimage.py | 8 +- oops/path/__init__.py | 1 + oops/path/circlepath.py | 129 +++--- oops/path/coordpath.py | 90 +++-- oops/path/fixedpath.py | 78 ++-- oops/path/keplerpath.py | 491 +++++++++++------------ oops/path/linearcoordpath.py | 88 ++-- oops/path/linearpath.py | 90 +++-- oops/path/multipath.py | 157 ++++---- oops/path/path_.py | 103 ++++- oops/path/pathshift.py | 90 +++++ oops/path/spicepath.py | 164 ++++---- tests/frame/test_frame.py | 2 +- tests/frame/test_poleframe.py | 41 -- tests/path/test_keplerpath.py | 24 +- tests/path/test_multipath.py | 2 +- tests/test_cache.py | 148 +++++++ tests/test_fittable.py | 179 +++++++++ tests/unittester.py | 2 + tests/unittester_with_hosts.py | 4 +- 58 files changed, 3809 insertions(+), 2127 deletions(-) create mode 100644 oops/cache.py create mode 100755 oops/cadence/snapcadence.py create mode 100755 oops/cadence/timeshift.py create mode 100755 oops/fov/platescale.py create mode 100755 oops/frame/frameshift.py create mode 100755 oops/path/pathshift.py create mode 100755 tests/test_cache.py create mode 100755 tests/test_fittable.py diff --git a/oops/__init__.py b/oops/__init__.py index fb63e168..4120cf87 100755 --- a/oops/__init__.py +++ b/oops/__init__.py @@ -45,8 +45,9 @@ from oops.backplane import Backplane from oops.body import Body +from oops.cache import Cache from oops.event import Event -from oops.fittable import Fittable +from oops.fittable import Fittable, Fittable_ from oops.meshgrid import Meshgrid from oops.transform import Transform @@ -82,6 +83,9 @@ frame.Frame.J2000, path.Path.SSB) +Cache.FRAME_CLASS = Frame +Cache.PATH_CLASS = Path + Frame.EVENT_CLASS = Event Frame.PATH_CLASS = Path Frame.SPICEPATH_CLASS = oops.path.SpicePath diff --git a/oops/backplane/__init__.py b/oops/backplane/__init__.py index ed4416a5..357a9ae2 100755 --- a/oops/backplane/__init__.py +++ b/oops/backplane/__init__.py @@ -16,9 +16,9 @@ class Backplane(object): """Class that supports the generation and manipulation of sets of backplanes - with a particular observation. + with a particular Observation. - intermediate results are cached to speed up calculations. + Intermediate results are cached to speed up calculations. """ DIAGNOSTICS = False # set True to log diagnostics @@ -118,15 +118,7 @@ def __init__(self, obs, meshgrid=None, time=None, inventory=None, else: self.meshgrid = meshgrid - if time is None: - self.time = obs.timegrid(self.meshgrid) - else: - self.time = Scalar.as_scalar(time) - - # For some cases, times are all equal. If so, collapse the times. - dt = self.time - obs.midtime - if abs(dt).max() < 1.e-3: # simplifies cases with jitter in time tags - self.time = Scalar(obs.midtime) + self._input_time = time # Intialize the inventory self._input_inventory = inventory @@ -139,8 +131,23 @@ def __init__(self, obs, meshgrid=None, time=None, inventory=None, self.inventory_border = inventory_border + self._refresh() # Fill in all internals + + def _refresh(self): + + if self._input_time is None: + self.time = self.obs.timegrid(self.meshgrid) + else: + self.time = Scalar.as_scalar(self._input_time) + + # For some cases, times are all equal. If so, collapse the times. + dt = self.time - self.obs.midtime + if abs(dt).max() < 1.e-3: # simplifies cases with jitter in time tags + self.time = Scalar(self.obs.midtime) + # Define events - self.obs_event = obs.event_at_grid(self.meshgrid, time=self.time) + self.obs_event = self.obs.event_at_grid(self.meshgrid, time=self.time) + self.shape = self.obs_event.shape # dict[derivs] = event self.obs_events = { @@ -148,10 +155,8 @@ def __init__(self, obs, meshgrid=None, time=None, inventory=None, True : self.obs_event.with_los_derivs() } - self.obs_gridless_event = obs.gridless_event(self.meshgrid, - time=self.time) - - self.shape = self.obs_event.shape + self.obs_gridless_event = self.obs.gridless_event(self.meshgrid, + time=self.time) # The surface_events dictionary comes in two versions, with and without # derivatives with respect to los and time. diff --git a/oops/cache.py b/oops/cache.py new file mode 100644 index 00000000..69c42854 --- /dev/null +++ b/oops/cache.py @@ -0,0 +1,117 @@ +########################################################################################## +# oops/cache.py: Support for caching of OOPS objects +########################################################################################## + +import numpy as np +from polymath import Qube + + +class Cache: + """Class that can be indexed like a dictionary, where `maxsize` items are preserved. + When the size of the cache exceeds `maxsize` by ~ 10%, the least-recently accessed + items are deleted. + + Indexing a Cache using a key that is not present, or has been deleted, returns None. + A KeyError is never raised. + + Dictionary keys can include mutable items, which are converted to immutable. The class + method `clean_key` performs this conversion. + """ + + # These are filled in by oops/__init__.py to avoid circular imports + FRAME_CLASS = None + PATH_CLASS = None + + def __init__(self, maxsize=100): + """Constructor for a Cache. + + Parameters: + maxsize (int, optional): The rough limit on the number of items stored in the + Cache. When this value is exceeded by ~ 10%, the number of elements is + reduced back to `maxsize` by removing the items accessed least recently. + """ + + self._maxsize = maxsize + self._extras = max(3, maxsize//10) + self._limit = maxsize + self._extras + self._dict = {} + self._counter = 0 + + def __len__(self): + """The number of items currently in this Cache.""" + return len(self._dict) + + @staticmethod + def clean_key(key): + """Convert the given key to immutable so it can be used as a dictionary key.""" + + def clean_item(item): + match item: + case Qube(): + vals = tuple(item.vals.ravel()) if np.shape(item.vals) else item.vals + mask = tuple(item.mask.ravel()) if np.shape(item.mask) else item.mask + return (type(item).__name__, item.shape, vals, mask) + case np.ndarray(): + return (item.shape, tuple(item.ravel())) + case Cache.PATH_CLASS(): + return Cache.PATH_CLASS.as_waypoint(item) + case Cache.FRAME_CLASS(): + return Cache.FRAME_CLASS.as_wayframe(item) + case x if hasattr(x, '__data__'): + return id(item) + case list(): + return tuple(item) + case _: + return item + + if isinstance(key, (list, tuple)): + return tuple(clean_item(item) for item in key) + + return clean_item(key) + + def __contains__(self, key): + """True if the given key is currently in the Cache.""" + + if self._maxsize: + key = Cache.clean_key(key) + if key in self._dict: + self._counter += 1 + self._dict[key][0] = self._counter + return True + return False + + def __getitem__(self, key): + """The value associated with the given key, or None if the key is missing. + + Supports index notation using square brackets "[]". + """ + + if self._maxsize: + key = Cache.clean_key(key) + if key in self._dict: + self._counter += 1 + count_key_value = self._dict[key] + count_key_value[0] = self._counter + return count_key_value[2] + + return None + + def __setitem__(self, key, value): + """Set the value associated with the given key. + + Supports index notation using square brackets "[]". + """ + + if self._maxsize: + key = Cache.clean_key(key) + self._counter += 1 + self._dict[key] = [self._counter, key, value] + + if len(self._dict) > self._limit: + tuples = list(self._dict.values()) + tuples.sort() + extras = tuples[:-self._maxsize] + for (_, k, _) in extras: + del self._dict[k] + +########################################################################################## diff --git a/oops/cadence/__init__.py b/oops/cadence/__init__.py index 93b23852..67caaf54 100755 --- a/oops/cadence/__init__.py +++ b/oops/cadence/__init__.py @@ -9,6 +9,7 @@ from oops.cadence.reshapedcadence import ReshapedCadence from oops.cadence.reversedcadence import ReversedCadence from oops.cadence.sequence import Sequence +from oops.cadence.snapcadence import SnapCadence from oops.cadence.tdicadence import TDICadence ################################################################################ diff --git a/oops/cadence/dualcadence.py b/oops/cadence/dualcadence.py index cc5c7089..951510d9 100755 --- a/oops/cadence/dualcadence.py +++ b/oops/cadence/dualcadence.py @@ -46,11 +46,25 @@ def __init__(self, long, short): self._max_long_tstep = self.long.shape[0] - 1 + def _refresh(self): + """Update internals if self.long or self.short is Fittable.""" + self.time = (self.long.time[0], + self.long.lasttime + self.short.time[1]) + self.midtime = (self.time[0] + self.time[1]) * 0.5 + self.lasttime = self.long.lasttime + self.short.lasttime + def __getstate__(self): + self.refresh() return (self.long, self.short) def __setstate__(self, state): self.__init__(*state) + self.freeze() + + self.time = (self.long.time[0], + self.long.lasttime + self.short.time[1]) + self.midtime = (self.time[0] + self.time[1]) * 0.5 + self.lasttime = self.long.lasttime + self.short.lasttime #=========================================================================== def time_at_tstep(self, tstep, remask=False, derivs=False, inclusive=True): diff --git a/oops/cadence/metronome.py b/oops/cadence/metronome.py index e38e2e12..a97e3933 100755 --- a/oops/cadence/metronome.py +++ b/oops/cadence/metronome.py @@ -51,8 +51,7 @@ def __init__(self, tstart, tstride, texp, steps, clip=True): self._max_step = self.steps - 1 def __getstate__(self): - return (self.tstart, self.tstride, self.texp, self.steps, - self.clip) + return (self.tstart, self.tstride, self.texp, self.steps, self.clip) def __setstate__(self, state): self.__init__(*state) @@ -366,16 +365,4 @@ def for_array1d(steps, tstart, texp, interstep_delay=0.): return Metronome(tstart, texp + interstep_delay, texp, steps) - #=========================================================================== - @staticmethod - def for_array0d(tstart, texp): - """Alternative constructor for a product with no time-axis. - - Input: - tstart start time in seconds TDB. - texp exposure duration in seconds. - """ - - return Metronome(tstart, texp, texp, 1) - ################################################################################ diff --git a/oops/cadence/reshapedcadence.py b/oops/cadence/reshapedcadence.py index 9cb9cff8..28d16cd4 100755 --- a/oops/cadence/reshapedcadence.py +++ b/oops/cadence/reshapedcadence.py @@ -49,11 +49,19 @@ def __init__(self, cadence, shape): self._old_rank = len(self.cadence.shape) self._old_stride = np.cumprod((self._old_shape + (1,))[::-1])[-2::-1] + def _refresh(self): + """Update internals if self.cadence is Fittable.""" + self.time = self.cadence.time + self.midtime = self.cadence.midtime + self.lasttime = self.cadence.lasttime + def __getstate__(self): + self.refresh() return (self.cadence, self.shape) def __setstate__(self, state): self.__init__(*state) + self.freeze() #=========================================================================== @staticmethod diff --git a/oops/cadence/reversedcadence.py b/oops/cadence/reversedcadence.py index bd28d63f..c6078f00 100755 --- a/oops/cadence/reversedcadence.py +++ b/oops/cadence/reversedcadence.py @@ -43,6 +43,12 @@ def __init__(self, cadence, axis=0): self._first_time = self.cadence.time_range_at_tstep(self._max_step)[0] self._last_time = self.cadence.time_range_at_tstep(0)[1] + def _refresh(self): + """Update internals if self.cadence is Fittable.""" + self.time = self.cadence.time + self.midtime = self.cadence.midtime + self.lasttime = self.cadence.lasttime + def __getstate__(self): return (self.cadence,) diff --git a/oops/cadence/snapcadence.py b/oops/cadence/snapcadence.py new file mode 100755 index 00000000..29fb68f2 --- /dev/null +++ b/oops/cadence/snapcadence.py @@ -0,0 +1,30 @@ +################################################################################ +# oops/cadence/metronome.py: Metronome subclass of class Cadence +################################################################################ + +from oops.cadence import Metronome + +class SnapCadence(Metronome): + """A shapeless Cadence subclass with a single start and stop.""" + + def __init__(self, tstart, texp, clip=True): + """Constructor for a SnapCadence. + + Input: + tstart the start time of the observation in seconds TDB. + texp the exposure time in seconds associated with each step. + This may be shorter than tstride due to readout times, + etc. It may also be longer. + clip if True (the default), times and index values are always + clipped into the valid range. + """ + + Metronome.__init__(self, tstart, texp, texp, 1, clip=clip) + + def __getstate__(self): + return (self.tstart, self.texp, self.clip) + + def __setstate__(self, state): + self.__init__(*state) + +################################################################################ diff --git a/oops/cadence/timeshift.py b/oops/cadence/timeshift.py new file mode 100755 index 00000000..e5db0831 --- /dev/null +++ b/oops/cadence/timeshift.py @@ -0,0 +1,205 @@ +########################################################################################## +# oops/cadence/timeshift.py: Class TimeShift +########################################################################################## + +from fittable import Fittable +from polymath import Scalar +from oops.cadence import Cadence + + +class TimeShift(Cadence, Fittable): + """A Fittable time shift applied to another Cadence object.""" + + def __init__(self, arg, /, cadence): + """Constructor for a TimeShift. + + Parameters: + arg (float or TimeShift): The initial time shift in seconds. A positive value + shifts times later. Alternatively, specify another TimeShift object and + this object will be linked to that one, meaning that the time shifts will + always be equal. + cadence (Cadence): The Cadence object to be shifted. + """ + + if isinstance(arg, TimeShift): + self.link = arg + else: + self.dt = arg + self.link = None + + self.cadence = cadence + self._refresh() + + self.shape = cadence.shape + self.is_continuous = cadence.is_continuous + self.is_unique = cadence.is_unique + self.min_tstride = cadence.min_tstride + self.max_tstride = cadence.max_tstride + + ###################################################################################### + # Fittable support + ###################################################################################### + + _fittable_nparams = 1 + + def _set_params(self, params): + """Update the time shift in seconds.""" + + if self.link: + self.link.set_params(params) + + self.dt = params[0] + + @property + def _params(self): + return (self.dt,) + + def _refresh(self): + """Update the internals.""" + + if self.link: + self.dt = self.link.dt + + self.time = (self.cadence.time[0] + self.dt, self.cadence.time[1] + self.dt) + self.midtime = 0.5 * (self.time[0] + self.time[1]) + self.lasttime = self.cadence.lasttime + self.dt + + ###################################################################################### + # Serialization support + ###################################################################################### + + def __getstate__(self): + self.refresh() + return (self.dt, self.cadence) + + def __setstate__(self, state): + self.__init__(*state) + self.freeze() + + ###################################################################################### + # Cadence API + ###################################################################################### + + def time_at_tstep(self, tstep, remask=False, derivs=False, inclusive=True): + """The time associated with the given time step. + + This method supports non-integer time step values. + + In multidimensional cadences, indexing beyond the dimensions of the cadence + returns the time at the nearest edge of the cadence's shape. + + Parameters: + tstep (Scalar, Pair, array-like, float, or int): Time step index, 1-D or 2-D. + remask (bool, optional): True to mask values outside the time limits. + derivs (bool, optional): True to include derivatives of tstep in the returned + time. + inclusive (bool, optional): True to treat the end time of the cadence as part + of the cadence; False to exclude it. + + Returns: + (Scalar): Time(s) in seconds TDB. + """ + + return (self.cadence.time_at_tstep(tstep=tstep, remask=remask, derivs=derivs, + inclusive=inclusive) + + self.dt) + + def time_range_at_tstep(self, tstep, remask=False, inclusive=True, + shift=True): + """The range of times for the given time step. + + In multidimensional cadences, indexing beyond the dimensions of the + cadence returns the time range at the nearest edge. + + Parameters: + tstep (Scalar, Pair, array-like, float, or int): Time step index, 1-D or 2-D. + remask (bool, optional): True to mask values outside the time limits. + inclusive (bool, optional): True to treat the end time of the cadence as part + of the cadence; False to exclude it. + shift (bool, optional): True to shift the end of the last time step (with + index==shape) into the previous time step. + + Returns: + (tuple): Two Scalars defining the minimum and maximum times associated with + the index. It is given in seconds TDB. + """ + + times = self.cadence.time_range_at_tstep(tstep, remask=remask, + inclusive=inclusive, shift=shift) + return (times[0] + self.dt, times[1] + self.dt) + + def tstep_at_time(self, time, remask=False, derivs=False, inclusive=True): + """Time step for the given time. + + This method returns non-integer time steps via interpolation. + + In multidimensional cadences, times before first time step refer to the + first; times after the last time step refer to the last. + + Parameters: + time (Scalar, array-like, float, or int): Time(s) in seconds TDB. + remask (bool, optional): True to mask values outside the time limits. + derivs (bool, optional): True to include derivatives of tstep in the returned + time. + inclusive (bool, optional): True to treat the end time of the cadence as part + of the cadence; False to exclude it. + + Returns: + (Scalar or Pair): Time step index or indices. + """ + + return self.cadence.tstep_at_time(time - self.dt, remask=remask, derivs=derivs, + inclusive=inclusive) + + def tstep_range_at_time(self, time, remask=False, inclusive=True): + """Integer range of time steps active at the given time. + + Parameters: + time (Scalar, array-like, float, or int): Time(s) in seconds TDB. + remask (bool, optional): True to mask values outside the time limits. + inclusive (bool, optional): True to treat the end time of the cadence as part + of the cadence; False to exclude it. + + Returns: + (tuple): Two Scalars defining the minimum and maximum tstep values. + + Notes: + All returned indices will be in the allowed range for the cadence, inclusive, + regardless of mask. If the time is not inside the cadence, `tstep_max < + tstep_min`. + """ + + time = Scalar.as_scalar(time) + return self.cadence.tstep_range_at_time(time - self.dt, remask=remask, + inclusive=inclusive) + + def time_is_outside(self, time, inclusive=True): + """A Boolean mask of times that fall outside the cadence. + + Parameters: + time (Scalar, array-like, float, or int): Time(s) in seconds TDB. + inclusive (bool, optional): True to treat the end time of the cadence as part + of the cadence; False to exclude it. + + Returns: + (Boolean): True where time values are not sampled by the Cadence. + """ + + time = Scalar.as_scalar(time) + return self.cadence.time_is_outside(time - self.dt, inclusive=inclusive) + + def time_shift(self, secs): + """Construct a duplicate of this Cadence with all times shifted by given amount. + + Parameters: + secs (float): The number of seconds to shift the time later. + """ + + return TimeShift(self.link or self.dt, self.cadence.time_shift(secs)) + + def as_continuous(self): + """Construct a shallow copy of this Cadence, forced to be continuous.""" + + return TimeShift(self.link or self.dt, self.cadence.as_continuous()) + +########################################################################################## diff --git a/oops/event.py b/oops/event.py index 4f53839b..1dfbe9c2 100755 --- a/oops/event.py +++ b/oops/event.py @@ -392,7 +392,7 @@ def empty_cache(self): self._ssb_._antimask_ = None self._ssb_._shape_ = None - def reinit(self): + def _refresh(self): """Remove all internal information; needed for Events that involve Fittable objects. """ diff --git a/oops/fittable.py b/oops/fittable.py index 83d702c2..947710dd 100755 --- a/oops/fittable.py +++ b/oops/fittable.py @@ -1,86 +1,643 @@ -################################################################################ +########################################################################################## # oops/fittable.py: Fittable interface -################################################################################ +########################################################################################## +"""Support for Fittable (mutable) OOPS objects.""" + +import numpy as np + class Fittable(object): - """The Fittable interface enables any class to be used within a - least-squares fitting procedure. - - Every Fittable object has these attributes: - nparams the number of parameters required. - param_name the name of the attribute holding the parameters. - cache a dictionary containing prior values of the object, - keyed by the parameter set as a tuple. - - It is also necessary to define these methods: - set_params() or set_params_new() - copy() + """The Fittable interface enables any class to be used within a fitting procedure. + + Most OOPS objects are static, but objects that subclass Fittable can be modified + in-place. This is primarily used when fitting unknown values such as pointing + corrections, time shifts, or plate scales, or orbital elements. + + The following methods are defined: + + * `set_params`: Update the parameter values for this object and/or any of its + sub-objects. + * `get_params`: Retrieve the current parameter values. + * `refresh`: Make sure this object is internally consistent. Always call this object + before making use of an object if any of the sub-objects might have changed + underneath it. + * `freeze`: Freeze the parameter values, preventing any further changes to this object + or any of its sub-objects. + * `is_frozen`: True if this object is frozen. + * `version`: An integer that starts at zero and increases whenever this or one of its + sub-objects changes. + + When an object class subclasses Fittable, the method:: + _set_params(self, params) + must be defined. It receives one or more floating-point parameters and uses their + values to update the object. (Note that the programmer should call the method + `set_params`, not `_set_params`, because the former handles a variety of additional + "bookkeeping" tasks.) + + A Fittable object must have a property or attribute `_params`, which returns the + tuple of the object's current parameters. + + If an object class maintains cached information internally, it should also have a + method `_refresh`, which updates any internal attributes based on the new parameters. + It may also need `_freeze`, which is called when the object is frozen. + + Note that it is possible for a Fittable object to have sub-objects that are also + Fittable. + + Information about the fittable state of such an object is maintained by a set of added + attributes. These attributes are maintained entirely by the Fittable interface and + should not be touched by the programmer. + + * `_fittables`: An ordered list of the names of all the fittable sub-objects. + * `_fittable_nparams`: The number of fittable parameters required by the object and + its sub-objects. + * `_fittable_params`: The current values of the fittable parameters for the object or + its sub-objects. + * `_fittable_version`: The version number of the parameters. This values begins at + zero and is incremented each time the object or any of its sub-objects is modified. + * `_fittable_is_frozen`: True if this object is now frozen. Once frozen, it can no + longer be modified. + * `_fittable_state`: A dictionary, keyed by attribute name, providing the version + number of each fittable sub-object. """ - #=========================================================================== def set_params(self, params): - """Redefine the object using this set of parameters. + """Redefine this object using a new set of parameters. - This implementation checks the cache first, and then calls - set_params_new() if the instance is not cached. Override this method - if the Fittable object does not need a cache. + This calls a defined `_set_params` method for any sub-objects and then for the + object itself. It also refreshes as it goes. - Input: - params a list, tuple or 1-D Numpy array of floating-point - numbers, defining the parameters to be used in the - object returned. + Parameters: + obj (object): Object for which parameters are to be set. + params (tuple, list, or np.ndarray): Parameter values to use. """ - key = tuple(params) - if key in self.cache: - return self.cache[key] + Fittable_.set_params(self, params) + + def get_params(self, *, frozen=False, as_dict=False): + """The parameters defining the current state of this object. + + Parameters: + obj (object): Object for which parameters are to be retrieved. + frozen (bool, optional): True to include parameters associated with frozen + objects as well. + """ + + return Fittable_.get_params(self, frozen=frozen, as_dict=as_dict) + + def refresh(self): + """Update any internally cached information if this object has been modified. + + Use this call to ensure that an object is fully self-consistent, not containing + any stale information. + + If the given object and any Fittable sub-object(s) are already up to date, the + object is not changed. + + Parameters: + obj (object): Object to be refreshed if necessary. + + Returns: + (bool): True if this object was modified as a result of this call. + """ + + return Fittable_.refresh(self) + + def freeze(self): + """Freeze this object and any Fittable subobjects. + + A frozen object can no longer be modified. + """ + + Fittable_.freeze(self) + + def is_frozen(self): + """True if the given object and all Fittable sub-objects are frozen.""" + + return Fittable_.is_frozen(self) + + def version(self): + """The Fittable version number of this object. + + The version number starts at zero and is incremented each time the object or one + of its sub-objects is modified by a call to `set_params` or possibly `refresh`. + """ + + return Fittable_.version(self) + +########################################################################################## +# Class Fittable_, holding static versions of all needed functions +########################################################################################## + +class Fittable_: + """Static functions that support Fittable operations on objects that are not + necessarily subclasses of Fittable. + + Most OOPS objects are static. An object that subclasses Fittable can be modified + in-place by making a call to a function named `_set_params`. - result = self.set_params_new(params) - self.cache[key] = result + If an object contains one or more Fittable sub-objects, it is effectively fittable + even if it does not subclass the Fittable class. The following static methods can be + applied to any object: + + * `set_params`: Update the parameter values for the given object and/or any of its + sub-objects. + * `get_params`: Retrieve the current parameter values. + * `refresh`: Make sure the given object is internally consistent. Always call this + function before making use of an object if it or any of the sub-objects might have + changed via a recent call to `set_params`. + * `freeze`: Freeze the parameter values, preventing any further changes to the given + object or its sub-objects. + * `is_frozen`: True if the given object is frozen. + * `fittables`: A list of the names of any fittable sub-objects. + * `is_fittable`: True if the given object is fittable, whether or not it is frozen. + * `version`: An integer that starts at zero and increases whenever an object or one of + its sub-objects changes. + + The programmer may wish to define these items for a class, even if the object class is + not a subclass of Fittable: + + * `_refresh`: This optional method should update any internal attributes that depend + on an updated sub-object. It ensures that this information cannot become "stale" if + an underlying sub-object changes. + * `_freeze`: This optional method should update any internal attributes that depend + on the object being frozen. + + Information about the fittable state of such an object is maintained by a set of added + attributes. These attributes are maintained entirely by the module and should not be + touched by the programmer. + + * `_fittables`: An ordered list of the names of all the fittable sub-objects. + * `_fittable_nparams`: The number of fittable parameters required by the object and + its sub-objects. + * `_fittable_params`: The current values of the fittable parameters for the object or + its sub-objects. + * `_fittable_version`: The version number of the parameters. This values begins at + zero and is incremented each time the object or any of its sub-objects is modified. + * `_fittable_is_frozen`: True if this object is now frozen. Once frozen, it can no + longer be modified. + * `_fittable_state`: A dictionary, keyed by attribute name, providing the version + number of each fittable sub-object. + """ + + _FROZEN_IDS = set() # for objects with __dict__ that can't have attributes set + + @staticmethod + def set_params(obj, /, params): + """Redefine the given object using a new set of parameters. + + This calls a defined `_set_params` method for any sub-objects and then for the + object itself. It also refreshes as it goes. + + Parameters: + obj (object): Object for which parameters are to be set. + params (tuple, list, np.ndarray, or dict): Parameter values to use. Use a dict + to apply parameters to subobjects, where the key is the name of the + fittable attribute; in this case, use a blank key "" to set the parameters + of the given object. + """ + + def clean_params(params): + if isinstance(params, (dict, tuple)): + return params + if isinstance(params, (list, np.ndarray)): + return tuple(params) + return (params,) + + if not Fittable_.is_fittable(obj): + raise ValueError(f'{type(obj).__name__} object is not Fittable') + + if Fittable_.is_frozen(obj): + if clean_params(params) == obj.get_params(): # no error if unchanged + return + raise ValueError(f'{type(obj).__name__} object is frozen') + + # Make sure state is initialized + _ = Fittable_._get_state(obj) + + # Handle a dictionary + if isinstance(params, dict): + for key, sub_params in params.items(): + if not key: + continue + Fittable_.set_params(obj.__dict__[key], sub_params) + + if '' in params: + params = params[''] + else: + if hasattr(obj, '_refresh'): + obj._refresh() + Fittable_._increment(obj) + obj._fittable_state[''] = obj._fittable_version + Fittable_.refresh(obj) + return + + # Handle a tuple + if not isinstance(obj, Fittable): + raise ValueError(f'{type(obj).__name__} object is not Fittable') + + params = clean_params(params) + if not params: + raise ValueError(f'missing parameters for {type(obj).__name__}.set_params()') + + # Check or set the number of parameters + if hasattr(obj, '_fittable_nparams'): + if len(params) != obj._fittable_nparams: + plural = 's' if obj._fittable_nparams > 1 else '' + raise ValueError(f'{type(obj).__name__} object requires ' + f'{obj._fittable_nparams} fit parameter{plural}') + else: + obj._fittable_nparams = len(params) + + # Update the object with the new parameters and refresh + obj._set_params(params) + obj._fittable_params = params + + if hasattr(obj, '_refresh'): + obj._refresh() + Fittable_._increment(obj) + obj._fittable_state[''] = obj._fittable_version + Fittable_.refresh(obj) + + @staticmethod + def get_params(obj, /, *, frozen=False, as_dict=False, _memo=None): + """The parameters defining the current state of the given object. + + Parameters: + obj (object): Object for which parameters are to be retrieved. + frozen (bool, optional): True to include parameters associated with frozen + objects. + as_dict (bool, optional): True to return a dictionary that also includes the + parameters of sub-objects keyed by attribute name; the parameters of the + given object will be keyed by a blank string "". If False a tuple of the + given object's parameters are returned. + _memo (set, optional): Set to use internally for tracking sub-objects, which + is needed to handle circular references. + + Returns: + (tuple or dict): Parameters as a tuple or dictionary. + """ + + _memo = set() if _memo is None else _memo + Fittable_.refresh(obj) + + result = () + if frozen or not Fittable_.is_frozen(obj): + result = obj.__dict__.get('_fittable_params', ()) + if not result and hasattr(obj, '_params'): + result = obj._params + _memo.add(id(obj)) + + if as_dict: + result = {'': result} if result else {} + for name in Fittable_.fittables(obj, frozen=frozen): + subobj = obj.__dict__[name] + subobj_id = id(subobj) + if subobj_id in _memo: + continue + + value = Fittable_.get_params(subobj, frozen=frozen, as_dict=True, + _memo=_memo) + _memo.add(subobj_id) + + # Convert dict to tuple where appropriate + result[name] = value[''] if list(value.keys()) == [''] else value return result - #=========================================================================== - def set_params_new(self, params): - """Redefine using this set of parameters. Do not check the cache first. + @staticmethod + def refresh(obj, /, *, _memo=None): + """Update any internally cached information if the given object has been modified. + + Use this call to ensure that an object is fully self-consistent, not containing + any stale information. - Override this method if the subclass uses a cache. Then calls to - set_params() will check the cache and only invoke this function when - needed. + If the given object and any Fittable sub-object(s) are already up to date, the + object is not changed. - Input: - params a list, tuple or 1-D Numpy array of floating-point - numbers, defining the parameters to be used in the - object returned. + Parameters: + obj (object): Object to be refreshed if necessary. + _memo (set, optional): Set to use internally for tracking sub-objects, which + is needed to handle circular references. + + Returns: + (bool): True if the given object was modified as a result of this call. """ - pass + def refresh1(obj, _memo): + """Apply a single iteration of `refresh`.""" + + # Refresh any Fittable sub-objects first + changed = False + state = Fittable_._get_state(obj) + for name, prev_version in state.items(): + if not name: # skip the object itself for now + continue - #=========================================================================== - def get_params(self): - """The current set of parameters defining this fittable object. + subobj = obj.__dict__[name] - This method normally does not need to be overridden. + # Make sure this object wasn't already refreshed + subobj_id = id(subobj) + if subobj_id in _memo: # avoid a circular reference + state[name] = subobj.version() # ...but update the state dict + continue + _memo.add(subobj_id) - Return: a Numpy 1-D array of floating-point numbers containing - the parameter values defining this object. + # Refresh this sub-object recursively + Fittable_.refresh(subobj, _memo=_memo) + + # Update the version of this sub-object + new_version = Fittable_.version(subobj) + if new_version != prev_version: + state[name] = new_version + subobj._fittable_params = Fittable_.get_params(subobj) + changed = True + + # If any sub-object has changed, this object needs a `_refresh` + if changed: + if hasattr(obj, '_refresh'): + obj._refresh() + obj._fittable_params = Fittable_.get_params(obj) + Fittable_._increment(obj) + state[''] = obj._fittable_version + + return changed + + # Begin active code + _memo = set() if _memo is None else _memo + + if not hasattr(obj, '__dict__'): + return False + + if hasattr(obj, '_is_fittable') and not obj._is_fittable: + return False + + if hasattr(obj, '_fittable_is_frozen') and obj._fittable_is_frozen: + return False + + # Refresh once + changed = refresh1(obj, _memo) + + # A circular reference can require multiple refreshes to get the correct state + if changed: + for i in range(9): + _memo = set() + if not refresh1(obj, _memo): + return True + + raise RuntimeError('Fittable_.refresh() did not complete after 10 iterations') + + return False + + @staticmethod + def freeze(obj, /, _memo=None): + """Freeze the given object and any Fittable subobjects. + + A frozen object can no longer be modified. + + Parameters: + obj (object): Object to freeze. + _memo (set, optional): Set to use internally for tracking sub-objects, which + is needed to handle circular references. + """ + + _memo = set() if _memo is None else _memo + + # If not freezable, return + if not hasattr(obj, '__dict__'): + return + + # Don't revisit this object + obj_id = id(obj) + if obj_id in _memo: + return + _memo.add(obj_id) + + if hasattr(obj, '_fittable_is_frozen') and obj._fittable_is_frozen: + return + + # If the frozen attribute can't be set, check the global list of frozen IDs + if obj_id in Fittable_._FROZEN_IDS: + return + + # Freeze the sub-objects + values = list(obj.__dict__.values()) + for subobj in values: + subobj_id = id(subobj) + if subobj_id in _memo: + continue + Fittable_.freeze(subobj, _memo=_memo) + _memo.add(subobj_id) + + # Refresh and freeze this + if hasattr(obj, '_refresh'): + obj._refresh() + + if hasattr(obj, '_freeze'): + obj._freeze() + + # Set this object as frozen if possible + try: + obj._fittable_is_frozen = True + except (AttributeError, TypeError): + pass + + @staticmethod + def is_frozen(obj, /, _memo=None): + """True if the given object and all Fittable sub-objects are frozen. + + Parameters: + obj (object): Object to test. + _memo (set, optional): Set to use internally for tracking sub-objects, which + is needed to handle circular references. + + Returns: + (bool): True if the given object is frozen. + """ + + _memo = set() if _memo is None else _memo + + # If not a freezable object, return True + if not hasattr(obj, '__dict__'): + return True + + # Don't revisit this object + obj_id = id(obj) + _memo.add(obj_id) + + # If this object is already frozen, return True + if hasattr(obj, '_fittable_is_frozen'): + if obj._fittable_is_frozen: + return True + elif obj_id in Fittable_._FROZEN_IDS: + return True + + # If any sub-object is not frozen, this is not frozen + values = list(obj.__dict__.values()) + for subobj in values: + subobj_id = id(subobj) + if subobj_id in _memo: + continue + if not Fittable_.is_frozen(subobj, _memo=_memo): + return False + _memo.add(subobj_id) + + # If this object fittable, return False + if isinstance(obj, Fittable): + obj._fittable_is_frozen = False + return False + + # Designate this object as frozen + try: + obj._fittable_is_frozen = True + except (AttributeError, TypeError): + Fittable_._FROZEN_IDS.add(obj_id) + + return True + + @staticmethod + def fittables(obj, /, *, frozen=False, _memo=None): + """Ordered list of Fittable subobjects of the given object. + + This is the list of Fittable sub-objects sorted alphabetically. + + Parameters: + obj (object): Object for which to obtain the fittable attribute names. + frozen (bool, optional): True to include names of attributes that have been + frozen. + _memo (set, optional): Set to use internally for tracking sub-objects, which + is needed to handle circular references. + + Returns: + (list): List of the Fittable sub-objects of the given objects. + """ + + _memo = set() if _memo is None else _memo + + if not hasattr(obj, '__dict__'): + return [] + + if hasattr(obj, '_fittables'): + names = obj._fittables + else: + keys = list(obj.__dict__.keys()) + + names = [] + for name in keys: + subobj = obj.__dict__[name] + subobj_id = id(subobj) + if subobj_id in _memo: + continue + if Fittable_.is_fittable(subobj, _memo=_memo): + names.append(name) + _memo.add(subobj_id) + + names.sort() + try: + obj._fittables = names + except (AttributeError, TypeError): + pass + + _memo.add(id(obj)) + + if not frozen: + _memo = set() + names = [n for n in names if not Fittable_.is_frozen(obj.__dict__[n], + _memo=_memo)] + + _memo.add(id(obj)) + return names + + @staticmethod + def is_fittable(obj, /, _memo=None): + """True if the given object or any of its sub-objects is fittable, whether or not + it is frozen. + + Parameters: + obj (object): Object to test. + _memo (set, optional): Set to use internally for tracking sub-objects, needed + to handle circular references. + + Returns: + (bool): True if the object is fittable, either by being a Fittable subclass or + by having a Fittable subobject. """ - return self.__dict__[self.param_name] + _memo = set() if _memo is None else _memo + _memo.add(id(obj)) + + if not hasattr(obj, '__dict__'): + return False + + if hasattr(obj, '_is_fittable'): + return obj._is_fittable + + if isinstance(obj, Fittable): + obj._is_fittable = True + return True + + fittables = Fittable_.fittables(obj, frozen=True, _memo=_memo) + try: + return obj.__dict__.setdefault('_is_fittable', bool(fittables)) + except (AttributeError, TypeError): + return False - #=========================================================================== - def copy(self): - """A deep copy of this object. + @staticmethod + def version(obj, /): + """The Fittable version number of this object. - The copy can be safely modified without affecting the original. + The version number starts at zero and is incremented each time the object or one + of its sub-objects is modified by a call to `set_params` or possibly `refresh`. """ - pass + if hasattr(obj, '_fittable_version'): + return obj._fittable_version + + try: + obj._fittable_version = 0 + except (AttributeError, TypeError): + pass + + return 0 + + @staticmethod + def _get_state(obj, /): + """The state of the given object. + + This is a dictionary containing the version number of the object itself, as well + as each Fittable sub-object, at the time of construction or the most recent call + to `refresh`. Each dictionary entry is keyed by the attribute name; the version + number of the object itself is keyed by "". + """ + + if Fittable_.is_frozen(obj): + return {} + + if hasattr(obj, '_fittable_state'): + return obj._fittable_state + + state = {} + for name in Fittable_.fittables(obj): + state[name] = 0 + state[''] = 0 + + try: + obj._fittable_state = state + except (AttributeError, TypeError): + return {} + + return state + + @staticmethod + def _increment(obj, /): + """Increment the Fittable version number of this object.""" + + if not hasattr(obj, '__dict__'): + return - #=========================================================================== - def clear_cache(self): - """Clear the current cache.""" + if hasattr(obj, '_fittable_version'): + obj._fittable_version += 1 - self.cache = {} + obj._fittable_version = 1 -################################################################################ +########################################################################################## diff --git a/oops/fov/__init__.py b/oops/fov/__init__.py index 65d70cdc..4c55e8a0 100755 --- a/oops/fov/__init__.py +++ b/oops/fov/__init__.py @@ -7,6 +7,7 @@ from oops.fov.flatfov import FlatFOV from oops.fov.nullfov import NullFOV from oops.fov.offsetfov import OffsetFOV +from oops.fov.platescale import Platescale from oops.fov.polynomialfov import PolynomialFOV from oops.fov.slicefov import SliceFOV from oops.fov.subarray import Subarray diff --git a/oops/fov/offsetfov.py b/oops/fov/offsetfov.py index 07f26a9c..36430cf6 100755 --- a/oops/fov/offsetfov.py +++ b/oops/fov/offsetfov.py @@ -59,10 +59,21 @@ def __init__(self, fov, uv_offset=None, xy_offset=None): self.uv_area = self.fov.uv_area self.uv_los = self.fov.uv_los - self.uv_offset - # Required attributes for Fittable - self.nparams = 2 - self.param_name = 'uv_offset' - self.cache = {} # not used + ############################################################################ + # Fittable interface + ############################################################################ + + _fittable_nparams = 2 + + def _set_params(self, params): + """Redefine the (u,v) offsets of this OffsetFOV.""" + + self.uv_offset = Pair.as_pair(params) + self.xy_offset = self.fov.xy_from_uv(self.uv_offset - self.fov.uv_los) + + ############################################################################ + # Serialization support + ############################################################################ def __getstate__(self): return (self.fov, self.uv_offset, self.xy_offset) @@ -70,7 +81,10 @@ def __getstate__(self): def __setstate__(self, state): self.__init__(*state) - #=========================================================================== + ############################################################################ + # FOV API + ############################################################################ + def xy_from_uvt(self, uv_pair, time=None, derivs=False, remask=False, **keywords): """The (x,y) camera frame coordinates given the FOV coordinates (u,v) at @@ -119,29 +133,4 @@ def uv_from_xyt(self, xy_pair, time=None, derivs=False, remask=False, return self.fov.uv_from_xyt(xy_pair + self.xy_offset, time=time, derivs=derivs, remask=remask, **keywords) - ############################################################################ - # Fittable interface - ############################################################################ - - def set_params(self, params): - """Redefine the Fittable object, using this set of parameters. - - Input: - params a list, tuple or 1-D Numpy array of floating-point - numbers, defining the parameters to be used in the - object returned. - """ - - self.uv_offset = Pair.as_pair(params) - self.xy_offset = self.fov.xy_from_uv(self.uv_offset - self.fov.uv_los) - - #=========================================================================== - def copy(self): - """A deep copy of the Fittable object. - - The copy can be safely modified without affecting the original. - """ - - return OffsetFOV(self.fov, self.uv_offset.copy(), xy_offset=None) - ################################################################################ diff --git a/oops/fov/platescale.py b/oops/fov/platescale.py new file mode 100755 index 00000000..8f43ea5d --- /dev/null +++ b/oops/fov/platescale.py @@ -0,0 +1,101 @@ +########################################################################################## +# oops/fov/platescale.py: Platescale subclass of class FOV +########################################################################################## + +from oops.fittable import Fittable +from oops.fov import FOV +from polymath import Pair + + +class Platescale(FOV, Fittable): + """An FOV defined by applying a plate scale to another FOV. + + PLACEHOLDER CODE. "CONCEPTUALLY" CORRECT BUT NOT YET TESTED. + """ + + def __init__(self, factor, /, fov): + """Constructor for a Platescale FOV. + + Parameters: + factor (float): The scale factor to apply to the given FOV. A value greater + than one enlarges the FOV. + fov (FOV): The FOV object to which the scale factor applies. + """ + + self.factor = factor + self.fov = fov + + self.uv_los = self.fov.uv_los + self.uv_shape = self.fov.uv_shape + self.uv_area = self.fov.uv_area + + self._refresh() + + ###################################################################################### + # Fittable API + ###################################################################################### + + def _refresh(self): + self.uv_scale = self.fov.uv_scale * self.factor + + def _set_params(self, params): + self.factor = params[0] + + @property + def _params(self): + return (self.factor,) + + ###################################################################################### + # Serialization support + ###################################################################################### + + def __getstate__(self): + return (self.factor, self.fov) + + def __setstate__(self, state): + self.__init__(*state) + + ###################################################################################### + # FOV API + ###################################################################################### + + def xy_from_uvt(self, uv_pair, time=None, derivs=False, remask=False): + """The (x,y) camera frame coordinates given FOV coordinates (u,v). + + Parameters: + uv_pair (Pair or array-like): (u,v) pixel coordinates in the FOV. + time (Scalar, array-like, or float, optional): Time in TDB seconds, ignored by + time-independent FOVs. + derivs (bool, optional): True to propagate any derivatives in (u,v) into the + returned (x,y) Pair. + remask (bool, optional): True to mask (u,v) coordinates that fall outside the + field of view; False to leave them unmasked. + + Returns: + (Pair): (x,y) coordinates in the FOV's frame. + """ + + xy_pair = self.fov.xy_from_uvt(uv_pair, time=time, derivs=derivs, remask=remask) + return xy_pair * self.factor + + def uv_from_xyt(self, xy_pair, time=None, derivs=False, remask=False): + """The (u,v) FOV coordinates given (x,y) camera frame coordinates. + + Parameters: + xy_pair (Pair or array-like): (x,y) coordinates in the FOV. + time (Scalar, array-like, or float, optional): Time in TDB seconds, ignored by + time-independent FOVs. + derivs (bool, optional): True to propagate any derivatives in (x,y) into the + returned (u,v) Pair. + remask (bool, optional): True to mask (u,v) coordinates that fall outside the + field of view; False to leave them unmasked. + + Returns: + (Pair): (u,v) pixel coordinates in the FOV. + """ + + xy_pair = Pair.as_pair(xy_pair) + return self.fov.uv_from_xyt(xy_pair / self.factor, time=time, derivs=derivs, + remask=remask) + +########################################################################################## diff --git a/oops/frame/__init__.py b/oops/frame/__init__.py index 52c05376..33d44de6 100755 --- a/oops/frame/__init__.py +++ b/oops/frame/__init__.py @@ -6,6 +6,7 @@ LinkedFrame, RelativeFrame, ReversedFrame, QuickFrame) from oops.frame.cmatrix import Cmatrix +from oops.frame.frameshift import FrameShift from oops.frame.inclinedframe import InclinedFrame from oops.frame.laplaceframe import LaplaceFrame from oops.frame.navigation import Navigation diff --git a/oops/frame/cmatrix.py b/oops/frame/cmatrix.py index ffad650f..97fb3b32 100755 --- a/oops/frame/cmatrix.py +++ b/oops/frame/cmatrix.py @@ -1,14 +1,16 @@ -################################################################################ +########################################################################################## # oops/frame/cmatrix.py: Subclass Cmatrix of class Frame -################################################################################ +########################################################################################## import numpy as np from polymath import Matrix3, Qube, Scalar, Vector3 +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform from oops.constants import RPD + class Cmatrix(Frame): """Frame subclass in which the frame is defined by a fixed rotation matrix. @@ -17,67 +19,78 @@ class Cmatrix(Frame): and the Y-axis points downward. """ - # Note: Navigation frames are not generally re-used, so their IDs are - # expendable. Frame IDs are not preserved during pickling. + _FRAME_IDS = {} - #=========================================================================== def __init__(self, cmatrix, reference=None, frame_id=None): """Constructor for a Cmatrix frame. - Input: - cmatrix a Matrix3 object. - reference the ID or frame relative to which this frame is defined; - None for J2000. - frame_id the ID under which the frame will be registered; None - to leave the frame unregistered + Parameters: + cmatrix (Matrix3): the C matrix. + reference (Frame or str): Frame or Frame ID elative to which this frame is + defined; None for J2000. + frame_id (str, optional): The ID under which the frame will be registered; + None to leave the frame unregistered """ self.cmatrix = Matrix3.as_matrix3(cmatrix) # Required attributes - self.frame_id = frame_id self.reference = Frame.as_wayframe(reference) or Frame.J2000 - self.origin = self.reference.origin - self.shape = Qube.broadcasted_shape(self.cmatrix, self.reference) - self.keys = set() + self.origin = self.reference.origin + self.shape = Qube.broadcasted_shape(self.cmatrix, self.reference) + self.frame_id = self._recover_id(frame_id) - # Update wayframe and frame_id; register if not temporary + # Update wayframe and frame_id; register if not temporary self.register() + self._refresh() + self._cache_id() + + def _refresh(self): + self.transform = Transform(self.cmatrix, Vector3.ZERO, self.wayframe, + self.reference) - # It needs a wayframe before we can construct the transform - self.transform = Transform(cmatrix, Vector3.ZERO, - self.wayframe, self.reference) + ###################################################################################### + # Serialization support + ###################################################################################### + + def _frame_key(self): + return (self.cmatrix, self.reference) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (self.cmatrix, Frame.as_primary_frame(self.reference)) + Fittable_.refresh(self) + self._cache_id() + return (self.cmatrix, Frame.as_primary_frame(self.reference), self._state_id()) def __setstate__(self, state): - self.__init__(*state) + (cmatrix, reference, frame_id) = state + self.__init__(cmatrix, reference, frame_id=frame_id) + Fittable_.freeze(self) + + ###################################################################################### + # Alternative constructor + ###################################################################################### - #=========================================================================== @staticmethod def from_ra_dec(ra, dec, clock, reference=None, frame_id=None): """Construct a Cmatrix from RA, dec and celestial north clock angles. - Input: - ra a Scalar defining the right ascension of the optic axis - in degrees. - dec a Scalar defining the declination of the optic axis in - degrees. - clock a Scalar defining the angle of celestial north in - degrees, measured clockwise from the "up" direction in - the observation. - reference the reference frame or ID; None for J2000. - frame_id the ID to use when registering this frame; None to leave - it unregistered - - Note that this Frame can have an arbitrary shape. This shape is defined - by broadcasting the shapes of the ra, dec, twist and reference. + Parameters: + ra (Scalar, array-like, or float): The right ascension of the optic axis in + degrees. + dec (Scalar, array-like, or float): The declination of the optic axis in + degrees. + clock (Scalar, array-like, or float): The angle of celestial north in degrees, + measured clockwise from the "up" direction in the observation. + reference (Frame, optional): The reference frame or ID; None for J2000. + frame_id (str, optional): The ID to use when registering this frame; None to + leave it unregistered + + Note that this Frame can have an arbitrary shape. This shape is defined by + broadcasting the shapes of the ra, dec, twist and reference. """ - ra = Scalar.as_scalar(ra) - dec = Scalar.as_scalar(dec) + ra = Scalar.as_scalar(ra) + dec = Scalar.as_scalar(dec) clock = Scalar.as_scalar(clock) mask = Qube.or_(ra.mask, dec.mask, clock.mask) @@ -88,36 +101,54 @@ def from_ra_dec(ra, dec, clock, reference=None, frame_id=None): cosr = np.cos(ra) sinr = np.sin(ra) - cosd = np.cos(dec) sind = np.sin(dec) - cost = np.cos(twist) sint = np.sin(twist) - - (cosr, cosd, cost, - sinr, sind, sint) = np.broadcast_arrays(cosr, cosd, cost, - sinr, sind, sint) + (cosr, cosd, cost, sinr, sind, sint) = np.broadcast_arrays(cosr, cosd, cost, + sinr, sind, sint) # Extracted from the PDS Data Dictionary definition, which is appended # below cmatrix_values = np.empty(cosr.shape + (3,3)) - cmatrix_values[...,0,0] = -sinr * cost - cosr * sind * sint - cmatrix_values[...,0,1] = cosr * cost - sinr * sind * sint - cmatrix_values[...,0,2] = cosd * sint - cmatrix_values[...,1,0] = sinr * sint - cosr * sind * cost - cmatrix_values[...,1,1] = -cosr * sint - sinr * sind * cost - cmatrix_values[...,1,2] = cosd * cost - cmatrix_values[...,2,0] = cosr * cosd - cmatrix_values[...,2,1] = sinr * cosd - cmatrix_values[...,2,2] = sind - - return Cmatrix(Matrix3(cmatrix_values,mask), reference, frame_id) - - #=========================================================================== + cmatrix_values[..., 0, 0] = -sinr * cost - cosr * sind * sint + cmatrix_values[..., 0, 1] = cosr * cost - sinr * sind * sint + cmatrix_values[..., 0, 2] = cosd * sint + cmatrix_values[..., 1, 0] = sinr * sint - cosr * sind * cost + cmatrix_values[..., 1, 1] = -cosr * sint - sinr * sind * cost + cmatrix_values[..., 1, 2] = cosd * cost + cmatrix_values[..., 2, 0] = cosr * cosd + cmatrix_values[..., 2, 1] = sinr * cosd + cmatrix_values[..., 2, 2] = sind + cmatrix = Matrix3(cmatrix_values, mask) + + return Cmatrix(cmatrix, reference, frame_id=frame_id) + + ###################################################################################### + # Frame API + ###################################################################################### + def transform_at_time(self, time, quick=False): - """Transform into this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + + Notes: + Cmatrix is a fixed frame, so the transform relative to the `reference` frame + is independent of time. + """ return self.transform -################################################################################ +########################################################################################## diff --git a/oops/frame/frame_.py b/oops/frame/frame_.py index 41510fa2..68d24ae9 100755 --- a/oops/frame/frame_.py +++ b/oops/frame/frame_.py @@ -6,9 +6,12 @@ from scipy.interpolate import InterpolatedUnivariateSpline from polymath import Matrix, Matrix3, Quaternion, Qube, Scalar, Vector3 +from oops.cache import Cache from oops.config import QUICK, LOGGING, PICKLE_CONFIG +from oops.fittable import Fittable_ from oops.transform import Transform + class Frame(object): """A Frame is an abstract class that returns a Transform (rotation matrix and spin vector) given a Scalar time. @@ -242,7 +245,7 @@ def reset_registry(): Frame.initialize_registry() #=========================================================================== - def register(self, shortcut=None, override=False, unpickled=False): + def register(self, shortcut=None, override=False): """Register a Frame's definition. A shortcut makes it possible to calculate one SPICE frame relative to @@ -255,10 +258,6 @@ def register(self, shortcut=None, override=False, unpickled=False): definition of any previous frame with the same name. The old frame might still exist, but it will not be available from the registry. - If unpickled is True and a frame with the same ID is already in the - registry, then this frame is not registered. Instead, its wayframe will - be defined by the frame with the same name that is already registered. - If the frame ID is None, blank, or begins with '.', it is treated as a temporary path and is not registered. """ @@ -268,6 +267,8 @@ def register(self, shortcut=None, override=False, unpickled=False): Frame.initialize_registry() frame_id = self.frame_id + if not hasattr(self, 'keys'): + self.keys = set() # Handle a shortcut if shortcut is not None: @@ -341,16 +342,13 @@ def register(self, shortcut=None, override=False, unpickled=False): if not hasattr(self, 'wayframe') or self.wayframe is None: self.wayframe = Frame.WAYFRAME_REGISTRY[frame_id] - # If this is not an unpickled frame, make it the frame returned by - # any of the standard keys. - if not unpickled: - # Cache (self.wayframe, self.reference); overwrite if necessary - key = (self.wayframe, self.reference) - if key in Frame.FRAME_CACHE: # remove an old version - Frame.FRAME_CACHE[key].keys -= {key} + # Cache (self.wayframe, self.reference); overwrite if necessary + key = (self.wayframe, self.reference) + if key in Frame.FRAME_CACHE: # remove an old version + Frame.FRAME_CACHE[key].keys -= {key} - Frame.FRAME_CACHE[key] = self - self.keys |= {key} + Frame.FRAME_CACHE[key] = self + self.keys |= {key} #=========================================================================== @staticmethod @@ -405,8 +403,7 @@ def as_frame_id(frame): #=========================================================================== @staticmethod def temporary_frame_id(): - """A temporary frame ID. This is assigned once and never re-used. - """ + """A temporary frame ID. This is assigned once and never re-used.""" while True: Frame.TEMPORARY_FRAME_ID += 1 @@ -415,12 +412,69 @@ def temporary_frame_id(): if frame_id not in Frame.WAYFRAME_REGISTRY: return frame_id + #=========================================================================== + @staticmethod + def id_is_temporary(frame_id): + """True if this is a temporary frame ID.""" + + return frame_id.startswith('TEMPORARY_') + #=========================================================================== def is_registered(self): """True if this frame is registered.""" return (self.frame_id in Frame.WAYFRAME_REGISTRY) + ############################################################################ + # Serialization support + ############################################################################ + + def _cache_id(self): + """Save this object's frame ID in a class dictionary `_FRAME_IDS`. + + This dictionary is keyed by a tuple of attributes of the object, as + returned by the method `_frame_key`. It returns the frame ID. + + If an object is constructed with a default frame ID, but an existing + frame with the same key already exists, the frame ID is reused (although + it will still be a different, unique object). + """ + + if self.shape != (): # shapeless + return + if not Frame.id_is_temporary(self.frame_id): # permanent id + return + if self.frame_id in self._FRAME_IDS.values(): # don't overwrite + return + if not Fittable_.is_frozen(self): # frozen + return + + key = Cache.clean_key(self._frame_key()) + if key in self._FRAME_IDS: # don't overwrite + return + + self._FRAME_IDS[key] = self.frame_id + + def _recover_id(self, frame_id=None): + """If the given frame ID is None, check the class's `_FRAME_IDS` + dictionary for a matching object and use its ID if found. + """ + + if frame_id is not None: + return frame_id + + if hasattr(self, '_FRAME_IDS'): + key = Cache.clean_key(self._frame_key()) + if key in self._FRAME_IDS: + return self._FRAME_IDS[key] + + return None + + def _state_id(self): + if Frame.id_is_temporary(self.frame_id): + return None + return self.frame_id + ############################################################################ # Frame Generators ############################################################################ diff --git a/oops/frame/frameshift.py b/oops/frame/frameshift.py new file mode 100755 index 00000000..c1be03d4 --- /dev/null +++ b/oops/frame/frameshift.py @@ -0,0 +1,92 @@ +########################################################################################## +# oops/frame/frameshift.py: Subclass FrameShift of class Frame +########################################################################################## + +from polymath import Scalar +from oops.fittable import Fittable +from oops.frame import Frame + + +class FrameShift(Frame, Fittable): + """A path defined by a time-shift of another frame. + + PLACEHOLDER CODE. "CONCEPTUALLY" CORRECT BUT NOT YET TESTED. + """ + + _FRAME_IDS = {} + + def __init__(self, dt, /, frame, *, frame_id=None): + """Constructor for a FrameShift. + + Parameters: + dt (float): The initial time shift in seconds. + frame (Frame or str): The Framee or ID to which the time shift applies. + frame_id (str, optional): The new frame ID; None to leave this frame + unregistered. + """ + + self.dt = dt + self.frame = frame + + # Required attributes + self.reference = self.frame.reference + self.origin = self.reference.origin + self.shape = self.frame.shape + self.frame_id = self._recover_id(frame_id) + + self.register() + self._cache_id() + + ###################################################################################### + # Fittable interface + ###################################################################################### + + def _set_params(self, params): + self.dt = params[0] + + @property + def _params(self): + return (self.dt,) + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _frame_key(self): + return (self.dt, self.frame) + + def __getstate__(self): + self.refresh(self) + self._cache_id() + return (self.dt, Frame.as_primary_frame(self.frame), self._state_id()) + + def __setstate__(self, state): + (dt, frame, frame_id) = state + self.__init__(dt, frame, frame_id=frame_id) + self.freeze() + + ###################################################################################### + # Frame API + ###################################################################################### + + def transform_at_time(self, time, quick={}): + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + """ + + time = Scalar.as_scalar(time) + return self.frame.transform_at_time(time + self.dt, quick=quick) + +########################################################################################## diff --git a/oops/frame/inclinedframe.py b/oops/frame/inclinedframe.py index 087b870a..e8ac4d29 100755 --- a/oops/frame/inclinedframe.py +++ b/oops/frame/inclinedframe.py @@ -1,117 +1,116 @@ -################################################################################ +########################################################################################## # oops/frame/inclinedframe.py: Subclass InclinedFrame of class Frame -################################################################################ +########################################################################################## from polymath import Qube, Scalar +from oops.fittable import Fittable_ from oops.frame import Frame -from oops.frame.poleframe import PoleFrame from oops.frame.rotation import Rotation from oops.frame.spinframe import SpinFrame + class InclinedFrame(Frame): - """InclinedFrame is a Frame subclass describing a frame that is inclined to - the equator of another frame. - - It is defined by an inclination, a node at epoch, and a nodal regression - rate. This frame is oriented to be "nearly inertial," meaning that a - longitude in the new frame is determined by measuring from the reference - longitude in the reference frame, along that frame's equator to the - ascending node, and thence along the ascending node. + """InclinedFrame is a Frame subclass describing a frame that is inclined to the + equator of another frame. + + It is defined by an inclination, a node at epoch, and a nodal regression rate. This + frame is oriented to be "nearly inertial," meaning that a longitude in the new frame + is determined by measuring from the reference longitude in the reference frame, along + that frame's equator to the ascending node, and thence along the ascending node. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _FRAME_IDS = {} - #=========================================================================== - def __init__(self, inc, node, rate, epoch, reference, despin=True, - frame_id=None, unpickled=False): + def __init__(self, inc, node, rate, epoch, reference, *, despin=True, frame_id=None): """Constructor for a InclinedFrame. - Input: - inc the inclination of the plane in radians. - - node the longitude of ascending node of the inclined plane - at the specified epoch, in radians. This measured - relative to the ascending node of the planet's equator - relative to its parent frame, which is typically J2000. - - rate the nodal regression rate of the inclined plane in - radians per second. Should be negative for a ring about - an oblate planet. - - epoch the time TDB at which the node is defined. - - reference a reference frame describing the central planet of the - inclined plane. - - despin True to return a nearly inertial frame; False to return - a frame in which the x-axis is tied to the ascending - node. - - frame_id the ID under which the frame will be registered; None - to leave the frame unregistered. - - unpickled True if this frame has been read from a pickle file. - - Note that inc, node, rate and epoch can all be scalars of arbitrary - shape. The shape of the InclinedFrame is the result of broadcasting all - these shapes together. + Parameters: + inc (Scalar, array-like, or float): Inclination angle in radians. + node (Scalar, array-like, or float): Longitude of None at epoch in radians. + rate (Scalar, array-like, or float): Rate of nodal presession in radians/s. + epoch Scalar, array-like, or float): Time in seconds TDB at which the `node` + applies. + reference (Frame or str): Frame or Frame ID describing the central planet of + the inclined plane. + despin (bool, optional): True for a nearly inertial frame, in which the x and + yaxes vary as little as possible while the zaxis rotates; False for a + frame in which the x axis is tied to the ascending node. + frame_id (str, optional): The ID under which the frame will be registered; + None to leave the frame unregistered + + Note that inc, node, rate and epoch can all be Scalars of arbitrary shape. The + shape of the InclinedFrame is the result of broadcasting all these shapes + together. """ self.inc = Scalar.as_scalar(inc) self.node = Scalar.as_scalar(node) self.rate = Scalar.as_scalar(rate) self.epoch = Scalar.as_scalar(epoch) + self.despin = bool(despin) + # Required attributes + self.reference = Frame.as_wayframe(reference) + self.origin = self.reference.origin self.shape = Qube.broadcast(self.inc, self.node, self.rate, self.epoch) + self.frame_id = self._recover_id(frame_id) - self.frame_id = frame_id - self.reference = Frame.as_wayframe(reference) - self.origin = self.reference.origin - self.keys = set() + # Update wayframe and frame_id; register if not temporary + self.register() + self._refresh() + self._cache_id() - self.spin1 = SpinFrame(self.node, self.rate, self.epoch, axis=2, - reference=self.reference) + def _refresh(self): + self.spin1 = SpinFrame(self.node, self.rate, self.epoch, axis=2, + reference=self.reference) self.rotate = Rotation(self.inc, axis=0, reference=self.spin1) + self.rotate.freeze() - self.despin = bool(despin) - if despin: + if self.despin: self.spin2 = SpinFrame(-self.node, -self.rate, self.epoch, axis=2, reference=self.rotate) else: self.spin2 = None - # Update wayframe and frame_id; register if not temporary - self.register(unpickled=unpickled) + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.inc.vals, self.node.vals, self.rate.vals, - self.epoch.vals, self.reference.frame_id, self.despin) - InclinedFrame.FRAME_IDS[key] = self.frame_id + def _frame_key(self): + return (self.inc, self.node, self.rate, self.epoch, self.reference, self.despin) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (self.inc, self.node, self.rate, self.epoch, - Frame.as_primary_frame(self.reference), - self.despin, self.shape) + Frame.as_primary_frame(self.reference), self.despin, self._state_id()) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (inc, node, rate, epoch, reference, despin, shape) = state - if shape == (): - key = (inc.vals, node.vals, rate.vals, epoch.vals, - reference.frame_id, despin) - frame_id = PoleFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (inc, node, rate, epoch, reference, despin, frame_id) = state + self.__init__(inc, node, rate, epoch, reference=reference, despin=despin, + frame_id=frame_id) + Fittable_.freeze(self) - self.__init__(inc, node, rate, epoch, reference, despin, - frame_id=frame_id, unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick=False): - """The Transform into the this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + """ xform = self.spin1.transform_at_time(time) xform = self.rotate.transform_at_time(time).rotate_transform(xform) @@ -121,12 +120,22 @@ def transform_at_time(self, time, quick=False): return xform - #=========================================================================== def node_at_time(self, time): - """The longitude of ascending node at the specified time.""" + """The vector defining the ascending node of this frame's XY plane relative to + the XY frame of its reference. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Vector3): The unit vector pointing in the direction of the ascending node. + """ # Locate the ascending nodes in the reference frame - return (self.node + self.rate * (Scalar.as_scalar(time) - - self.epoch)) % Scalar.TWOPI + time = Scalar.as_scalar(time) + return (self.node + self.rate * (time - self.epoch)) % Scalar.TWOPI -################################################################################ +########################################################################################## diff --git a/oops/frame/laplaceframe.py b/oops/frame/laplaceframe.py index 51fb4072..308fc649 100755 --- a/oops/frame/laplaceframe.py +++ b/oops/frame/laplaceframe.py @@ -1,138 +1,131 @@ -################################################################################ +########################################################################################## # oops/frame/laplaceframe.py: Subclass LaplaceFrame of class Frame -################################################################################ +########################################################################################## import numpy as np from polymath import Matrix3, Qube, Scalar, Vector3 +from oops.cache import Cache +from oops.fittable import Fittable_ from oops.frame import Frame -from oops.frame.poleframe import PoleFrame from oops.transform import Transform + class LaplaceFrame(Frame): """A Frame subclass defined by a Kepler Path and a tilt angle. - The new Z-axis is constructed by rotating the planet's pole by a specified, - fixed angle toward the pole of the orbit. The rotation occurs around the - ascending node of the orbit on the orbit's defined reference plane. + The new Z-axis is constructed by rotating the planet's pole by a specified, fixed + angle toward the pole of the orbit. The rotation occurs around the ascending node of + the orbit on the orbit's defined reference plane. - As an example, use the Kepler Path of Triton, which is defined relative to - Neptune's PoleFrame, to construct tilted Laplace Planes for each of - Neptune's inner satellites. Note, however, that the tilt angles should be - negative because Triton is retrograde, and therefore its orbital ascending - node is the descending node for the orbits of the inner moons. + As an example, use the Kepler Path of Triton, which is defined relative to Neptune's + PoleFrame, to construct tilted Laplace Planes for each of Neptune's inner satellites. + Note, however, that the tilt angles should be negative because Triton is retrograde, + and therefore its orbital ascending node is the descending node for the orbits of the + inner moons. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _FRAME_IDS = {} - #=========================================================================== - def __init__(self, orbit, tilt=0., frame_id='+', cache_size=1000, - unpickled=False): + def __init__(self, orbit, tilt=0., *, frame_id='+', cache_size=100): """Constructor for a LaplaceFrame. - Input: - orbit a Kepler Path object. - - tilt The tilt of the Laplace Plane's pole from the planet's - pole toward or beyond the invariable pole. - - frame_id the ID under which the frame will be registered. None to - leave the frame unregistered. If the value is "+", then - the registered name is the name of the planet's - ring_frame with the suffix "_LAPLACE". Note that this - default ID will not be unique if frames are defined for - multiple Laplace Planes around the same planet. - - cache_size number of transforms to cache. This can be useful - because it avoids unnecessary SPICE calls when the frame - is being used repeatedly at a finite set of times. - - unpickled True if this frame has been read from a pickle file. + Parameters: + orbit (KeplerPath): The orbit of the body for which a Laplace Plane is needed. + tilt (Scalar, array-like, or float): The tilt of the Laplace Plane's pole from + the planet's pole toward or beyond the invariable pole. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. + If the value is "+", then the registered name is the planet's `ring_frame` + ID with "_LAPLACE" appended. + cache_size (int, optional): Number of transforms to cache. This can be useful + because it avoids unnecessary SPICE calls when the frame is being used + repeatedly at a finite set of times. """ self.orbit = orbit - self.planet = self.orbit.planet - - self.orbit_frame = self.orbit.frame.wrt(Frame.J2000) - self.planet_frame = self.planet.frame.wrt(Frame.J2000) + self._planet = self.orbit.planet + self._orbit_frame = self.orbit.frame.wrt(Frame.J2000) + self._planet_frame = self.planet.frame.wrt(Frame.J2000) self.tilt = Scalar.as_scalar(tilt) - self.cos_tilt = self.tilt.cos() - self.sin_tilt = self.tilt.sin() + self._cos_tilt = self.tilt.cos() + self._sin_tilt = self.tilt.sin() + + self._cache_size = cache_size + # Required attributes self.reference = Frame.J2000 self.origin = self.orbit.origin self.shape = Qube.broadcasted_shape(self.orbit.shape, self.tilt) - self.keys = set() - - # Define cache - self.cache = {} - self.trim_size = max(cache_size//10, 1) - self.given_cache_size = cache_size - self.cache_size = cache_size + self.trim_size - self.cache_counter = 0 - self.cached_value_returned = False # Just used for debugging # Fill in the frame ID if frame_id is None: - self.frame_id = Frame.temporary_frame_id() + self.frame_id = Frame._recover_id(frame_id) elif frame_id == '+': - self.frame_id = self.orbit.planet.ring_frame.frame_id + '_LAPLACE' + frame_id = self.orbit.planet.ring_frame.frame_id + '_LAPLACE' elif frame_id.startswith('+'): - self.frame_id = (self.orbit.planet.ring_frame.frame_id + '_' - + frame_id[1:]) + frame_id = (self.orbit.planet.ring_frame.frame_id + '_' + frame_id[1:]) else: self.frame_id = frame_id # Register if necessary - self.register(unpickled=unpickled) + self.register() + self._refresh() + self._cache_id() + + def _refresh(self): + self._cache = Cache(self._cache_size) + + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.orbit.path_id, self.tilt.vals) - LaplaceFrame.FRAME_IDS[key] = self.frame_id + def _frame_key(self): + return (self.orbit, self.tilt) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (Frame.PATH_CLASS.as_primary_path(self.orbit), - self.tilt, self.given_cache_size, self.shape) + Fittable_.refresh(self) + self._cache_id() + return (Frame.PATH_CLASS.as_primary_path(self.orbit), self.tilt, self._state_id(), + self._cache_size) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (orbit, tilt, cache_size, shape) = state - if shape == (): - key = (orbit.path_id, tilt.vals) - frame_id = PoleFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (orbit, tilt, frame_id, cache_size) = state + self.__init__(orbit, tilt, frame_id=frame_id, cache_size=cache_size) + Fittable_.freeze(self) - self.__init__(orbit, tilt, frame_id=frame_id, - cache_size=cache_size, unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick={}): - """The Transform into the this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + """ time = Scalar.as_scalar(time) - # Check cache first if time is a Scalar + # Check cache first if time is shapeless if time.shape == (): - key = time.values - - if key in self.cache: - self.cached_value_returned = True - (count, key, xform) = self.cache[key] - self.cache_counter += 1 - count[0] = self.cache_counter + xform = self._cache[time.vals] + if xform: return xform - self.cached_value_returned = False - # All vectors below are in J2000 coordinates - orbit_ref_xform = self.orbit.frame_wrt_j2000.transform_at_time(time, - quick=quick) + orbit_ref_xform = self.orbit.frame_wrt_j2000.transform_at_time(time, quick=quick) orbit_ref_x_axis = orbit_ref_xform.unrotate(Vector3.XAXIS).wod orbit_ref_y_axis = orbit_ref_xform.unrotate(Vector3.YAXIS).wod orbit_ref_z_axis = orbit_ref_xform.unrotate(Vector3.ZAXIS).wod @@ -143,14 +136,11 @@ def transform_at_time(self, time, quick={}): sin_node = np.sin(node_lon) orbit_node = cos_node * orbit_ref_x_axis + sin_node * orbit_ref_y_axis - # This vector is 90 degrees behind of the node on the orbit reference - # equator - orbit_target = ( sin_node * orbit_ref_x_axis + - -cos_node * orbit_ref_y_axis) + # This vector is 90 degrees behind of the node on the orbit reference equator + orbit_target = sin_node * orbit_ref_x_axis - cos_node * orbit_ref_y_axis # This is the pole of the orbit - orbit_pole = (self.orbit.cos_i * orbit_ref_z_axis + - self.orbit.sin_i * orbit_target) + orbit_pole = self.orbit.cos_i * orbit_ref_z_axis + self.orbit.sin_i * orbit_target # Get the planet's pole in J2000 planet_xform = self.planet_frame.transform_at_time(time, quick=quick) @@ -161,65 +151,50 @@ def transform_at_time(self, time, quick={}): tilt_target = orbit_pole.perp(planet_pole).unit() # Now, rotation is easy - laplace_pole = (self.cos_tilt * planet_pole + - self.sin_tilt * tilt_target) + laplace_pole = self.cos_tilt * planet_pole + self.sin_tilt * tilt_target # We still have to be very careful to match up the orbital longitude. # Angles are measured... # 1. From the reference direction in the orbit's reference frame. - # 2. Along the equator plane of the orbit's reference frame to the - # ascending node of the Laplace plane - # 3. Then along the Laplace plane - - # This vector is at the intersection of the reference plane and the - # Laplace plane - # common_node = orbit_ref_z_axis.cross(laplace_pole) - # HOWEVER, the two vectors are very close (0.11 degrees apart in the - # case of Proteus) so this is a very imprecise calculation. + # 2. Along the equator plane of the orbit's reference frame to the ascending node + # of the Laplace plane. + # 3. Then along the Laplace plane. + + # This vector is at the intersection of the reference plane and the Laplace plane + # common_node = orbit_ref_z_axis.cross(laplace_pole) + # HOWEVER, the two vectors are very close (0.11 degrees apart in the case of + # Proteus) so this is a very imprecise calculation. # - # Instead, we use the orbital node. This is perpendicular to the Z-axis - # of the orbit reference frame and it is _nearly_ perpendicular to the - # Laplace pole; for the Neptune system, the angle is 90.07 degrees. The - # error arising in the longitude by ignoring tilt of the plane by 0.07 - # degrees will be a factor of ~ cos(0.07 deg) ~ one part in 10^6. + # Instead, we use the orbital node. This is perpendicular to the Z-axis of the + # orbit reference frame and it is _nearly_ perpendicular to the Laplace pole; for + # the Neptune system, the angle is 90.07 degrees. The error arising in the + # longitude by ignoring tilt of the plane by 0.07 degrees will be a factor of + # ~ cos(0.07 deg) ~ one part in 10^6. common_node = orbit_node # Create the rotation matrix matrix = Matrix3.twovec(laplace_pole, 2, common_node, 0) - # This matrix rotates coordinates from J2000 to a frame in which the - # Z-axis is along the Laplace pole and the X-axis is at the common node. + # This matrix rotates coordinates from J2000 to a frame in which the Z-axis is + # along the Laplace pole and the X-axis is at the common node. # Get the longitude of the common node in the orbit reference frame common_node_wrt_orbit_ref = orbit_ref_xform.rotate(common_node).wod (x, y, _) = common_node_wrt_orbit_ref.to_scalars() common_node_lon = y.arctan2(x) - # Rotate vectors around the Z-axis in the new frame to so that the - # X-axis falls at this longitude + # Rotate vectors around the Z-axis in the new frame to so that the X-axis falls at + # this longitude matrix = Matrix3.z_rotation(common_node_lon) * matrix # Create the transform - xform = Transform(matrix, Vector3.ZERO, self.wayframe, Frame.J2000, - self.origin) + xform = Transform(matrix, Vector3.ZERO, self.wayframe, Frame.J2000, self.origin) # Cache the transform if necessary - if time.shape == () and self.given_cache_size > 0: - - # Trim the cache, removing the values used least recently - if len(self.cache) >= self.cache_size: - all_keys = list(self.cache.values()) - all_keys.sort() - for (_, old_key, _) in all_keys[:self.trim_size]: - del self.cache[old_key] - - # Insert into the cache - key = time.values - self.cache_counter += 1 - count = np.array([self.cache_counter]) - self.cache[key] = (count, key, xform) + if time.shape == (): + self._cache[time.vals] = xform return xform -################################################################################ +########################################################################################## diff --git a/oops/frame/navigation.py b/oops/frame/navigation.py index 6c3d5b63..489541d3 100755 --- a/oops/frame/navigation.py +++ b/oops/frame/navigation.py @@ -1,85 +1,114 @@ -################################################################################ +########################################################################################## # oops/frame/navigation.py: Fittable subclass Navigation of class Frame -################################################################################ +########################################################################################## import numpy as np -from polymath import Matrix3, Vector, Vector3 +from polymath import Matrix3, Vector3 from oops.fittable import Fittable from oops.frame import Frame from oops.transform import Transform + class Navigation(Frame, Fittable): - """A Frame subclass describing a fittable, fixed offset from another frame, - defined by two or three rotation angles. + """A Frame subclass describing a fittable, fixed offset from another frame, defined by + two or three rotation angles. """ - # Note: Navigation frames are not generally re-used, so their IDs are - # expendable. Frame IDs are not preserved during pickling. - - #=========================================================================== - def __init__(self, angles, reference, frame_id=None, override=False, - _matrix=None): + def __init__(self, arg, /, reference, *, frame_id=None, override=False, _matrix=None): """Constructor for a Navigation Frame. - Input: - angles two or three angles of rotation in radians. The order of - the rotations is about the y, x, and (optionally) z - axes. These angles rotate a vector in the reference - frame into this frame. - reference the frame or frame ID relative to which this rotation is - defined. - frame_id the ID to use; None to use a temporary ID. - override True to override a pre-existing frame with the same ID. - _matrix an optional 3x3 matrix, used internally, to speed up the - copying of Navigation objects. If not None, it must - contain the Matrix3 object that performs the defined - rotation. + Parameters: + arg (array-like or Navigation): Two or three angles of rotation in radians. + The order of the rotations is about the y, x, and (optionally) z axes. + These angles rotate a vector in the reference frame into this frame. + Alternatively, specify another Navigation object and this object will be + linked to that one, meaning that the rotation angles will always match. + reference (Frame or str): The frame or frame ID relative to which this + navigation applies. + frame_id (str, optional): The frame ID to use; None to use a temporary ID. + override (bool, optional): True to override a pre-existing frame with the same + ID. + _matrix (Matrix3, optional): A 3x3 matrix, used internally, to speed up the + copying of Navigation objects. If not None, it must contain the Matrix3 + object that performs the defined rotation. """ - if isinstance(angles, Vector): - angles = angles.vals - - self.angles = np.array(angles) - if self.angles.shape not in ((2,),(3,)): + if isinstance(arg, Navigation): + self.link = arg + self.link.refresh() + self.angles = self.link.angles + self._matrix = self.link._matrix + else: + self.angles = tuple(arg) + self.link = None + self._matrix = _matrix + + self._fittable_nparams = len(self.angles) + if self._fittable_nparams not in {2, 3}: raise ValueError('two or three Navigation angles must be provided') - self.cache = {} - self.param_name = 'angles' - self.nparams = self.angles.shape[0] - - if _matrix is None: - _matrix = Navigation._rotmat(self.angles[0],1) - _matrix = Navigation._rotmat(self.angles[1],0) * _matrix - - if self.nparams > 2 and self.angles[2] != 0.: - _matrix = Navigation._rotmat(self.angles[2], 2) * _matrix - self.reference = Frame.as_wayframe(reference) - self.origin = self.reference.origin - self.frame_id = frame_id - self.shape = self.reference.shape - self.keys = set() + self.origin = self.reference.origin + self.shape = self.reference.shape + self.frame_id = Frame._recover_id(frame_id) # Update wayframe and frame_id; register if not temporary self.register(override=override) # Fill in transform (_after_ registration) - self.transform = Transform(_matrix, Vector3.ZERO, - self, self.reference, self.origin) + self._refresh(matrix=self._matrix) + + ###################################################################################### + # Serialization support + ###################################################################################### - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (self.angles, Frame.as_primary_frame(self.reference)) + self.refresh() + return (self.angles, Frame.as_primary_frame(self.reference), self._state_id()) def __setstate__(self, state): - self.__init__(*state) + (angles, reference, frame_id) = state + self.__init__(angles, reference, frame_id=frame_id) + self.freeze() + + ###################################################################################### + # Fittable interface + ###################################################################################### + + def _set_params(self, params): + """Redefine the navigation angles.""" + + if self.link: + self.link.set_params(params) + self.angles = self.link.angles + else: + self.angles = params + + @property + def _params(self): + return self.angles + + def _refresh(self, matrix=None): + """Update the internals.""" + + if self.link: + self.angles = self.link.angles + self._matrix = self.link._matrix + elif matrix is None: + matrix = Navigation._rotmat(self.angles[0], 1) + matrix = Navigation._rotmat(self.angles[1], 0) * matrix + if self._fittable_nparams > 2 and self.angles[2] != 0.: + matrix = Navigation._rotmat(self.angles[2], 2) * matrix + self._matrix = matrix + + self._transform = Transform(matrix, Vector3.ZERO, self, self.reference, + self.origin) - #=========================================================================== @staticmethod def _rotmat(angle, axis): - """Internal function to return a matrix that performs a rotation about - a single specified axis. + """Internal function to return a matrix that performs a rotation about a single + specified axis. """ axis2 = axis @@ -95,50 +124,31 @@ def _rotmat(angle, axis): return Matrix3(mat) - #=========================================================================== - def transform_at_time(self, time, quick=False): - """The Transform to the given Frame at a specified Scalar of - times. - """ - - return self.transform - - ############################################################################ - # Fittable interface - ############################################################################ - - def set_params_new(self, params): - """Redefines the Fittable object, using this set of parameters. Unlike - method set_params(), this method does not check the cache first. - Override this method if the subclass should use a cache. - - Input: - params a list, tuple or 1-D Numpy array of floating-point - numbers, defining the parameters to be used in the - object returned. - """ - - params = np.array(params).copy() - if self.angles.shape != params.shape: - raise ValueError('new parameter shape does not match original') + ###################################################################################### + # Frame API + ###################################################################################### - self.angles = params + def transform_at_time(self, time, quick=False): + """Transform that rotates coordinates from the reference frame to this frame. - matrix = Navigation._rotmat(self.angles[0],1) - matrix = Navigation._rotmat(self.angles[1],0) * matrix + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. - if self.nparams > 2 and self.angles[2] != 0.: - matrix = Navigation._rotmat(self.angles[2],2) * matrix + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. - self.transform = Transform(matrix, Vector3.ZERO, self, - self.reference, self.origin) + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. - def copy(self): - """A deep copy of the given object. The copy can be safely modified - without affecting the original. + Notes: + Navigation is a fixed frame, so the transform relative to the `reference` + frame is independent of time. """ - return Navigation(self.angles.copy(), self.reference, - matrix=self.transform.matrix.copy()) + return self._transform -################################################################################ +########################################################################################## diff --git a/oops/frame/poleframe.py b/oops/frame/poleframe.py index e5abb401..c12c2f7d 100755 --- a/oops/frame/poleframe.py +++ b/oops/frame/poleframe.py @@ -1,72 +1,61 @@ -################################################################################ +########################################################################################## # oops/frame/poleframe.py: Subclass PoleFrame of class Frame -################################################################################ +########################################################################################## import numpy as np from polymath import Matrix3, Qube, Scalar, Vector3 +from oops.cache import Cache +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform + class PoleFrame(Frame): - """A Frame subclass describing a non-rotating frame centered on the Z-axis - of a body's pole vector. - - This differs from RingFrame in that the pole may precess around a separate, - invariable pole for the system. Because of this behavior, the reference - longitude is defined as the ascending node of the invariable plane rather - than as the ascending node of the ring plane. This frame is recommended for - Neptune in particular. + """A Frame subclass describing a non-rotating frame centered on the Z-axis of a body's + pole vector. + + This differs from RingFrame in that the pole may precess around a separate, invariable + pole for the system. Because of this behavior, the reference longitude is defined as + the ascending node of the invariable plane rather than as the ascending node of the + ring plane. This frame is recommended for Neptune in particular. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _FRAME_IDS = {} - #=========================================================================== - def __init__(self, frame, pole, retrograde=False, aries=False, frame_id='+', - cache_size=1000, unpickled=False): + def __init__(self, frame, pole, *, retrograde=False, aries=False, frame_id='+', + cache_size=100): """Constructor for a PoleFrame. Input: - frame a (possibly) rotating frame, or its ID, describing the - central planet relative to J2000. This is typically a - body's rotating SpiceFrame. - - pole The pole of the invariable plane, about which planet's - pole precesses. This enables the reference longitude to - be defined properly. Defined in J2000 coordinates. - - retrograde True to flip the sign of the Z-axis. Necessary for - retrograde systems like Uranus. - - aries True to use the First Point of Aries as the longitude - reference; False to use the ascending node of the - invariable plane. Note that the former might be - preferred in a situation where the invariable pole is - uncertain, because small changes in the invariable pole - will have only a limited effect on the absolute - reference longitude. - - frame_id the ID under which the frame will be registered. None to - leave the frame unregistered. If the value is "+", then - the registered name is the planet frame's name with the - suffix "_POLE". Note that this default ID will not be - unique if frames are defined for multiple Laplace Planes - around the same planet. - - cache_size number of transforms to cache. This can be useful - because it avoids unnecessary SPICE calls when the frame - is being used repeatedly at a finite set of times. - - unpickled True if this frame has been read from a pickle file. + frame (Frame or str): Frame or frame ID for a (possibly) rotating frame + describing the central planet relative to J2000. This is typically a + body's rotating SpiceFrame. + pole (Vector3 or array-like): The pole of the invariable plane, about which + planet's pole precesses, in J2000 coordinates. This enables the reference + longitude to be defined properly. + retrograde (bool, optional): True to flip the sign of the Z-axis. This is + necessary for retrograde systems like Uranus. + aries (bool, optional): True to use the First Point of Aries as the longitude + reference; False to use the ascending node of the invariable plane. Note + that the former might be preferred in a situation where the invariable + pole is uncertain, because small changes in the invariable pole will have + only a limited effect on the absolute reference longitude. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. + If the value is "+", then the registered name is the planet frame's ID + with "_POLE" appended. If the value is "+" followed by text, the ID is + the planet frame's ID followed by "_" and the given text. + cache_size (int, optional): The number of transforms to cache. This can be + useful because it avoids unnecessary SPICE calls when the frame is being + used repeatedly at a finite set of times. """ # Rotates from J2000 to the invariable frame pole = Vector3.as_vector3(pole) (ra, dec, _) = pole.to_ra_dec_length(recursive=False) - self.invariable_matrix = Matrix3.pole_rotation(ra,dec) - # Rotates J2000 coordinates into a frame where the Z-axis is the - # invariable pole and the X-axis is the ascending node of the - # invariable plane on J2000 + self.invariable_matrix = Matrix3.pole_rotation(ra, dec) + # Rotates J2000 coordinates into a frame where the Z-axis is the invariable pole + # and the X-axis is the ascending node of the invariable plane on J2000 self.invariable_pole = pole self.invariable_node = Vector3.ZAXIS.ucross(pole) @@ -79,24 +68,17 @@ def __init__(self, frame, pole, retrograde=False, aries=False, frame_id='+', self.invariable_node_lon = 0. self.planet_frame = Frame.as_frame(frame).wrt(Frame.J2000) - self.origin = self.planet_frame.origin self.retrograde = bool(retrograde) - self.keys = set() + + self._cache_size = cache_size + + # Required attributes self.reference = Frame.J2000 - self.shape = Qube.broadcasted_shape(self.invariable_pole, - self.planet_frame) - - # Define cache - self.cache = {} - self.trim_size = max(cache_size//10, 1) - self.given_cache_size = cache_size - self.cache_size = cache_size + self.trim_size - self.cache_counter = 0 - self.cached_value_returned = False # Just used for debugging - - # Fill in the frame ID + self.origin = self.planet_frame.origin + self.shape = Qube.broadcasted_shape(self.invariable_pole, self.planet_frame) + if frame_id is None: - self.frame_id = Frame.temporary_frame_id() + self.frame_id = Frame._recover_id(frame_id) elif frame_id == '+': self.frame_id = self.planet_frame.frame_id + '_POLE' elif frame_id.startswith('+'): @@ -105,53 +87,61 @@ def __init__(self, frame, pole, retrograde=False, aries=False, frame_id='+', self.frame_id = frame_id # Register if necessary - self.register(unpickled=unpickled) + self.register() + self._refresh() + self._cache_id() + + def _refresh(self): + self._cache = Cache(self._cache_size) + + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.planet_frame.frame_id, - tuple(self.invariable_pole.vals), - retrograde, aries) - PoleFrame.FRAME_IDS[key] = self.frame_id + def _frame_key(self): + return (self.planet_frame, self.invariable_pole, self.retrograde, self.aries) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (Frame.as_primary_frame(self.planet_frame), - self.invariable_pole, self.retrograde, - self.aries, self.given_cache_size, self.shape) + Fittable_.refresh(self) + self._cache_id() + return (Frame.as_primary_frame(self.planet_frame), self.invariable_pole, + self.retrograde, self.aries, self._state_id(), self._cache_size) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (frame, pole, retrograde, aries, cache_size, shape) = state - if shape == (): - key = (frame.frame_id, tuple(pole.vals), retrograde, aries) - frame_id = PoleFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (frame, pole, retrograde, aries, frame_id, cache_size) = state + self.__init__(frame, pole, retrograde=retrograde, aries=aries, frame_id=frame_id, + cache_size=cache_size) + Fittable_.freeze(self) - self.__init__(frame, pole, retrograde, aries, frame_id=frame_id, - cache_size=cache_size, unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick={}): - """The Transform into the this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + """ time = Scalar.as_scalar(time) # Check cache first if time is a Scalar if time.shape == (): - key = time.values - - if key in self.cache: - self.cached_value_returned = True - (count, key, xform) = self.cache[key] - self.cache_counter += 1 - count[0] = self.cache_counter + xform = self._cache[time.vals] + if xform: return xform - self.cached_value_returned = False - # Calculate the planet frame for the current time in J2000 xform = self.planet_frame.transform_at_time(time, quick=quick) @@ -162,15 +152,13 @@ def transform_at_time(self, time, quick={}): if self.retrograde: z_axis = -z_axis - planet_matrix = Matrix3.twovec(z_axis, 2, - Vector3.ZAXIS.cross(z_axis), 0) + planet_matrix = Matrix3.twovec(z_axis, 2, Vector3.ZAXIS.cross(z_axis), 0) - # This is the RingFrame matrix. It rotates from J2000 to the frame where - # the pole at epoch is along the Z-axis and the ascending node relative - # to the J2000 equator is along the X-axis. + # This is the RingFrame matrix. It rotates from J2000 to the frame where the pole + # at epoch is along the Z-axis and the ascending node relative to the J2000 + # equator is along the X-axis. - # Locate the J2000 ascending node of the RingFrame on the invariable - # plane. + # Locate the J2000 ascending node of the RingFrame on the invariable plane. planet_pole_j2000 = planet_matrix.inverse() * Vector3.ZAXIS joint_node_j2000 = self.invariable_pole.cross(planet_pole_j2000) @@ -185,31 +173,31 @@ def transform_at_time(self, time, quick={}): self.invariable_node_lon) * planet_matrix # Create the transform - xform = Transform(Matrix3(matrix, xform.matrix.mask), Vector3.ZERO, - self.wayframe, self.reference, self.origin) + xform = Transform(Matrix3(matrix, xform.matrix.mask), Vector3.ZERO, self.wayframe, + self.reference, self.origin) # Cache the transform if necessary - if time.shape == () and self.given_cache_size > 0: - - # Trim the cache, removing the values used least recently - if len(self.cache) >= self.cache_size: - all_keys = list(self.cache.values()) - all_keys.sort() - for (_, old_key, _) in all_keys[:self.trim_size]: - del self.cache[old_key] - - # Insert into the cache - key = time.values - self.cache_counter += 1 - count = np.array([self.cache_counter]) - self.cache[key] = (count, key, xform) + if time.shape == (): + self._cache[time.vals] = xform return xform - #=========================================================================== def node_at_time(self, time, quick={}): - """Angle from the frame's X-axis to the ring plane ascending node on the - invariable plane. + """The vector defining the ascending node of this frame's XY plane relative to + the XY frame of its reference. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Vector3): The unit vector pointing in the direction of the ascending node. + + Notes: + TwoVector is a fixed frame, so its node vector relative to the `reference` + frame is independent of time. """ # Calculate the pole for the current time @@ -229,4 +217,4 @@ def node_at_time(self, time, quick={}): node = (y.arctan2(x) + Scalar.HALFPI + self.invariable_node_lon) return node % Scalar.TWOPI -################################################################################ +########################################################################################## diff --git a/oops/frame/postargframe.py b/oops/frame/postargframe.py index 7fbfde2e..67c61d80 100755 --- a/oops/frame/postargframe.py +++ b/oops/frame/postargframe.py @@ -1,33 +1,33 @@ -################################################################################ +########################################################################################## # oops/frame/postargframe.py: Subclass PosTargFrame of class Frame -################################################################################ +########################################################################################## import numpy as np from polymath import Matrix3, Vector3 +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform + class PosTargFrame(Frame): - """A Frame subclass describing a fixed rotation about the X and Y axes, so - the Z-axis of another frame falls at a slightly different position in this - frame. + """A Frame subclass describing a fixed rotation about the X and Y axes, so the Z-axis + of another frame falls at a slightly different position in this frame. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _FRAME_IDS = {} - #=========================================================================== - def __init__(self, xpos, ypos, reference, frame_id=None, unpickled=False): + def __init__(self, xpos, ypos, reference, *, frame_id=None): """Constructor for a PosTarg Frame. - Input: - xpos the X-position of the reference frame's Z-axis in this - frame, in radians. - ypos the Y-position of the reference frame's Z-axis in this - frame, in radians. - reference the frame relative to which this frame is defined. - frame_id the ID to use; None to leave the frame unregistered. - unpickled True if this frame has been read from a pickle file. + Parameters: + xpos (float): The X-position of the reference frame's Z-axis in this frame, in + radians. + ypos (float): The Y-position of the reference frame's Z-axis in this frame, in + radians. + reference (Frame or str): The frame or frame ID relative to which this frame + is defined. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. """ self.xpos = float(xpos) @@ -35,59 +35,69 @@ def __init__(self, xpos, ypos, reference, frame_id=None, unpickled=False): cos_x = np.cos(self.xpos) sin_x = np.sin(self.xpos) - cos_y = np.cos(self.ypos) sin_y = np.sin(self.ypos) - xmat = Matrix3([[1., 0., 0. ], [0., cos_y, sin_y], [0., -sin_y, cos_y]]) - ymat = Matrix3([[ cos_x, 0., sin_x], [ 0., 1., 0. ], [-sin_x, 0., cos_x]]) + self._matrix = ymat * xmat - mat = ymat * xmat - - self.frame_id = frame_id self.reference = Frame.as_wayframe(reference) - self.origin = self.reference.origin - self.shape = self.reference.shape - self.keys = set() + self.origin = self.reference.origin + self.shape = self.reference.shape + self.frame_id = self._recover_id(frame_id) + + self.register() + self._cache_id() - # Update wayframe and frame_id; register if not temporary - self.register(unpickled=unpickled) + self.transform = Transform(self._matrix, Vector3.ZERO, self, self.reference, + self.origin) - # It needs a wayframe before we can define the transform - self.transform = Transform(mat, Vector3.ZERO, - self, self.reference, self.origin) + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.xpos, self.ypos, self.reference.frame_id) - PosTargFrame.FRAME_IDS[key] = self.frame_id + def _frame_key(self): + return (self.xpos, self.ypos, self.reference) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (self.xpos, self.ypos, - Frame.as_primary_frame(self.reference), self.shape) + self._cache_id() + return (self.xpos, self.ypos, Frame.as_primary_frame(self.reference), + self._state_id()) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (xpos, ypos, reference, shape) = state - if shape == (): - key = (xpos, ypos, reference.frame_id) - frame_id = PosTargFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (xpos, ypos, reference, frame_id) = state + self.__init__(xpos, ypos, reference, frame_id=frame_id) + Fittable_.freeze(self) - self.__init__(xpos, ypos, reference, frame_id=frame_id, unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== - def transform_at_time(self, time, quick={}): - """The Transform into the this Frame at a Scalar of times.""" + def transform_at_time(self, time, quick=False): + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + + Notes: + Navigation is a fixed frame, so the transform relative to the `reference` + frame is independent of time. + """ return self.transform -################################################################################ +########################################################################################## diff --git a/oops/frame/ringframe.py b/oops/frame/ringframe.py index f8d55439..2707dd12 100755 --- a/oops/frame/ringframe.py +++ b/oops/frame/ringframe.py @@ -1,131 +1,131 @@ -################################################################################ +########################################################################################## # oops/frame/ringframe.py: Subclass RingFrame of class Frame -################################################################################ +########################################################################################## import numpy as np from polymath import Matrix3, Qube, Scalar, Vector3 +from oops.cache import Cache +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform + class RingFrame(Frame): - """A Frame subclass describing a non-rotating frame centered on the Z-axis - of another frame, but oriented with the X-axis fixed along the ascending - node of the equator within the reference frame. + """A Frame subclass describing a non-rotating frame centered on the Z-axis of another + frame, but oriented with the X-axis fixed along the ascending node of the equator + within the reference frame. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + FRAME_IDS = {} - #=========================================================================== - def __init__(self, frame, epoch=None, retrograde=False, aries=False, - frame_id='+', cache_size=1000, unpickled=False): + def __init__(self, frame, epoch=None, *, retrograde=False, aries=False, frame_id='+', + cache_size=100): """Constructor for a RingFrame Frame. - Input: - frame a frame describing the central planet of the ring plane - relative to J2000. - - epoch the time TDB at which the frame is to be evaluated. If - this is specified, then the frame will be precisely - inertial, based on the orientation of the pole at the - specified epoch. If it is unspecified, then the frame - could wobble or rotate slowly due to precession of the - planet's pole. - - retrograde True to flip the sign of the Z-axis. Necessary for - retrograde systems like Uranus. - - aries True to use the First Point of Aries as the longitude - reference; False to use the ascending node of the ring - plane. Note that the former might be preferred in a - situation where the ring plane is uncertain, wobbles, or - is nearly parallel to the celestial equator. In these - situations, using Aries as a reference will reduce the - uncertainties related to the pole orientation. - - frame_id the ID under which the frame will be registered. None to - leave the frame unregistered. If the value is "+", then - the registered name is the planet frame's name with the - suffix "_DESPUN" if epoch is None, or "_INERTIAL" if an - epoch is specified. - - cache_size number of transforms to cache. This can be useful - because it avoids unnecessary SPICE calls when the frame - is being used repeatedly at a finite set of times. - - unpickled True if this frame has been read from a pickle file. + Parameters: + frame (Frame or str): The frame or frame ID describing the central planet of + the ring plane relative to J2000. + epoch (Scalar or float): The time TDB at which the frame is to be evaluated. + If this is specified, then the frame will be precisely inertial, based on + the orientation of the pole at the specified epoch. If it is unspecified, + then the frame could wobble or rotate slowly due to precession of the + planet's pole. + retrograde (bool, optional): True to flip the sign of the Z-axis. Necessary + for retrograde systems like Uranus. + aries (bool, optional): True to use the First Point of Aries as the longitude + reference; False to use the ascending node of the ring plane. Note that + the former might be preferred in a situation where the ring plane is + uncertain, wobbles, or is nearly parallel to the celestial equator. In + these situations, using Aries as a reference will reduce the uncertainties + related to the pole orientation. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. + If the value is "+", then the registered name is the planet frame's ID + with "_DESPUN" appended if `epoch` is None or "_INERTIAL" if an epoch is + specified. + cache_size (int, optinal): The number of transforms to cache. This can be + useful because it avoids unnecessary SPICE calls when the frame is being + used repeatedly at a finite set of times. """ self.planet_frame = Frame.as_frame(frame).wrt(Frame.J2000) - self.reference = Frame.J2000 self.epoch = None if epoch is None else Scalar.as_scalar(epoch) self.retrograde = bool(retrograde) - self.shape = Qube.broadcasted_shape(self.planet_frame, self.epoch) - self.keys = set() - self.aries = bool(aries) + self._cache_size = cache_size - # The frame might not be exactly inertial due to polar precession, but - # it is good enough + # Required attributes + self.reference = Frame.J2000 + # The frame might not be exactly inertial due to polar precession, but it is close + # enough for all practical purposes. self.origin = None - - # Define cache - self.cache = {} - self.trim_size = max(cache_size//10, 1) - self.given_cache_size = cache_size - self.cache_size = cache_size + self.trim_size - self.cache_counter = 0 - self.cached_value_returned = False # Just used for debugging + self.shape = Qube.broadcasted_shape(self.planet_frame, self.epoch) # Fill in the frame ID - if frame_id is None: - self.frame_id = Frame.temporary_frame_id() - elif frame_id == '+': + if frame_id == '+': if self.epoch is None: - self.frame_id = self.planet_frame.frame_id + "_DESPUN" + self.frame_id = self.planet_frame.frame_id + '_DESPUN' else: - self.frame_id = self.planet_frame.frame_id + "_INERTIAL" + self.frame_id = self.planet_frame.frame_id + '_INERTIAL' else: - self.frame_id = frame_id + self.frame_id = self._recover_id(frame_id) # Register if necessary - self.register(unpickled=unpickled) + self.register() + self._refresh() + self._cache_id() + + def _refresh(self): + self._cache = Cache(self._cache_size) # For a fixed epoch, derive the inertial tranform now self.transform = None if self.epoch is not None: self.transform = self.transform_at_time(self.epoch) - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.planet_frame.frame_id, - None if self.epoch is None else self.epoch.vals, - self.retrograde, self.aries) - RingFrame.FRAME_IDS[key] = self.frame_id + ###################################################################################### + # Serialization support + ###################################################################################### + + def _frame_key(self): + return (self.planet_frame, self.epoch, self.retrograde, self.aries) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (Frame.as_primary_frame(self.planet_frame), self.epoch, - self.retrograde, self.aries, self.given_cache_size, self.shape) + Fittable_.refresh(self) + self._cache_id() + return (Frame.as_primary_frame(self.planet_frame), self.epoch, self.retrograde, + self.aries, self._state_id(), self._cache_size) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (frame, epoch, retrograde, aries, cache_size, shape) = state - if shape == (): - key = (frame.frame_id, - None if epoch is None else epoch.vals, - retrograde, aries) - frame_id = RingFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (frame, epoch, retrograde, aries, frame_id, cache_size) = state + self.__init__(frame, epoch, retrograde=retrograde, aries=aries, + frame_id=frame_id, cache_size=cache_size) + Fittable_.freeze(self) - self.__init__(frame, epoch, retrograde, aries, frame_id=frame_id, - cache_size=cache_size, unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick={}): - """The Transform into the this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + + Notes: + TwoVector is a fixed frame, so the transform relative to the `reference` frame + is independent of time. + """ # For a fixed epoch, return the fixed transform if self.transform is not None: @@ -135,17 +135,10 @@ def transform_at_time(self, time, quick={}): # Check cache first if time is a Scalar if time.shape == (): - key = time.values - - if key in self.cache: - self.cached_value_returned = True - (count, key, xform) = self.cache[key] - self.cache_counter += 1 - count[0] = self.cache_counter + xform = self._cache[time.vals] + if xform: return xform - self.cached_value_returned = False - # Otherwise, calculate it for the current time xform = self.planet_frame.transform_at_time(time, quick=quick) @@ -159,9 +152,9 @@ def transform_at_time(self, time, quick={}): x_axis = Vector3.ZAXIS.cross(z_axis) matrix = Matrix3.twovec(z_axis, 2, x_axis, 0) - # This is the RingFrame matrix. It rotates from J2000 to the frame where - # the pole at epoch is along the Z-axis and the ascending node relative - # to the J2000 equator is along the X-axis. + # This is the RingFrame matrix. It rotates from J2000 to the frame where the pole + # at epoch is along the Z-axis and the ascending node relative to the J2000 + # equator is along the X-axis. if self.aries: (x,y,z) = x_axis.to_scalars() @@ -169,40 +162,35 @@ def transform_at_time(self, time, quick={}): matrix = Matrix3.z_rotation(node_lon) * matrix # Create transform - xform = Transform(matrix, Vector3.ZERO, - self.wayframe, self.reference, None) + xform = Transform(matrix, Vector3.ZERO, self.wayframe, self.reference, None) # Cache the transform if necessary - if time.shape == () and self.given_cache_size > 0: - - # Trim the cache, removing the values used least recently - if len(self.cache) >= self.cache_size: - all_keys = list(self.cache.values()) - all_keys.sort() - for (_, old_key, _) in all_keys[:self.trim_size]: - del self.cache[old_key] - - # Insert into the cache - key = time.values - self.cache_counter += 1 - count = np.array([self.cache_counter]) - self.cache[key] = (count, key, xform) + if time.shape == (): + self._cache[time.vals] = xform return xform - #=========================================================================== def node_at_time(self, time, quick={}): - """Angle from the frame's X-axis to the ring plane ascending node on - the J2000 equator. + """The vector defining the ascending node of this frame's XY plane relative to + the XY frame of its reference. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Vector3): The unit vector pointing in the direction of the ascending node. """ xform = self.transform_at_time(time, quick=quick) z_axis_wrt_j2000 = xform.unrotate(Vector3.ZAXIS) - (x,y,_) = z_axis_wrt_j2000.to_scalars() + (x, y, _) = z_axis_wrt_j2000.to_scalars() - if (x,y) == (0.,0.): + if (x, y) == (0., 0.): return Scalar(0.) return (y.arctan2(x) + np.pi/2.) % Scalar.TWOPI -################################################################################ +########################################################################################## diff --git a/oops/frame/rotation.py b/oops/frame/rotation.py index abfc75f3..e3c1c1dc 100755 --- a/oops/frame/rotation.py +++ b/oops/frame/rotation.py @@ -1,6 +1,6 @@ -################################################################################ +########################################################################################## # oops/frame/rotation.py: Subclass Rotation of class Frame -################################################################################ +########################################################################################## import numpy as np @@ -9,126 +9,126 @@ from oops.frame import Frame from oops.transform import Transform + class Rotation(Frame, Fittable): """A Frame describing a fixed rotation about one axis of another frame.""" - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _XYZDICT = {'X': 0, 'Y': 1, 'Z': 2, 'x': 0, 'y': 1, 'z': 2, 0: 0, 1: 1, 2: 2} - #=========================================================================== - def __init__(self, angle, axis, reference, frame_id=None, unpickled=False): + def __init__(self, arg, /, axis, reference, *, frame_id=None): """Constructor for a Rotation Frame. - Input: - angle the angle of rotation in radians. Can be a Scalar - containing multiple values. - axis the rotation axis: 0 for x, 1 for y, 2 for z. - reference the frame relative to which this rotation is defined. - frame_id the ID to use; None to leave the frame unregistered. - unpickled True if this frame has been read from a pickle file. + Parameters: + arg (Scalar, array-like, float or Rotation): The angle of rotation in radians, + which can be multidimensional. Alternatively, if another Rotation is + given, this object's rotation angle will always match that of the other. + axis (int): The rotation axis: 0 or "X" for x; 1 or "Y" for y, or 2 or "Z" for + z. + reference (frame or str): The frame or frame ID relative to which this + rotation is defined. + frame_id (str, optional): The frame ID to use; None to use a temporary ID. """ - self.angle = Scalar.as_scalar(angle) + if isinstance(arg, Rotation): + self.link = arg + self._angle_shape = self.link.angle_shape + self._fittable_nparams = self.link._fittable_nparams + else: + self.angle = Scalar.as_scalar(arg) + self._angle_shape = self.angle.shape + self._fittable_nparams = self.angle.size + self.link = None - self.axis2 = axis # Most often, the Z-axis + self.axis2 = Rotation._XYZDICT[axis] self.axis0 = (self.axis2 + 1) % 3 self.axis1 = (self.axis2 + 2) % 3 - self.frame_id = frame_id + # Required attributes self.reference = Frame.as_wayframe(reference) - self.origin = self.reference.origin - self.keys = set() - + self.origin = self.reference.origin self.shape = Qube.broadcasted_shape(self.angle, self.reference) + self.frame_id = frame_id - mat = np.zeros(self.shape + (3,3)) - mat[..., self.axis2, self.axis2] = 1. - mat[..., self.axis0, self.axis0] = np.cos(self.angle.vals) - mat[..., self.axis0, self.axis1] = np.sin(self.angle.vals) - mat[..., self.axis1, self.axis1] = mat[..., self.axis0, self.axis0] - mat[..., self.axis1, self.axis0] = -mat[..., self.axis0, self.axis1] - - # Update wayframe and frame_id; register if not temporary - self.register(unpickled=unpickled) - - # We need a wayframe before we can create the transform - self.transform = Transform(Matrix3(mat, self.angle.mask), Vector3.ZERO, - self.wayframe, self.reference, self.origin) + self.register() + self._refresh() - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.angle.vals, self.axis2, self.reference.frame_id) - Rotation.FRAME_IDS[key] = self.frame_id + ###################################################################################### + # Fittable interface + ###################################################################################### - def __getstate__(self): - return (self.angle, self.axis2, - Frame.as_primary_frame(self.reference), self.shape) + def _set_params(self, params): + """Redefine the rotation angle of this Rotation object.""" - def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (angle, axis, reference, shape) = state - if shape == (): - key = (angle.vals, axis, reference.frame_id) - frame_id = Rotation.FRAME_IDS.get(key, None) + if self.link: + self.link.set_params(params) + self.angle = self.link.angle + elif self._angle_shape == (): + self.angle = Scalar(params[0], self.angle.mask) else: - frame_id = None + params = np.array(params).reshape(self._angle_shape) + self.angle = Scalar(params, self.angle.mask) - self.__init__(angle, axis, reference, frame_id=frame_id, unpickled=True) - - #=========================================================================== - def transform_at_time(self, time, quick=False): - """Transform into this Frame at a Scalar of times.""" + @property + def _params(self): + return (self.angle,) - return self.transform + def _refresh(self): + """Update the internals.""" - ############################################################################ - # Fittable interface - ############################################################################ - - def set_params(self, params): - """Redefine the Fittable object, using this set of parameters. - - In this case, params is the set of angles of rotation. + if self.link: + self.angle = self.link.angle + self._matrix = self.link._matrix + else: + mat = np.zeros(self.shape + (3, 3)) + mat[..., self.axis2, self.axis2] = 1. + mat[..., self.axis0, self.axis0] = np.cos(self.angle.vals) + mat[..., self.axis0, self.axis1] = np.sin(self.angle.vals) + mat[..., self.axis1, self.axis1] = mat[..., self.axis0, self.axis0] + mat[..., self.axis1, self.axis0] = -mat[..., self.axis0, self.axis1] + self._matrix = Matrix3(mat, self.angle.mask) - Input: - params a list, tuple or 1-D Numpy array of floating-point - numbers, defining the parameters to be used in the - object returned. - """ + self._transform = Transform(self._matrix, Vector3.ZERO, self, self.reference, + self.origin) - params = Scalar.as_scalar(params) - if params.shape != self.shape: - raise ValueError('new parameter shape does not match original') + ###################################################################################### + # Serialization support + ###################################################################################### - self.angle = params + def __getstate__(self): + self.refresh() + return (self.angle, self.axis2, Frame.as_primary_frame(self.reference), + self._state_id()) - mat = np.zeros(self.shape + (3,3)) - mat[..., self.axis2, self.axis2] = 1. - mat[..., self.axis0, self.axis0] = np.cos(self.angle.vals) - mat[..., self.axis0, self.axis1] = np.sin(self.angle.vals) - mat[..., self.axis1, self.axis1] = mat[..., self.axis0, self.axis0] - mat[..., self.axis1, self.axis0] = -mat[..., self.axis0, self.axis1] + def __setstate__(self, state): + (angle, axis, reference, frame_id) = state + self.__init__(angle, axis, reference, frame_id=frame_id) + self.freeze() - self.transform = Transform(Matrix3(mat, self.angle.mask), Vector3.ZERO, - self.reference, self.origin) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== - def get_params(self): - """The current set of parameters defining this fittable object. + def transform_at_time(self, time, quick=False): + """Transform that rotates coordinates from the reference frame to this frame. - Return: a Numpy 1-D array of floating-point numbers containing - the parameter values defining this object. - """ + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. - return self.angle.vals + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. - #=========================================================================== - def copy(self): - """A deep copy of the given object. + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. - The copy can be safely modified without affecting the original. + Notes: + Rotation is a fixed frame, so the transform relative to the `reference` frame + is independent of time. """ - return Rotation(self.angle.copy(), self.axis, self.reference_id) + return self._transform -################################################################################ +########################################################################################## diff --git a/oops/frame/spiceframe.py b/oops/frame/spiceframe.py index 4664747e..cbd8eb31 100755 --- a/oops/frame/spiceframe.py +++ b/oops/frame/spiceframe.py @@ -1,6 +1,6 @@ -################################################################################ +########################################################################################## # oops/frame/spiceframe.py: Subclass SpiceFrame of class Frame -################################################################################ +########################################################################################## import numpy as np from scipy.interpolate import UnivariateSpline @@ -12,35 +12,30 @@ from oops.transform import Transform import oops.spice_support as spice + class SpiceFrame(Frame): """A Frame defined within the SPICE toolkit.""" - #=========================================================================== - def __init__(self, spice_frame, spice_reference='J2000', frame_id=None, - omega_type='tabulated', omega_dt=1., unpickled=False): + def __init__(self, spice_frame, spice_reference='J2000', frame_id=None, *, + omega_type='tabulated', omega_dt=1.): """Constructor for a SpiceFrame. - Input: - spice_frame the name or integer ID of the destination frame - or of the central body as used in the SPICE toolkit. - - spice_reference the name or integer ID of the reference frame as - used in the SPICE toolkit; 'J2000' by default. - - frame_id the string ID under which the frame will be - registered. By default, this will be the name as - used by the SPICE toolkit. - - omega_type 'tabulated' to take omega directly from the kernel; - 'numerical' to return numerical derivatives. - 'zero' to ignore omega vectors. - - omega_dt default time step in seconds to use for spline-based - numerical derivatives of omega. - - unpickled True if this object was read from a pickle file. If - so, then it will be treated as a duplicate of a - pre-existing SpicePath for the same SPICE ID. + Parameters: + spice_frame (str or int): The name, frame ID, or body ID as used in the SPICE + toolkit. + spice_reference (str or int, optional): The name or ID of the reference frame + as used in the SPICE toolkit. + frame_id (str, optional): The name under which the frame will be registered. + By default, this is the name as used by the SPICE toolkit. + omega_type (str, optional): Options defining how `omega`, the time derivative + of the frame, is calculated: + + * "tabulated" to take `omega` directly from the SPICE kernel; + * "numerical" to derive `omega` via numerical derivatives; + * "zero" to ignore omega vectors. + + omega_dt (float, optional): The default time step in seconds to use when + `omega_type` equals "numerical". """ # Preserve the inputs @@ -58,7 +53,7 @@ def __init__(self, spice_frame, spice_reference='J2000', frame_id=None, # Fill in the Frame ID and save it in the SPICE global dictionary self.frame_id = frame_id or self.spice_frame_name - spice.FRAME_TRANSLATION[self.spice_frame_id] = self.frame_id + spice.FRAME_TRANSLATION[self.spice_frame_id] = self.frame_id spice.FRAME_TRANSLATION[self.spice_frame_name] = self.frame_id # Fill in the reference wayframe @@ -77,71 +72,72 @@ def __init__(self, spice_frame, spice_reference='J2000', frame_id=None, origin_path = Frame.SPICEPATH_CLASS(origin_id) self.origin = origin_path.waypoint - # No shape, no keys + # No shape self.shape = () - self.keys = set() # Save interpolation method if omega_type not in ('tabulated', 'numerical', 'zero'): - raise ValueError('invalid SpiceFrame omega_type: ' - + repr(omega_type)) + raise ValueError('invalid SpiceFrame omega_type: ' + repr(omega_type)) self.omega_tabulated = (omega_type == 'tabulated') self.omega_numerical = (omega_type == 'numerical') self.omega_zero = (omega_type == 'zero') - # Always register a SpiceFrame - # This also fills in the waypoint - self.register(unpickled=unpickled) + # Always register a SpiceFrame. This also fills in the waypoint. + self.register() + + ###################################################################################### + # Serialization support + ###################################################################################### - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): return (self.spice_frame_name, self.spice_reference, self.omega_type, - self.omega_dt) + self.omega_dt, self._state_id()) def __setstate__(self, state): - - (spice_frame_name, spice_reference, omega_type, omega_dt) = state - - # If this is a duplicate of a pre-existing SpiceFrame, make sure it gets - # assigned the pre-existing frame ID and Wayframe. - frame_id = spice.FRAME_TRANSLATION.get(spice_frame_name, None) + (spice_frame_name, spice_reference, omega_type, omega_dt, frame_id) = state + if frame_id is None: + frame_id = spice.FRAME_TRANSLATION.get(spice_frame_name, None) self.__init__(spice_frame_name, spice_reference, frame_id=frame_id, - omega_type=omega_type, omega_dt=omega_dt, unpickled=True) + omega_type=omega_type, omega_dt=omega_dt) + + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick={}): - """A Transform that rotates from the reference frame into this frame. + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. - Input: - time a Scalar time. - quick an optional dictionary of parameter values to use as - overrides to the configured default QuickPath and - QuickFrame parameters; use False to disable the use - of QuickPaths and QuickFrames. + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. - Return: the corresponding Tranform applicable at the - specified time(s). """ time = Scalar.as_scalar(time).as_float() - ######## Handle a single time + # Handle a single time if time.shape == (): # Case 1: omega_type = tabulated if self.omega_tabulated: - matrix6 = cspyce.sxform(self.spice_reference_name, - self.spice_frame_name, + matrix6 = cspyce.sxform(self.spice_reference_name, self.spice_frame_name, time.values) (matrix, omega) = cspyce.xf2rav(matrix6) - return Transform(matrix, omega, self, self.reference) # Case 2: omega_type = zero elif self.omega_zero: - matrix = cspyce.pxform(self.spice_reference_name, - self.spice_frame_name, + matrix = cspyce.pxform(self.spice_reference_name, self.spice_frame_name, time.values) return Transform(matrix, Vector3.ZERO, self, self.reference) @@ -154,8 +150,7 @@ def transform_at_time(self, time, quick={}): for j in range(len(times)): mats[j] = cspyce.pxform(self.spice_reference_name, - self.spice_frame_name, - times[j]) + self.spice_frame_name, times[j]) # Convert three matrices to quaternions quats = Quaternion.as_quaternion(Matrix3(mats)) @@ -163,14 +158,13 @@ def transform_at_time(self, time, quick={}): # Use a Univariate spline to get components of the derivative qdot = np.empty(4) for j in range(4): - spline = UnivariateSpline(times, quats.values[:,j], k=2, - s=0) + spline = UnivariateSpline(times, quats.values[:,j], k=2, s=0) qdot[j] = spline.derivative(1)(et) omega = 2. * (Quaternion(qdot) / quats[1]).values[1:4] return Transform(mats[1], omega, self, self.reference) - ######## Apply the quick_frame if requested + # Apply the quick_frame if requested if isinstance(quick, dict): quick = quick.copy() @@ -178,41 +172,38 @@ def transform_at_time(self, time, quick={}): quick['ignore_quickframe_omega'] = self.omega_zero if self.omega_numerical: - quick['frame_time_step'] = min(quick['frame_time_step'], - self.omega_dt) + quick['frame_time_step'] = min(quick['frame_time_step'], self.omega_dt) frame = self.quick_frame(time, quick) return frame.transform_at_time(time, quick=False) - ######## Handle multiple times + # Handle multiple times # Case 1: omega_type = tabulated if self.omega_tabulated: - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.empty(time.shape + (3,)) for i,t in np.ndenumerate(time.values): matrix6 = cspyce.sxform(self.spice_reference_name, - self.spice_frame_name, - t) + self.spice_frame_name, t) (matrix[i], omega[i]) = cspyce.xf2rav(matrix6) # Case 2: omega_type = zero elif self.omega_zero: - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.zeros(time.shape + (3,)) for i,t in np.ndenumerate(time.values): matrix[i] = cspyce.pxform(self.spice_reference_name, - self.spice_frame_name, - t) + self.spice_frame_name, t) # Case 3: omega_type = numerical - # This procedure calculates each omega using its own UnivariateSpline; - # it could be very slow. A QuickFrame is recommended as it would - # accomplish the same goals much faster. + # This procedure calculates each omega using its own UnivariateSpline; it could be + # very slow. A QuickFrame is recommended as it would accomplish the same goals + # much faster. else: - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.empty(time.shape + (3,)) for i,t in np.ndenumerate(time.values): @@ -224,8 +215,7 @@ def transform_at_time(self, time, quick={}): mats = np.empty((3,3,3)) for j in range(len(times)): mats[j] = cspyce.pxform(self.spice_reference_name, - self.spice_frame_name, - times[j]) + self.spice_frame_name, times[j]) # Convert these three matrices to quaternions quats = Quaternion.as_quaternion(Matrix3(mats)) @@ -233,8 +223,7 @@ def transform_at_time(self, time, quick={}): # Use a Univariate spline to get components of the derivative qdot = np.empty(4) for j in range(4): - spline = UnivariateSpline(times, quats.values[:,j], k=2, - s=0) + spline = UnivariateSpline(times, quats.values[:,j], k=2, s=0) qdot[j] = spline.derivative(1)(t) omega[i] = 2. * (Quaternion(qdot) / quats[1]).values[1:4] @@ -242,27 +231,29 @@ def transform_at_time(self, time, quick={}): return Transform(matrix, omega, self, self.reference) - #=========================================================================== def transform_at_time_if_possible(self, time, quick={}): - """A Transform that rotates from the reference frame into this frame. - - Unlike method transform_at_time(), this variant tolerates times that - raise cspyce errors. It returns a new time Scalar along with the new - Transform, where both objects skip over the times at which the transform - could not be evaluated. - - Input: - time a Scalar time, which must be 0-D or 1-D. - quick an optional dictionary of parameter values to use as - overrides to the configured default QuickPath and - QuickFrame parameters; use False to disable the use - of QuickPaths and QuickFrames. - - Return: (newtimes, transform) - newtimes a Scalar time, possibly containing a subset of the - times given. - transform the corresponding Tranform applicable at the new - time(s). + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Unlike method `transform_at_time`, this variant tolerates times that raise cspyce + errors. It returns a new time Scalar along with the new Transform, where both + objects skip over the times at which the transform could not be evaluated. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (tuple): A tuple with two values: + + * Scalar: The times that are returned, possibly containing a subset of the + original times given. + * Transform: The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. """ time = Scalar.as_scalar(time).as_float() @@ -271,7 +262,7 @@ def transform_at_time_if_possible(self, time, quick={}): if time.shape == (): return (time, self.transform_at_time(time, quick)) - ######## Apply the quick_frame if requested + # Apply the quick_frame if requested if isinstance(quick, dict): quick = quick.copy() @@ -279,13 +270,12 @@ def transform_at_time_if_possible(self, time, quick={}): quick['ignore_quickframe_omega'] = self.omega_zero if self.omega_numerical: - quick['frame_time_step'] = min(quick['frame_time_step'], - self.omega_dt) + quick['frame_time_step'] = min(quick['frame_time_step'], self.omega_dt) frame = self.quick_frame(time, quick) return frame.transform_at_time_if_possible(time, quick=False) - ######## Handle multiple times + # Handle multiple times # Lists used in case of error new_time = [] @@ -296,14 +286,13 @@ def transform_at_time_if_possible(self, time, quick={}): # Case 1: omega_type = tabulated if self.omega_tabulated: - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.empty(time.shape + (3,)) - for i,t in np.ndenumerate(time.values): + for i, t in np.ndenumerate(time.values): try: matrix6 = cspyce.sxform(self.spice_reference_name, - self.spice_frame_name, - t) + self.spice_frame_name, t) (matrix[i], omega[i]) = cspyce.xf2rav(matrix6) new_time.append(t) @@ -317,18 +306,17 @@ def transform_at_time_if_possible(self, time, quick={}): # Case 2: omega_type = zero elif self.omega_zero: - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.zeros(time.shape + (3,)) for i,t in np.ndenumerate(time.values): try: matrix[i] = cspyce.pxform(self.spice_reference_name, - self.spice_frame_name, - t) + self.spice_frame_name, t) new_time.append(t) matrix_list.append(matrix[i]) - omega_list.append((0.,0.,0.)) + omega_list.append((0., 0., 0.)) except (RuntimeError, ValueError, IOError) as e: if len(time.shape) > 1: @@ -336,22 +324,21 @@ def transform_at_time_if_possible(self, time, quick={}): error_found = e # Case 3: omega_type = numerical - # This procedure calculates each omega using its own UnivariateSpline; - # it could be very slow. A QuickFrame is recommended as it would - # accomplish the same goals much faster. + # This procedure calculates each omega using its own UnivariateSpline; it could be + # very slow. A QuickFrame is recommended as it would accomplish the same goals + # much faster. else: - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.empty(time.shape + (3,)) for i,t in np.ndenumerate(time.values): try: times = np.array((t - self.omega_dt, t, t + self.omega_dt)) - mats = np.empty((3,3,3)) + mats = np.empty((3, 3, 3)) for j in range(len(times)): mats[j] = cspyce.pxform(self.spice_reference_name, - self.spice_frame_name, - times[j]) + self.spice_frame_name, times[j]) # Convert three matrices to quaternions quats = Quaternion.as_quaternion(Matrix3(mats)) @@ -359,8 +346,7 @@ def transform_at_time_if_possible(self, time, quick={}): # Use a Univariate spline to get components of the derivative qdot = np.empty(4) for j in range(4): - spline = UnivariateSpline(times, quats.values[:,j], k=2, - s=0) + spline = UnivariateSpline(times, quats.values[:,j], k=2, s=0) qdot[j] = spline.derivative(1)(t) omega[i] = 2. * (Quaternion(qdot) / quats[1]).values[1:4] @@ -388,4 +374,4 @@ def transform_at_time_if_possible(self, time, quick={}): return (time, Transform(matrix, omega, self, self.reference)) -################################################################################ +########################################################################################## diff --git a/oops/frame/spicetype1frame.py b/oops/frame/spicetype1frame.py index c46b5215..1015d292 100755 --- a/oops/frame/spicetype1frame.py +++ b/oops/frame/spicetype1frame.py @@ -1,6 +1,6 @@ -################################################################################ +########################################################################################## # oops/frame/spicetype1frame.py: Subclass SpiceType1Frame of Frame -################################################################################ +########################################################################################## import numpy as np @@ -12,36 +12,23 @@ import oops.spice_support as spice class SpiceType1Frame(Frame): - """A Frame object defined within the SPICE toolkit as a Type 1 (discrete) C - kernel. - """ + """A Frame object defined within the SPICE toolkit as a Type 1 (discrete) C kernel.""" - #=========================================================================== - def __init__(self, spice_frame, spice_host, tick_tolerance, - spice_reference="J2000", frame_id=None, unpickled=False): + def __init__(self, spice_frame, spice_host, tick_tolerance, *, + spice_reference='J2000', frame_id=None): """Constructor for a SpiceType1Frame. Input: - spice_frame the name or integer ID of the destination frame or - of the central body as used in the SPICE toolkit. - - spice_host the name or integer ID of the spacecraft. This is - needed to evaluate the spacecraft clock ticks. - - tick_tolerance a number or string defining the error tolerance in - spacecraft clock ticks for the frame returned. - - spice_reference the name or integer ID of the reference frame as - used in the SPICE toolkit; "J2000" by default. - - frame_id the name or ID under which the frame will be - registered. By default, this will be the value of - spice_id if that is given as a string; otherwise - it will be the name as used by the SPICE toolkit. - - unpickled True if this object was read from a pickle file. If - so, then it will be treated as a duplicate of a - pre-existing SpicePath for the same SPICE ID. + spice_frame (str or int): The name, frame ID, or body ID as used in the SPICE + toolkit. + spice_host (str or int)" The name or integer ID of the spacecraft. This is + needed to evaluate the spacecraft clock ticks. + tick_tolerance (float, int or str): A number or string defining the error + tolerance in spacecraft clock ticks for the frame returned. + spice_reference (str or int, optional): The name or ID of the reference frame + as used in the SPICE toolkit. + frame_id (str, optional): The name under which the frame will be registered. + By default, this is the name as used by the SPICE toolkit. """ # Preserve the inputs @@ -57,13 +44,11 @@ def __init__(self, spice_frame, spice_host, tick_tolerance, (self.spice_reference_id, self.spice_reference_name) = spice.frame_id_and_name(spice_reference) - (self.spice_body_id, - self.spice_body_name) = spice.body_id_and_name(spice_host) + (self.spice_body_id, self.spice_body_name) = spice.body_id_and_name(spice_host) # Fill in the time tolerances if isinstance(tick_tolerance, str): - self.tick_tolerance = cspyce.sctiks(self.spice_body_id, - tick_tolerance) + self.tick_tolerance = cspyce.sctiks(self.spice_body_id, tick_tolerance) else: self.tick_tolerance = tick_tolerance @@ -71,7 +56,7 @@ def __init__(self, spice_frame, spice_host, tick_tolerance, # Fill in the Frame ID and save it in the SPICE global dictionary self.frame_id = frame_id or self.spice_frame_name - spice.FRAME_TRANSLATION[self.spice_frame_id] = self.frame_id + spice.FRAME_TRANSLATION[self.spice_frame_id] = self.frame_id spice.FRAME_TRANSLATION[self.spice_frame_name] = self.frame_id # Fill in the reference wayframe @@ -79,7 +64,7 @@ def __init__(self, spice_frame, spice_host, tick_tolerance, self.reference = Frame.as_wayframe(reference_id) # Fill in the origin waypoint - self.spice_origin_id = cspyce.frinfo(self.spice_frame_id)[0] + self.spice_origin_id = cspyce.frinfo(self.spice_frame_id)[0] self.spice_origin_name = cspyce.bodc2n(self.spice_origin_id) try: @@ -89,34 +74,36 @@ def __init__(self, spice_frame, spice_host, tick_tolerance, origin_path = Frame.SPICEPATH_CLASS(self.spice_origin_id) self.origin = origin_path.waypoint - # No shape, no keys + # No shape self.shape = () - self.keys = set() - # Always register a SpiceType1Frame - # This also fills in the waypoint - self.register(unpickled=unpickled) + # Always register a SpiceFrame. This also fills in the waypoint. + self.register() # Initialize cache - self.cached_shape = None - self.cached_time = None - self.cached_transform = None + self._latest_shape = None + self._latest_time = None + self._latest_transform = None + + ###################################################################################### + # Serialization support + ###################################################################################### def __getstate__(self): return (self.spice_frame, self.spice_host, self.tick_tolerance, - self.spice_reference) + self.spice_reference, self._state_id()) def __setstate__(self, state): + (spice_frame_name, spice_host, tick_tolerance, spice_reference, frame_id) = state + if frame_id is None: + frame_id = spice.FRAME_TRANSLATION.get(spice_frame_name, None) + self.__init__(spice_frame_name, spice_host, tick_tolerance, spice_reference, + frame_id=frame_id) - (spice_frame_name, spice_host, tick_tolerance, spice_reference) = state - - # If this is a duplicate of a pre-existing SpiceType1Frame, make sure it - # gets assigned the pre-existing frame ID and Wayframe. - frame_id = spice.FRAME_TRANSLATION.get(spice_frame_name, None) - self.__init__(spice_frame_name, spice_host, tick_tolerance, - spice_reference, frame_id=frame_id, unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick={}): """A Transform that rotates from the reference frame into this frame. @@ -131,16 +118,15 @@ def transform_at_time(self, time, quick={}): if self.time_tolerance is None: time = Scalar.as_scalar(time) ticks = cspyce.sce2c(self.spice_body_id, time.vals) - ticks_per_sec = cspyce.sce2c(self.spice_body_id, - time.vals + 1.) - ticks + ticks_per_sec = cspyce.sce2c(self.spice_body_id, time.vals + 1.) - ticks self.time_tolerance = self.tick_tolerance / ticks_per_sec # Check to see if the cached transform is adequate time = Scalar.as_scalar(time) - if np.shape(time.vals) == self.cached_shape: + if np.shape(time.vals) == self._cached_shape: diff = np.abs(time.vals - self.cached_time) if np.all(diff < self.time_tolerance): - return self.cached_transform + return self._cached_transform # A single input time can be handled quickly if time.shape == (): @@ -149,14 +135,14 @@ def transform_at_time(self, time, quick={}): self.tick_tolerance, self.spice_reference_name) - self.cached_shape = time.shape - self.cached_time = cspyce.sct2e(self.spice_body_id, true_ticks) - self.cached_transform = Transform(matrix3, Vector3.ZERO, - self.frame_id, self.reference_id) - return self.cached_transform + self._cached_shape = time.shape + self._cached_time = cspyce.sct2e(self.spice_body_id, true_ticks) + self._cached_transform = Transform(matrix3, Vector3.ZERO, self.frame_id, + self.reference_id) + return self._cached_transform # Create the buffers - matrix = np.empty(time.shape + (3,3)) + matrix = np.empty(time.shape + (3, 3)) omega = np.zeros(time.shape + (3,)) true_times = np.empty(time.shape) @@ -174,8 +160,8 @@ def transform_at_time(self, time, quick={}): self.cached_shape = time.shape self.cached_time = true_times - self.cached_transform = Transform(matrix, omega, - self.frame_id, self.reference_id) + self.cached_transform = Transform(matrix, omega, self.frame_id, + self.reference_id) return self.cached_transform # Otherwise, iterate through the array... @@ -187,10 +173,10 @@ def transform_at_time(self, time, quick={}): matrix[i] = matrix3 true_times[i] = cspyce.sct2e(self.spice_body_id, true_ticks) - self.cached_shape = time.shape - self.cached_time = true_times - self.cached_transform = Transform(matrix, omega, - self.frame_id, self.reference_id) - return self.cached_transform + self._latest_shape = time.shape + self._latest_time = true_times + self._latest_transform = Transform(matrix, omega, self.frame_id, + self.reference_id) + return self.latest_transform -################################################################################ +########################################################################################## diff --git a/oops/frame/spinframe.py b/oops/frame/spinframe.py index 8c3585ba..36152b3d 100755 --- a/oops/frame/spinframe.py +++ b/oops/frame/spinframe.py @@ -1,111 +1,118 @@ -################################################################################ +########################################################################################## # oops/frame/spinframe.py: Subclass SpinFrame of class Frame -################################################################################ +########################################################################################## import numpy as np from polymath import Matrix3, Qube, Scalar, Vector3 +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform + class SpinFrame(Frame): - """A Frame subclass describing a frame in uniform rotation about one axis of - another frame. + """A Frame subclass describing a frame in uniform rotation about one axis of another + frame. - It can be created without a frame_id, reference_id or origin_id; in this - case it is not registered and can therefore be used as a component of - another frame. + It can be created without a frame_id, reference_id or origin_id; in this case it is + not registered and can therefore be used as a component of another frame. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _FRAME_IDS = {} - #=========================================================================== - def __init__(self, offset, rate, epoch, axis, reference, frame_id=None, - unpickled=False): + def __init__(self, offset, rate, epoch, axis, reference, frame_id=None): """Constructor for a Spin Frame. Input: - offset the angular offset of the frame at the epoch. - rate the rotation rate of the frame in radians/second. - epoch the time TDB at which the frame is defined. - axis the rotation axis: 0 for x, 1 for y, 2 for z. - reference the frame relative to which this frame is defined. - frame_id the ID under which this frame is to be registered; - None to use a temporary ID. - unpickled True if this frame has been read from a pickle file. - - Note that rate, offset and epoch can be Scalar values, in which case the - shape of the SpinFrame is defined by broadcasting the shapes of these - Scalars. + offset (Scalar, array-like, or float): The angular offset of the frame at the + epoch, in radians. + rate (Scalar, array-like, or float): The rotation rate of the frame in + radians/second. + epoch (Scalar, array-like, or float): The time TDB at which the frame is + defined. + axis (int): The rotation axis: 0 for x, 1 for y, 2 for z. + reference (Frame or str): The frame or ID relative to which this frame is + defined. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. + + Notes: + The rate, offset and epoch can be Scalar values, in which case the shape of + the SpinFrame is defined by broadcasting the shapes of these Scalars. """ self.offset = Scalar.as_scalar(offset) self.rate = Scalar.as_scalar(rate) self.epoch = Scalar.as_scalar(epoch) - self.shape = Qube.broadcasted_shape(self.rate, self.offset, self.epoch) - - self.axis2 = axis # Most often, the Z-axis + self.axis2 = axis self.axis0 = (self.axis2 + 1) % 3 self.axis1 = (self.axis2 + 2) % 3 + self.shape = Qube.broadcasted_shape(self.rate, self.offset, self.epoch) omega_vals = np.zeros(list(self.shape) + [3]) omega_vals[..., self.axis2] = self.rate.vals self.omega = Vector3(omega_vals, self.rate.mask) # Required attributes - self.frame_id = frame_id self.reference = Frame.as_wayframe(reference) - self.origin = self.reference.origin or Frame.PATH_CLASS.SSB - self.keys = set() + self.origin = self.reference.origin or Frame.PATH_CLASS.SSB + self.frame_id = self._recover_id(frame_id) # Update wayframe and frame_id; register if not temporary - self.register(unpickled=unpickled) + self.register() + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.offset.vals, self.rate.vals, self.epoch.vals, - self.axis2, self.reference.frame_id) - SpinFrame.FRAME_IDS[key] = self.frame_id + def _frame_key(self): + return (self.offset, self.rate, self.epoch, self.axis2, self.reference) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (self.offset, self.rate, self.epoch, self.axis2, - Frame.as_primary_frame(self.reference), self.shape) + Frame.as_primary_frame(self.reference), self._state_id()) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (offset, rate, epoch, axis2, reference, shape) = state - if shape == (): - key = (offset.vals, rate.vals, epoch.vals, axis2, - reference.frame_id) - frame_id = SpinFrame.FRAME_IDS.get(key, None) - else: - frame_id = None - - self.__init__(offset, rate, epoch, axis2, reference, frame_id=frame_id, - unpickled=True) - - #=========================================================================== + (offset, rate, epoch, axis, reference, frame_id) = state + self.__init__(offset, rate, epoch, axis, reference, frame_id=frame_id) + Fittable_.freeze(self) + + ###################################################################################### + # Frame API + ###################################################################################### + def transform_at_time(self, time, quick={}): - """The Transform to this Frame at a specified Scalar of times. + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. Ignored for + class SpinFrame. - QuickFrame options are ignored. + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. """ time = Scalar.as_scalar(time) angle = (time - self.epoch) * self.rate + self.offset - mat = np.zeros(list(angle.shape) + [3,3]) + mat = np.zeros(list(angle.shape) + [3, 3]) mat[..., self.axis2, self.axis2] = 1. mat[..., self.axis0, self.axis0] = np.cos(angle.values) mat[..., self.axis1, self.axis1] = mat[..., self.axis0, self.axis0] mat[..., self.axis0, self.axis1] = np.sin(angle.values) - mat[..., self.axis1, self.axis0] = -mat[...,self.axis0,self.axis1] + mat[..., self.axis1, self.axis0] = -mat[...,self.axis0, self.axis1] matrix = Matrix3(mat, angle.mask) - return Transform(matrix, self.omega, self.wayframe, self.reference, - self.origin) + return Transform(matrix, self.omega, self.wayframe, self.reference, self.origin) -################################################################################ +########################################################################################## diff --git a/oops/frame/synchronousframe.py b/oops/frame/synchronousframe.py index c60ad818..9a42d2b2 100755 --- a/oops/frame/synchronousframe.py +++ b/oops/frame/synchronousframe.py @@ -1,27 +1,26 @@ -################################################################################ +########################################################################################## # oops/frame/synchronousframe.py: Subclass SynchronousFrame of class Frame -################################################################################ +########################################################################################## -from polymath import Matrix3, Qube +from polymath import Matrix3 +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform class SynchronousFrame(Frame): - """A Frame subclass describing a a body that always keeps the x-axis pointed - toward a central planet and the y-axis in the negative direction of motion. + """A Frame subclass describing a a body that always keeps the x-axis pointed toward a + central planet and the y-axis in the negative direction of motion. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + FRAME_IDS = {} - #=========================================================================== - def __init__(self, body_path, planet_path, frame_id=None, unpickled=False): + def __init__(self, body_path, planet_path, frame_id=None): """Constructor for a SynchronousFrame. - Input: - body_path the path or path ID followed by the body. - planet_path the path or path ID followed by the central planet. - frame_id the ID to use; None to leave the frame unregistered. - unpickled True if this frame has been read from a pickle file. + Parameters: + body_path (Path or str): The path or path ID of the body. + planet_path (Path or str): The path or path ID of the central planet. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. """ self.body_path = Frame.PATH_CLASS.as_path(body_path) @@ -31,48 +30,59 @@ def __init__(self, body_path, planet_path, frame_id=None, unpickled=False): if self.planet_path.shape: raise ValueError('SynchronousFrame requires a shapeless body path') - self.frame_id = frame_id + # Required attributes self.reference = Frame.as_wayframe(self.planet_path.frame) - self.origin = self.planet_path.origin - self.shape = Qube.broadcasted_shape(self.body_path, - self.planet_path) - self.keys = set() + self.origin = self.planet_path.origin + self.shape = self.body_path.shape + self.frame_id = self._recover_id(frame_id) # Update wayframe and frame_id; register if not temporary - self.register(unpickled=unpickled) + self.register() + self._cache_id() - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.body_path.path_id, self.planet_path.path_id) - SynchronousFrame.FRAME_IDS[key] = self.frame_id + ###################################################################################### + # Serialization support + ###################################################################################### + + def _frame_key(self): + return (self.body_path, self.planet_path) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (Frame.PATH_CLASS.as_primary_path(self.body_path), - Frame.PATH_CLASS.as_primary_path(self.planet_path), self.shape) + Frame.PATH_CLASS.as_primary_path(self.planet_path), self._state_id()) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (body_path, planet_path, shape) = state - if shape == (): - key = (body_path.path_id, planet_path.path_id) - frame_id = SynchronousFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (body_path, planet_path, frame_id) = state + self.__init__(body_path, planet_path, frame_id=frame_id) + Fittable_.freeze(self) + + ###################################################################################### + # Frame API + ###################################################################################### + + def transform_at_time(self, time, *, quick=False): + """Transform that rotates coordinates from the reference frame to this frame. - self.__init__(body_path, planet_path, frame_id=frame_id, - unpickled=True) + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. - #=========================================================================== - def transform_at_time(self, time, quick=False): - """The Transform into the this Frame at a Scalar of times.""" + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + """ event = self.path.event_at_time(time, quick=quick) matrix = Matrix3.twovec(event.pos, 0, event.vel, 1) omega = event.pos.cross(event.vel) / event.pos.dot(event.pos) - return Transform(matrix, omega, self.frame_id, self.reference, - self.body_path) + return Transform(matrix, omega, self.frame_id, self.reference, self.body_path) -################################################################################ +########################################################################################## diff --git a/oops/frame/trackerframe.py b/oops/frame/trackerframe.py index 990518a1..af668ad9 100755 --- a/oops/frame/trackerframe.py +++ b/oops/frame/trackerframe.py @@ -1,119 +1,134 @@ -################################################################################ +########################################################################################## # oops/frame/trackerframe.py: Subclass TrackerFrame of class Frame -################################################################################ +########################################################################################## from polymath import Qube, Scalar, Vector3, Matrix3 +from oops.cache import Cache +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform + class TrackerFrame(Frame): - """A Frame subclass that ensures, via a small rotation, that a designated - target path will remain in a fixed direction. - - The primary use of this frame is for observing moving targets with HST. - Normally, HST images of the same target, obtained during the same visit and - orbit, will have a common pointing offset and can be navigated as a group. - This is not generally true when using the pointing information in the FITS - headers, because that pointing refers to the start time of each frame rather - than the midtime. + """A Frame subclass that ensures, via a small rotation, that a designated target path + will remain in a fixed direction. + + The primary use of this frame is for observing moving targets with HST. Normally, HST + images of the same target, obtained during the same visit and orbit, will have a + common pointing offset and can be navigated as a group. This is not generally true + when using the pointing information in the FITS headers, because that pointing refers + to the start time of each frame rather than the midtime. """ - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + FRAME_IDS = {} - #=========================================================================== - def __init__(self, frame, target, observer, epoch, frame_id=None, - unpickled=False): + def __init__(self, frame, target, observer, epoch, frame_id=None, cache_size=100): """Constructor for a Tracker Frame. Input: - frame the frame that will be modified to enable tracking, or - its frame ID. Must be inertial. - target the target's path or path ID. - observer the observer's path or path ID. - epoch the epoch for which the given frame is defined. - frame_id the ID to use; None to use a temporary ID. - unpickled True if this frame has been read from a pickle file. + frame (Frame or str): The frame or frame ID that defines the initial pointing. + target (Path or str): The target's path or path ID. + observer (Path or str): The observer's path or path ID. + epoch (Scalar, array-like, or float): The epoch for which the given frame is + defined. + frame_id (str, optional): The ID to use; None to leave the frame unregistered. + cache_size (int, optional): Number of transforms to cache. This can be useful + because it avoids unnecessary SPICE calls when the frame is being used + repeatedly at a finite set of times. """ - self.fixed_frame = Frame.as_frame(frame) + self.fixed_frame = Frame.as_frame(frame).wrt(Frame.J2000) self.target_path = Frame.PATH_CLASS.as_path(target) self.observer_path = Frame.PATH_CLASS.as_path(observer) self.epoch = Scalar.as_scalar(epoch) - self.shape = Qube.broadcasted_shape(self.fixed_frame, self.target_path, - self.observer_path, self.epoch) + self._cache_size = cache_size # Required attributes - self.frame_id = frame_id - self.reference = self.fixed_frame.reference - self.origin = self.fixed_frame.origin - self.keys = set() + self.reference = Frame.as_wayframe(self.fixed_frame.reference) + self.origin = self.fixed_frame.origin + self.shape = Qube.broadcasted_shape(self.fixed_frame, self.target_path, + self.observer_path, self.epoch) + self.frame_id = self._recover_id(frame_id) # Update wayframe and frame_id; register if not temporary - self.register(unpickled=unpickled) + self.register() + self._refresh() + self._cache_id() + + def _refresh(self): # Determine the apparent direction to the target path at epoch - obs_event = Frame.EVENT_CLASS(epoch, Vector3.ZERO, self.observer_path, + obs_event = Frame.EVENT_CLASS(self.epoch, Vector3.ZERO, self.observer_path, Frame.J2000) (path_event, obs_event) = self.target_path.photon_to_event(obs_event) - self.trackpoint = obs_event.neg_arr_ap.unit() + self._trackpoint = obs_event.neg_arr_ap.unit() # Determine the transform at epoch fixed_xform = self.fixed_frame.transform_at_time(self.epoch) - self.reference_xform = Transform(fixed_xform.matrix, Vector3.ZERO, - self.wayframe, self.reference, - self.origin) - if fixed_xform.omega != Vector3.ZERO: - raise ValueError('TrackerFrame reference frame must be inertial') + self.reference_xform = Transform(fixed_xform.matrix, Vector3.ZERO, self.wayframe, + self.reference, self.origin) # Convert the matrix to three axis vectors - self.reference_rows = Vector3(self.reference_xform.matrix.values) + self.reference_rows = Vector3(self.reference_xform.matrix.vals) - # Prepare to cache the most recently used transform - self.cached_time = None - self.cached_xform = None - _ = self.transform_at_time(self.epoch) # cache initialized + self._cache = Cache(self._cache_size) - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.fixed_frame.frame_id, self.target_path.path_id, - self.observer_path.path_id, self.epoch.vals) - TrackerFrame.FRAME_IDS[key] = self.frame_id + ###################################################################################### + # Serialization support + ###################################################################################### + + def _frame_key(self): + return (self.fixed_frame, self.target_path, self.observer_path, self.epoch) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (Frame.as_primary_frame(self.fixed_frame), Frame.PATH_CLASS.as_primary_path(self.target_path), Frame.PATH_CLASS.as_primary_path(self.observer_path), - self.epoch, self.shape) + self.epoch, self._state_id(), self._cache_size) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (frame, target, observer, epoch, shape) = state - if shape == (): - key = (frame.frame_id, target.path_id, observer.path_id, - epoch.vals) - frame_id = TrackerFrame.FRAME_IDS.get(key, None) - else: - frame_id = None - + (frame, target, observer, epoch, frame_id, cache_size) = state self.__init__(frame, target, observer, epoch, frame_id=frame_id, - unpickled=True) + cache_size=cache_size) + Fittable_.freeze(self) + + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick=False): - """The Transform into the this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. + """ + + time = Scalar.as_scalar(time) - if time == self.cached_time: - return self.cached_xform + # Check cache first if time is shapeless + if time.shape == (): + xform = self._cache[time.vals] + if xform: + return xform # Determine the needed rotation - obs_event = Frame.EVENT_CLASS(time, Vector3.ZERO, self.observer_path, - Frame.J2000) + obs_event = Frame.EVENT_CLASS(time, Vector3.ZERO, self.observer_path, Frame.J2000) (path_event, obs_event) = self.target_path.photon_to_event(obs_event) newpoint = obs_event.neg_arr_ap.unit() - rotation = self.trackpoint.cross(newpoint) + rotation = self._trackpoint.cross(newpoint) rotation = rotation.reshape(rotation.shape + (1,)) # Rotate the three axis vectors accordingly @@ -122,9 +137,10 @@ def transform_at_time(self, time, quick=False): Vector3.ZERO, # neglect the slow frame rotation self.wayframe, self.reference, self.origin) - # Cache the most recently used transform - self.cached_time = time - self.cached_xform = xform + # Cache the transform if necessary + if time.shape == (): + self._cache[time.vals] = xform + return xform -################################################################################ +########################################################################################## diff --git a/oops/frame/twovectorframe.py b/oops/frame/twovectorframe.py index 54eeff82..07a57764 100755 --- a/oops/frame/twovectorframe.py +++ b/oops/frame/twovectorframe.py @@ -1,73 +1,53 @@ -################################################################################ +########################################################################################## # oops/frame/twovectorframe.py: Subclass TwoVectorFrame of class Frame -################################################################################ +########################################################################################## from polymath import Matrix3, Qube, Vector3 +from oops.fittable import Fittable_ from oops.frame import Frame from oops.transform import Transform + class TwoVectorFrame(Frame): - """A Frame subclass describing a frame that is fixed relative to another - frame. + """A Frame subclass describing a frame that is fixed relative to another frame. - It is described by two vectors. The first vector is one axis of the frame - and the second vector points in the half-plane of another axis. + It is described by two vectors. The first vector is one axis of the frame and the + second vector points in the half-plane of another axis. """ - XYZDICT = {'X': 0, 'Y': 1, 'Z': 2, 'x': 0, 'y': 1, 'z': 2} - - FRAME_IDS = {} # frame_id to use if a frame already exists upon un-pickling + _FRAME_IDS = {} + _XYZDICT = {'X': 0, 'Y': 1, 'Z': 2, 'x': 0, 'y': 1, 'z': 2, 0: 0, 1: 1, 2: 2} - #=========================================================================== - def __init__(self, frame, vector1, axis1, vector2, axis2, frame_id='+', - unpickled=False): + def __init__(self, reference, vector1, axis1, vector2, axis2, frame_id='+'): """Constructor for a TwoVectorFrame. - Input: - frame the frame relative to which this frame is defined. - - vector1 vector describing an axis. - - axis1 'X', 'Y', or 'Z', indicating the axis defined by the - first vector.. - - vector2 a vector which, along with vector1, defines the half - plane in which a second axis falls. - - axis2 'X', 'Y', or 'Z', indicating the axis defined by the - second vector. - - frame_id the ID under which the frame will be registered. None to - leave the frame unregistered. If the value begins with - "+", then the "+" is replaced by an underscore and the - result is appended to the name of the reference frame. - If the name is "+" alone, then the registered name is - that of the reference frame appended with '_TWOVECTOR'. - - unpickled True if this frame has been read from a pickle file. + Parameters: + reference (Frame): The frame relative to which this frame is defined. + vector1 (Vector3 or array-like): Vector describing an axis. + axis1 (str): The axis defined by the first vector: 0 or "X" for x; 1 or "Y" + for y, or 2 or "Z" for z. + vector1 (Vector3 or array-like): A Vector which, along with vector1, defines + the half-plane in which a second axis falls. + axis1 (str): "X", "Y", or "Z", indicating the axis defined by the second + vector. + frame_id (str, optional): The ID under which to register the frame; None to + leave it unregistered. As a special case, at value of "+" alone is + replaced by the ID of `reference` plus "_TWOVECTOR". If text follows the + "+", the new ID is the ID of `reference` followed by "_" and this text. """ self.vector1 = Vector3.as_vector3(vector1) self.vector2 = Vector3.as_vector3(vector2) - self.axis1 = str(axis1).upper() - self.axis2 = str(axis2).upper() - - for axis in (self.axis1, self.axis2): - if axis not in ('X','Y','Z'): - raise ValueError('invalid axis value: ' + repr(axis)) - - self.reference = Frame.as_wayframe(frame) - - self.shape = Qube.broadcasted_shape(self.vector1, self.vector2, - self.reference) - self.keys = set() + self.axis1 = TwoVectorFrame._XYZDICT[axis1] + self.axis2 = TwoVectorFrame._XYZDICT[axis2] + # Required attributes + self.reference = Frame.as_wayframe(reference) self.origin = self.reference.origin + self.shape = Qube.broadcasted_shape(self.vector1, self.vector2, self.reference) - # Fill in the frame ID - self._state_frame_id = frame_id if frame_id is None: - self.frame_id = Frame.temporary_frame_id() + self.frame_id = Frame._recover_frame_id(frame_id) elif frame_id.startswith('+') and len(frame_id) > 1: self.frame_id = self.reference.frame_id + '_' + frame_id[1:] elif frame_id == '+': @@ -76,54 +56,79 @@ def __init__(self, frame, vector1, axis1, vector2, axis2, frame_id='+', self.frame_id = frame_id # Register if necessary - self.register(unpickled=unpickled) - - # Derive the tranform now - matrix = Matrix3.twovec(self.vector1, TwoVectorFrame.XYZDICT[axis1], - self.vector2, TwoVectorFrame.XYZDICT[axis2]) - - self.transform = Transform(matrix, Vector3.ZERO, - self.wayframe, self.reference) + self.register() + self._refresh() + self._cache_id() + def _refresh(self): + matrix = Matrix3.twovec(self.vector1, self.axis1, self.vector2, self.axis2) + self._transform = Transform(matrix, Vector3.ZERO, self, self.reference) z_axis = matrix.row_vector(2, classes=[Vector3]) - self.node = Vector3.ZAXIS.ucross(z_axis) + self._node = Vector3.ZAXIS.ucross(z_axis) + + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.frame_id in Frame.WAYFRAME_REGISTRY): - key = (self.reference.frame_id, - tuple(self.vector1.vals), self.axis1, - tuple(self.vector2.vals), self.axis2) - TwoVectorFrame.FRAME_IDS[key] = self.frame_id + def _frame_key(self): + return (self.reference, self.vector1, self.axis1, self.vector2, self.axis2) - # Unpickled frames will always have temporary IDs to avoid conflicts def __getstate__(self): - return (Frame.as_primary_frame(self.reference), - self.vector1, self.axis1, - self.vector2, self.axis2, self.shape) + Fittable_.refresh(self) + self._cache_id() + return (Frame.as_primary_frame(self.reference), self.vector1, self.axis1, + self.vector2, self.axis2, self._state_id()) def __setstate__(self, state): - # If this frame matches a pre-existing frame, re-use its ID - (frame, vector1, axis1, vector2, axis2, shape) = state - if self.shape == (): - key = (frame.frame_id, tuple(vector1.vals), axis1, - tuple(vector2.vals), axis2) - frame_id = TwoVectorFrame.FRAME_IDS.get(key, None) - else: - frame_id = None + (frame, vector1, axis1, vector2, axis2, frame_id) = state + self.__init__(frame, vector1, axis1, vector2, axis2, frame_id=frame_id) + Fittable_.freeze(self) - self.__init__(frame, vector1, axis1, vector2, axis2, frame_id=frame_id, - unpickled=True) + ###################################################################################### + # Frame API + ###################################################################################### - #=========================================================================== def transform_at_time(self, time, quick={}): - """The Transform into the this Frame at a Scalar of times.""" + """Transform that rotates coordinates from the reference frame to this frame. + + If the frame is rotating, then the coordinates being transformed must be given + relative to the center of rotation. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Transform): The Tranform applicable at the specified time or times. It + rotates vectors from the reference frame to this frame. - return self.transform + Notes: + TwoVector is a fixed frame, so the transform relative to the `reference` frame + is independent of time. + """ + + return self._transform - #=========================================================================== def node_at_time(self, time, quick={}): + """The vector defining the ascending node of this frame's XY plane relative to + the XY frame of its reference. + + Parameters: + time (Scalar, array-like, or float): The time in seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters. + Use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Vector3): The unit vector pointing in the direction of the ascending node. + + Notes: + TwoVector is a fixed frame, so its node vector relative to the `reference` + frame is independent of time. + """ - return self.node + return self._node -################################################################################ +########################################################################################## diff --git a/oops/observation/insitu.py b/oops/observation/insitu.py index 2faf138e..e924d366 100644 --- a/oops/observation/insitu.py +++ b/oops/observation/insitu.py @@ -8,6 +8,7 @@ from oops.cadence import Cadence from oops.cadence.instant import Instant from oops.fov.nullfov import NullFOV +from oops.fittable import Fittable_ from oops.frame import Frame from oops.observation import Observation from oops.path import Path @@ -71,19 +72,30 @@ def __init__(self, cadence, path, **subfields): self.shape = self.cadence.shape self.uv_shape = (1,1) - # Timing - self.time = self.cadence.time - self.midtime = self.cadence.midtime - # Optional subfields self.subfields = {} for key in subfields.keys(): self.insert_subfield(key, subfields[key]) def __getstate__(self): + Fittable_.refresh(self) return (self.cadence, self.path, self.subfields) def __setstate__(self, state): self.__init__(*state[:-1], **state[-1]) + #=========================================================================== + def time_shift(self, dtime): + """A copy of the observation object with a time-shift. + + Input: + dtime the time offset to apply to the observation, in units of + seconds. A positive value shifts the observation later. + + Return: a shallow copy of the object with a new time. + """ + + return self.insitu(self.cadence.time_shift(dtime), self.path, + **self.subfields) + ################################################################################ diff --git a/oops/observation/observation_.py b/oops/observation/observation_.py index 2ce79684..d89c0d69 100755 --- a/oops/observation/observation_.py +++ b/oops/observation/observation_.py @@ -5,10 +5,11 @@ import numpy as np import numbers -from polymath import Scalar, Pair, Vector, Vector3, Qube -from oops.config import LOGGING, PATH_PHOTONS -from oops.event import Event -from oops.meshgrid import Meshgrid +from polymath import Scalar, Pair, Vector, Vector3, Qube +from oops.config import LOGGING, PATH_PHOTONS +from oops.event import Event +from oops.frame.navigation import Navigation +from oops.meshgrid import Meshgrid class Observation(object): """An Observation is an abstract class that defines the timing and pointing @@ -29,12 +30,14 @@ class Observation(object): At minimum, these attributes are used to describe the observation: + cadence a Cadence object defining the timing of the observation. + time a tuple or Pair defining the start time and end time of - the observation overall, in seconds TDB. + the observation overall, in seconds TDB. Inherited from + `cadence`. midtime the mid-time of the observation, in seconds TDB. - - cadence a Cadence object defining the timing of the observation. + Inherited from `cadence`. fov a FOV (field-of-view) object, which describes the field of view including any spatial distortion. It maps @@ -42,8 +45,9 @@ class Observation(object): coordinates (x,y). uv_shape a tuple defining the 2-D shape of the spatial axes of - the data array, in (u,v) order. Note that this may - differ from fov.uv_shape. + the data array, in (u,v) order. Note that this will + differ from fov.uv_shape in cases where the + time-dependence introduces an extra dimension. u_axis, v_axis integers identifying the axes of the data array associated with the u-axis and the v-axis. Use -1 if @@ -91,6 +95,14 @@ def __init__(self): pass + @property + def time(self): + return self.cadence.time + + @property + def midtime(self): + return self.cadence.midtime + #=========================================================================== def uvt(self, indices, remask=False, derivs=True): """Coordinates (u,v) and time t for indices into the data array. @@ -354,6 +366,32 @@ def time_shift(self, dtime): raise NotImplementedError(type(self).__name__ + '.time_shift ' + 'is not implemented') + #=========================================================================== + def navigate(self, angles): + """A copy of this Observation object after two or three rotation angles + of a Navigation object applied. + + Input: + angles two or three angles of rotation in radians. The order of + the rotations is about the y, x, and (optionally) z + axes. These angles rotate a vector in the reference + frame into this frame. + + Return: A new Observation with the navigation applied. + """ + + # Identify the non-navigated frame + if isinstance(self.frame, Navigation): + frame = self.frame.reference + else: + frame = self.frame + + # Copy and update the frame + obs = type(self).__new__() + obs.__dict__ = self.__dict__.copy() + obs.frame = Navigation(angles, reference=frame) + return obs + ############################################################################ # Subfield support methods ############################################################################ diff --git a/oops/observation/pixel.py b/oops/observation/pixel.py index a4bf5c81..cffe782c 100755 --- a/oops/observation/pixel.py +++ b/oops/observation/pixel.py @@ -7,6 +7,7 @@ from polymath import Scalar, Pair, Vector3 from oops.observation import Observation from oops.event import Event +from oops.fittable import Fittable_ from oops.frame import Frame from oops.path import Path @@ -82,18 +83,13 @@ def __init__(self, axes, cadence, fov, path, frame, **subfields): shape_list[self.t_axis] = samples self.shape = tuple(shape_list) - # Timing - self.time = self.cadence.time - self.midtime = self.cadence.midtime - self._scalar_time = (Scalar(self.time[0]), Scalar(self.time[1])) - self._scalar_midtime = Scalar(self.cadence.midtime) - # Optional subfields self.subfields = {} for key in subfields.keys(): self.insert_subfield(key, subfields[key]) def __getstate__(self): + Fittable_.refresh(self) return (self.axes, self.cadence, self.fov, self.path, self.frame, self.subfields) @@ -124,7 +120,7 @@ def uvt(self, indices, remask=False, derivs=True): if tstep is None: # if t_axis < 0 uv = Pair.filled(indices.shape, 0.5) - return (uv, self._scalar_midtime) + return (uv, Scalar(self.cadence.midtime)) time = self.cadence.time_at_tstep(tstep, remask=remask) uv = Pair.filled(time.shape, 0.5, mask=time.mask) @@ -149,7 +145,8 @@ def uvt_range(self, indices, remask=False): """ if self.t_axis < 0: - return (Pair.INT00, self.fov.uv_shape) + self._scalar_time + return (Pair.INT00, self.fov.uv_shape, + Scalar(self.cadence.time[0]), Scalar(self.cadence.time[1])) # Works for a 1-D index or a multi-D index tstep = Observation.scalar_from_indices(indices, self.t_axis) diff --git a/oops/observation/rasterslit1d.py b/oops/observation/rasterslit1d.py index 478896c1..cbea26b0 100755 --- a/oops/observation/rasterslit1d.py +++ b/oops/observation/rasterslit1d.py @@ -8,6 +8,7 @@ from oops.observation import Observation from oops.cadence import Cadence from oops.cadence.metronome import Metronome +from oops.fittable import Fittable_ from oops.frame import Frame from oops.path import Path @@ -112,16 +113,13 @@ def __init__(self, axes, cadence, fov, path, frame, **subfields): else: raise TypeError('Invalid cadence class: ' + type(cadence).__name__) - # Timing - self.time = self.cadence.time - self.midtime = self.cadence.midtime - # Optional subfields self.subfields = {} for key in subfields.keys(): self.insert_subfield(key, subfields[key]) def __getstate__(self): + Fittable_.refresh(self) return (self.axes, self.cadence, self.fov, self.path, self.frame, self.subfields) diff --git a/oops/observation/slit1d.py b/oops/observation/slit1d.py index 90ae4e2c..1a9157d4 100755 --- a/oops/observation/slit1d.py +++ b/oops/observation/slit1d.py @@ -4,11 +4,13 @@ import numpy as np -from polymath import Scalar, Pair -from oops.observation import Observation -from oops.cadence.metronome import Metronome -from oops.frame import Frame -from oops.path import Path +from polymath import Scalar, Pair +from oops.observation import Observation +from oops.cadence import Cadence +from oops.cadence.snapcadence import SnapCadence +from oops.fittable import Fittable_ +from oops.frame import Frame +from oops.path import Path class Slit1D(Observation): """A subclass of Observation consisting of a 1-D slit measurement with no @@ -27,8 +29,11 @@ def __init__(self, axes, tstart, texp, fov, path, frame, **subfields): if any. Only one of 'u' or 'v' can appear in a Slit1D. tstart the start time of the observation in seconds TDB. + Alternatively, a Cadence object with shape (1,) defining + `tstart` and `texp`. - texp exposure duration of the observation in seconds. + texp exposure duration of the observation in seconds. Ignored + if `tstart` is specified as a Cadence. fov a FOV (field-of-view) object, which describes the field of view including any spatial distortion. It maps @@ -89,17 +94,14 @@ def __init__(self, axes, tstart, texp, fov, path, frame, **subfields): self.t_axis = -1 # Cadence - self.cadence = Metronome.for_array0d(tstart, texp) - - # Timing - self.tstart = self.cadence.tstart - self.texp = self.cadence.texp - - self.time = self.cadence.time - self.midtime = self.cadence.midtime - - self._scalar_time = (Scalar(self.time[0]), Scalar(self.time[1])) - self._scalar_midtime = Scalar(self.midtime) + if isinstance(tstart, Cadence): + self.cadence = tstart + if self.cadence.shape != (1,): + raise ValueError("Shape of a Snapshot's cadence must be (1,)") + self.texp = self.cadence.texp + else: + self.cadence = SnapCadence(tstart, texp) + self.texp = texp # Optional subfields self.subfields = {} @@ -107,7 +109,8 @@ def __init__(self, axes, tstart, texp, fov, path, frame, **subfields): self.insert_subfield(key, subfields[key]) def __getstate__(self): - return (self.axes, self.tstart, self.texp, self.fov, self.path, + Fittable_.refresh(self) + return (self.axes, self.cadence, self.texp, self.fov, self.path, self.frame, self.subfields) def __setstate__(self, state): @@ -149,7 +152,7 @@ def uvt(self, indices, remask=False, derivs=True): uv = Pair(uv_vals, mask=slit_coord.mask) # Create time Scalar; shapeless is OK unless there's a mask - time = self._scalar_midtime + time = Scalar(self.cadence.midtime) # Apply mask to time if necessary if remask and np.any(slit_coord.mask): @@ -242,12 +245,9 @@ def time_shift(self, dtime): Return: a (shallow) copy of the object with a new time. """ - obs = Slit1D(axes=self.axes, tstart=self.tstart + dtime, texp=self.texp, - fov=self.fov, path=self.path, frame=self.frame) - - for key in self.subfields.keys(): - obs.insert_subfield(key, self.subfields[key]) - - return obs + cadence = self.cadence.time_shift(dtime) + return Slit1D(axes=self.axes, tstart=cadence, texp=self.texp, + fov=self.fov, path=self.path, frame=self.frame, + **self.subfields) ################################################################################ diff --git a/oops/observation/snapshot.py b/oops/observation/snapshot.py index 598b19ee..96a6e230 100755 --- a/oops/observation/snapshot.py +++ b/oops/observation/snapshot.py @@ -4,14 +4,16 @@ import numpy as np -from polymath import Scalar, Pair, Vector, Vector3, Qube -from oops.observation import Observation -from oops.body import Body -from oops.cadence.metronome import Metronome -from oops.event import Event -from oops.frame import Frame -from oops.path import Path -from oops.path.multipath import MultiPath +from polymath import Scalar, Pair, Vector, Vector3, Qube +from oops.observation import Observation +from oops.body import Body +from oops.cadence import Cadence +from oops.cadence.snapcadence import SnapCadence +from oops.event import Event +from oops.fittable import Fittable_ +from oops.frame import Frame +from oops.path import Path +from oops.path.multipath import MultiPath class Snapshot(Observation): """A Snapshot is an Observation consisting of a 2-D image made up of pixels @@ -33,8 +35,11 @@ def __init__(self, axes, tstart, texp, fov, path, frame, **subfields): an image file in FITS or VICAR format. tstart the start time of the observation in seconds TDB. + Alternatively, a Cadence object with shape (1,) defining + `tstart` and `texp`. - texp exposure duration of the observation in seconds. + texp exposure duration of the observation in seconds. Ignored + if `tstart` is specified as a Cadence. fov a FOV (field-of-view) object, which describes the field of view including any spatial distortion. It maps @@ -75,17 +80,14 @@ def __init__(self, axes, tstart, texp, fov, path, frame, **subfields): self.shape[self.v_axis] = self.uv_shape[1] # Cadence - self.cadence = Metronome.for_array0d(tstart, texp) - - # Timing - self.tstart = self.cadence.tstart - self.texp = self.cadence.texp - - self.time = self.cadence.time - self.midtime = self.cadence.midtime - - self._scalar_time = (Scalar(self.time[0]), Scalar(self.time[1])) - self._scalar_midtime = Scalar(self.midtime) + if isinstance(tstart, Cadence): + self.cadence = tstart + if self.cadence.shape != (1,): + raise ValueError('Shape of Snapshot cadence must be (1,)') + self.texp = self.cadence.texp + else: + self.cadence = SnapCadence(tstart, texp) + self.texp = texp # Optional subfields self.subfields = {} @@ -93,7 +95,8 @@ def __init__(self, axes, tstart, texp, fov, path, frame, **subfields): self.insert_subfield(key, subfields[key]) def __getstate__(self): - return (self.axes, self.tstart, self.texp, self.fov, self.path, + Fittable_.refresh(self) + return (self.axes, self.cadence, self.texp, self.fov, self.path, self.frame, self.subfields) def __setstate__(self, state): @@ -119,7 +122,7 @@ def uvt(self, indices, remask=False, derivs=True): indices = Vector.as_vector(indices, recursive=derivs) uv = indices.to_pair((self.u_axis, self.v_axis)) - time = self._scalar_midtime + time = Scalar(self.cadence.midtime) if remask: is_outside = self.uv_is_outside(uv, inclusive=True) @@ -212,7 +215,7 @@ def time_range_at_uv(self, uv_pair, remask=False): return (time_min, time_max) # Without a mask, it's OK to return shapeless values - return self._scalar_time + return (Scalar(self.cadence.time[0]), Scalar(self.cadence.time[1])) #=========================================================================== def uv_range_at_time(self, time, remask=False): @@ -249,14 +252,10 @@ def time_shift(self, dtime): Return: a (shallow) copy of the object with a new time. """ - obs = Snapshot(axes=self.axes, tstart=self.time[0] + dtime, - texp=self.texp, fov=self.fov, path=self.path, - frame=self.frame) - - for key in self.subfields.keys(): - obs.insert_subfield(key, self.subfields[key]) - - return obs + cadence = self.cadence.time_shift(dtime) + return Snapshot(axes=self.axes, tstart=cadence, texp=self.texp, + fov=self.fov, path=self.path, frame=self.frame, + **self.subfields) ############################################################################ # Overrides of Observation methods diff --git a/oops/observation/timedimage.py b/oops/observation/timedimage.py index 84e2b50d..fdbaea92 100755 --- a/oops/observation/timedimage.py +++ b/oops/observation/timedimage.py @@ -7,6 +7,7 @@ from polymath import Pair, Vector, Qube from oops.observation import Observation from oops.observation.snapshot import Snapshot +from oops.fittable import Fittable_ from oops.frame import Frame from oops.path import Path @@ -25,7 +26,7 @@ class TimedImage(Observation): #=========================================================================== def __init__(self, axes, cadence, fov, path, frame, **subfields): - """Constructor for a Pushframe. + """Constructor for a TimedImage. Input: axes a list or tuple of strings, with one value for each axis @@ -112,10 +113,6 @@ def __init__(self, axes, cadence, fov, path, frame, **subfields): if len(self.cadence.shape) != 2: raise ValueError('TimedImage axes requires 2-D cadence') - # Timing - self.time = self.cadence.time - self.midtime = self.cadence.midtime - # Shape / Size self.shape = len(axes) * [0] self.shape[self.u_axis] = self.fov_shape[0] @@ -182,6 +179,7 @@ def __init__(self, axes, cadence, fov, path, frame, **subfields): self.path, self.frame, **subfields) def __getstate__(self): + Fittable_.refresh(self) return (self.axes, self.cadence, self.fov, self.path, self.frame, self.subfields) diff --git a/oops/path/__init__.py b/oops/path/__init__.py index 97588183..19b86555 100755 --- a/oops/path/__init__.py +++ b/oops/path/__init__.py @@ -11,6 +11,7 @@ from oops.path.keplerpath import KeplerPath from oops.path.linearpath import LinearPath from oops.path.multipath import MultiPath +from oops.path.pathshift import PathShift from oops.path.spicepath import SpicePath ################################################################################ diff --git a/oops/path/circlepath.py b/oops/path/circlepath.py index c8cd776c..8f78ec3c 100755 --- a/oops/path/circlepath.py +++ b/oops/path/circlepath.py @@ -1,45 +1,41 @@ -################################################################################ +########################################################################################## # oops/path/circlepath.py: Subclass CirclePath of class Path -################################################################################ - -from polymath import Qube, Scalar, Vector3 +########################################################################################## +from polymath import Qube, Scalar, Vector3 +from oops.cache import Cache from oops.event import Event +from oops.fittable import Fittable_ from oops.frame.frame_ import Frame from oops.path.path_ import Path + class CirclePath(Path): """A path describing uniform circular motion about another path. - The orientation of the circle is defined by the z-axis of the given - frame. + The orientation of the circle is defined by the z-axis of the given frame. """ - PATH_IDS = {} # path_id to use if a path already exists upon un-pickling + _PATH_IDS = {} - #=========================================================================== - def __init__(self, radius, lon, rate, epoch, origin, frame=None, - path_id=None, unpickled=False): + def __init__(self, radius, lon, rate, epoch, origin, frame=None, path_id=None): """Constructor for a CirclePath. - Input: - radius radius of the path, km. - lon longitude of the path at epoch, measured from the - x-axis of the frame, toward the y-axis, in radians. - rate rate of circular motion, radians/second. - - epoch the time TDB relative to which all orbital elements are - defined. - origin the path or ID of the center of the circle. - frame the frame or ID of the frame in which the circular - motion is defined; None to use the default frame of the - origin path. - path_id the name under which to register the new path; None to - leave the path unregistered. - unpickled True if this path has been read from a pickle file. - - Note: The shape of the Path object returned is defined by broadcasting - together the shapes of all the orbital elements plus the epoch. + Parameters: + radius (Scalar, array-like, or float): Radius of the path, km. + lon (Scalar, array-like, or float): Longitude of the path at epoch, measured + from the x-axis of the frame, toward the y-axis, in radians. + rate (Scalar, array-like, or float): Rate of circular motion, radians/second. + epoch (Scalar, array-like, or float): The time TDB relative to which all + orbital elements are defined. + origin (Path or str): The path or ID of the center of the circle. + frame (Frame or str): The frame or ID of the frame in which the circular + motion is defined; None to use the default frame of the origin path. + path_id (str, optional): The ID to use; None to leave the path unregistered. + + Notes: + The shape of the Path object returned is defined by broadcasting together the + shapes of all the orbital elements plus the epoch. """ # Interpret the elements @@ -49,52 +45,52 @@ def __init__(self, radius, lon, rate, epoch, origin, frame=None, self.rate = Scalar.as_scalar(rate) # Required attributes - self.path_id = path_id - self.origin = Path.as_waypoint(origin) - self.frame = Frame.as_wayframe(frame) or self.origin.frame - self.keys = set() - self.shape = Qube.broadcasted_shape(self.radius, self.lon, - self.rate, self.epoch, - self.origin.shape, - self.frame.shape) - - # Update waypoint and path_id; register only if necessary - self.register(unpickled=unpickled) - - # Save in internal dict for name lookup upon serialization - if (not unpickled and self.shape == () - and self.path_id in Path.WAYPOINT_REGISTRY): - key = (self.radius.vals, self.lon.vals, self.rate.vals, - self.epoch.vals, origin.path_id, frame.frame_id) - CirclePath.PATH_IDS[key] = self.path_id + self.origin = Path.as_waypoint(origin) + self.frame = Frame.as_wayframe(frame) or self.origin.frame + self.shape = Qube.broadcasted_shape(self.radius, self.lon, self.rate, + self.epoch, self.origin.shape, + self.frame.shape) + self.path_id = self._recover_id(path_id) + + self.register() + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.radius, self.lon, self.rate, self.epoch, self.origin, self.frame) def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (self.radius, self.lon, self.rate, self.epoch, Path.as_primary_path(self.origin), - Frame.as_primary_frame(self.frame), self.shape) + Frame.as_primary_frame(self.frame), self._state_id()) def __setstate__(self, state): - # If this path matches a pre-existing path, re-use its ID - (radius, lon, rate, epoch, origin, frame, shape) = state - if shape == (): - key = (radius.vals, lon.vals, rate.vals, epoch.vals, origin.path_id, - frame.frame_id) - path_id = CirclePath.PATH_IDS.get(key, None) - else: - path_id = None - - self.__init__(radius, lon, rate, epoch, origin, frame, path_id=path_id, - unpickled=True) - - #=========================================================================== + (radius, lon, rate, epoch, origin, frame, path_id) = state + self.__init__(radius, lon, rate, epoch, origin, frame, path_id=path_id) + Fittable_.freeze(self) + + ###################################################################################### + # Path API + ###################################################################################### + def event_at_time(self, time, quick=False): """An Event corresponding to a specified time on this path. - Input: - time a time Scalar at which to evaluate the path. + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. - Return: an Event object containing (at least) the time, position - and velocity on the path. + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ lon = self.lon + self.rate * (Scalar.as_scalar(time) - self.epoch) @@ -102,9 +98,8 @@ def event_at_time(self, time, quick=False): r_sin_lon = self.radius * lon.sin() pos = Vector3.from_scalars(r_cos_lon, r_sin_lon, 0.) - vel = Vector3.from_scalars(-r_sin_lon * self.rate, - r_cos_lon * self.rate, 0.) + vel = Vector3.from_scalars(-r_sin_lon * self.rate, r_cos_lon * self.rate, 0.) return Event(time, (pos,vel), self.origin, self.frame) -################################################################################ +########################################################################################## diff --git a/oops/path/coordpath.py b/oops/path/coordpath.py index 429e75d6..66d494e5 100755 --- a/oops/path/coordpath.py +++ b/oops/path/coordpath.py @@ -1,35 +1,32 @@ -################################################################################ +######################################################################################### # oops/path/coordpath.py: Subclass CoordPath of class Path -################################################################################ +######################################################################################### from polymath import Qube, Scalar from oops.event import Event +from oops.fittable import Fittable_ from oops.path.path_ import Path + class CoordPath(Path): """A path defined by fixed coordinates on a specified Surface.""" - # Note: CoordPaths are not generally re-used, so their IDs are expendable. - # Their IDs are not preserved during pickling. + _PATH_IDS = {} - #=========================================================================== - def __init__(self, surface, coords, obs=None, path_id=None): + def __init__(self, surface, coords, obs=None, *, path_id=None): """Constructor for a CoordPath. - Input: - surface a surface. - coords a tuple of 2 or 3 Scalars defining the coordinates on - the surface. - obs optional path of observer, needed to calculate points - on virtual surfaces. - path_id the name under which to register the new path; None to - leave the path unregistered. + Parameters: + surface (Surface): The surface to which the coordinates refer. + coords (tuple): 2 or 3 Scalars defining the coordinates on the surface. + obs (Path or str, optional): Path of observer, needed to calculate points on + virtual surfaces. + path_id (str, optional): The ID to use; None to leave the path unregistered. """ - if surface.IS_VIRTUAL: - raise NotImplementedError('CoordPath cannot be defined for virtual ' - 'surface class ' - + type(surface).__name__) + if surface.IS_VIRTUAL and obs is None: + raise NotImplementedError('CoordPath requires an observation path for ' + 'virtual surface class ' + type(surface).__name__) self.surface = surface self.coords = tuple(Scalar(x) for x in coords) @@ -37,46 +34,61 @@ def __init__(self, surface, coords, obs=None, path_id=None): self.pos = self.surface.vector3_from_coords(self.coords) # Required attributes - self.path_id = path_id - self.origin = self.surface.origin - self.frame = self.origin.frame - self.keys = set() - self.shape = Qube.broadcasted_shape(self.obs_path, *self.coords) + self.origin = self.surface.origin + self.frame = self.origin.frame + self.shape = Qube.broadcasted_shape(self.obs_path, *self.coords) + self.path_id = self._recover_id(path_id) - # Update waypoint and path_id; register only if necessary self.register() + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.radius, self.lon, self.rate, self.epoch, self.origin, self.frame) - # Unpickled paths will always have temporary IDs to avoid conflicts def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (self.surface, self.coords, - None if self.obs_path is None - else Path.as_primary_path(self.obs_path)) + None if self.obs_path is None else Path.as_primary_path(self.obs_path), + self._state_id()) def __setstate__(self, state): - self.__init__(*state) + (surface, coords, obs, path_id) = state + self.__init__(surface, coords, obs, path_id=path_id) + Fittable_.freeze(self) + + ###################################################################################### + # Path API + ###################################################################################### - #=========================================================================== def event_at_time(self, time, quick={}): """An Event corresponding to a specified time on this path. - Input: - time a time Scalar at which to evaluate the path. + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. - Return: an Event object containing (at least) the time, position - and velocity on the path. + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ return Event(time, self.pos, self.origin, self.frame) - #=========================================================================== - def _solve_photon(self, link, sign, derivs=False, guess=None, antimask=None, - quick={}, converge={}): + def _solve_photon(self, link, sign, derivs=False, guess=None, antimask=None, quick={}, + converge={}): """Override of the default method to avoid extra iteration.""" return self.surface._solve_photon_by_coords(link, self.coords, sign, derivs=derivs, guess=guess, - antimask=antimask, - quick=quick, + antimask=antimask, quick=quick, converge=converge) -################################################################################ +######################################################################################### diff --git a/oops/path/fixedpath.py b/oops/path/fixedpath.py index 6c42d48b..f124dd79 100755 --- a/oops/path/fixedpath.py +++ b/oops/path/fixedpath.py @@ -1,31 +1,29 @@ -################################################################################ +########################################################################################## # oops/path/fixedpath.py: Subclass FixedPath of class Path -################################################################################ +########################################################################################## from polymath import Qube, Vector3 from oops.event import Event +from oops.fittable import Fittable_ from oops.frame.frame_ import Frame from oops.path.path_ import Path + class FixedPath(Path): - """A path described by fixed coordinates relative to another path and frame. - """ + """A path described by fixed coordinates relative to another path and frame.""" - # Note: FixedPaths are not generally re-used, so their IDs are expendable. - # Their IDs are not preserved during pickling. + _PATH_IDS = {} - #=========================================================================== def __init__(self, pos, origin, frame, path_id=None): """Constructor for an FixedPath. - Input: - pos a Vector3 of position vectors within the frame and - relative to the specified origin. - origin the path or ID of the reference point. - frame the frame or ID of the frame in which the position is - fixed. - path_id the name under which to register the new path; None to - leave the path unregistered. + Parameters: + pos (Vector3 or array-like): The position vectors within the frame and + relative to the specified origin. + origin (Path or str): The path or ID of the center of the circle. + frame (Frame or str): The frame or ID of the frame in which the fixed + coordinates are defined. + path_id (str, optional): The ID to use; None to leave the path unregistered. """ # Interpret the position @@ -34,35 +32,51 @@ def __init__(self, pos, origin, frame, path_id=None): self.pos = pos.as_readonly() # Required attributes - self.path_id = path_id - self.origin = Path.as_waypoint(origin) - self.frame = Frame.as_wayframe(frame) or self.origin.frame - self.keys = set() - self.shape = Qube.broadcasted_shape(self.pos, self.origin, self.frame) + self.origin = Path.as_waypoint(origin) + self.frame = Frame.as_wayframe(frame) or self.origin.frame + self.shape = Qube.broadcasted_shape(self.pos, self.origin, self.frame) + self.path_id = Path._recover_id(path_id) - # Update waypoint and path_id; register only if necessary self.register() + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.pos, self.origin, self.frame) - # Unpickled paths will always have temporary IDs to avoid conflicts def __getstate__(self): - return (self.pos, - Path.as_primary_path(self.origin), - Frame.as_primary_frame(self.frame)) + Fittable_.refresh(self) + self._cache_id() + return (self.pos, Path.as_primary_path(self.origin), + Frame.as_primary_frame(self.frame), self._state_id()) def __setstate__(self, state): - self.__init__(*state) + (pos, origin, frame, path_id) = state + self.__init__(pos, origin, frame, path_id=path_id) + Fittable_.freeze(self) + + ###################################################################################### + # Path API + ###################################################################################### - #========================================================================== def event_at_time(self, time, quick=False): """An Event corresponding to a specified time on this path. - Input: - time a time Scalar at which to evaluate the path. + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. - Return: an Event object containing (at least) the time, position - and velocity on the path. + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ return Event(time, self.pos, self.origin, self.frame) -################################################################################ +########################################################################################## diff --git a/oops/path/keplerpath.py b/oops/path/keplerpath.py index 698df838..baeb9625 100755 --- a/oops/path/keplerpath.py +++ b/oops/path/keplerpath.py @@ -1,10 +1,11 @@ -################################################################################ +########################################################################################## # oops/path/keplerpath.py: Subclass KeplerPath of class Path. -################################################################################ +########################################################################################## import numpy as np from polymath import Scalar, Vector3, Matrix3 +from oops.cache import Cache from oops.event import Event from oops.fittable import Fittable from oops.frame.frame_ import Frame @@ -28,94 +29,85 @@ NWOBBLES = 3 + class KeplerPath(Path, Fittable): """A Path subclass that defines a fittable Keplerian orbit. - It is accurate to first order in eccentricity and inclination, and is - defined using nine orbital elements. + It is accurate to first order in eccentricity and inclination, and is defined using + nine orbital elements. """ - PATH_IDS = {} + _PATH_IDS = {} - #=========================================================================== def __init__(self, body, epoch, elements=None, observer=None, wobbles=(), - frame=None, path_id=None, unpickled=False): + path_id=None): """Constructor for a KeplerPath. - Input: - body a Body object defining the central planet, including its - gravity and its ring_frame. - epoch the time TDB relative to which all orbital elements are - defined. - - elements a tuple, list or Numpy array containing the orbital - elements and wobble terms: - a mean radius of orbit, km. - lon mean longitude of orbit at epoch, radians. - n mean motion, radians/sec. - - e orbital eccentricity. - peri longitude of pericenter at epoch, radians. - prec pericenter precession rate, radians/sec. - - i inclination, radians. - node longitude of ascending node at epoch, radians. - regr nodal regression rate, radians/sec, NEGATIVE! - - Repeat for each wobble: - - amp amplitude of the first wobble term, radians. - phase0 initial phase of the first wobble term, radians. - dphase_dt rate of change of the first wobble term, radians/s. - - Alternatively, a dictionary containing keys with these - names, or None to leave the object un-initialized. - - observer an optional Path object or ID defining the observation - point. Used for astrometry. If provided, the path is - returned relative to the observer, in J2000 coordinates, - and with light travel time from the central planet - already accounted for. If None (the default), then the - path is defined relative to the central planet in that - planet's ring_frame. - - wobbles a string or tuple of strings containing the name of each - element to which the corresponding wobble applies. Use - 'mean', 'peri' or 'node', 'a', 'e', or 'i', for - individual elements. Use 'e2d' for a forced eccentricity - and 'i2d' for a forced inclination. Use 'pole' for a - Laplace plane offset). - - frame an optional frame in which the orbit is defined. By - default, this is the ring_frame of the planet. Ignored - if observer is defined. - - path_id the name under which to register the path. - - unpickled True if this path has been read from a pickle file. + Parameters: + body (Body or str): The Body object or name of the central planet, including + its gravity and its ring_frame. + + epoch (float): The time TDB relative to which all orbital elements are + defined. + + elements (array-like or dict, optional): The orbital elements and wobble + terms. If an array-like object is provided, this is the order of the + elements: + + * [0]: `a`, mean radius of orbit, km. + * [1]: `lon`, mean longitude of orbit at epoch, radians. + * [2]: `n`, mean motion, radians/sec. + * [3]: `e`, orbital eccentricity. + * [4]: `peri`, longitude of pericenter at epoch, radians. + * [5]: `prec`,pericenter precession rate, radians/sec. + * [6]: `i`, inclination, radians. + * [7]: `node`, longitude of ascending node at epoch, radians. + * [8]: `regr`, nodal regression rate, radians/sec, NEGATIVE! + + You can include additional "wobble" terms, which can describe + non-Keplerian orbital perturbations. Repeat these three elements for each + wobble term: + + * [9, 12, ...] `amp`: amplitude of the term, radians. + * [10, 13, ...] `phase0`: initial phase of the first wobble term, radians. + * [11, 14, ...] `dphase_dt`: rate of change of the first wobble term, + radians/s. + + Alternatively, provide a dictionary containing keys with these names (in + which case only one wobble term is allowed). If the elements are not + provided, the object remains un-initialized until `set_elements` is + called. + + observer (Path or str, optional): Identification of the Path of the observer. + If provided, then `event_at_time` returns positions relative to this + observer in J2000 coordinates and with light travel time already accounted + for; this makes it easy to use this Path object for astrometry and orbit + fitting. If not provided, `event_at_time` returns positions relative to + the planet center and in the planet's `ring_frame`. + + wobbles (str or tuple): The name(s) of each wobble element: + + * "a:: semimajor axis. + * "e": eccentricity. + * "i": inclination. + * "mean": mean motion rate. + * "peri": pericenter. + * "node": ascending node. + * "e2d": forced eccentricity represented by a 2-D vector. + * "i2d": forced inclination represented by a 2-D vector. + * "pole": an offset to the Laplace plane. + + path_id (str, optional): The ID to use; None to leave the path unregistered. """ - global SEMIM, MEAN0, DMEAN, ECCEN, PERI0, DPERI, INCLI, NODE0, DNODE - global NELEMENTS - - global LIBAMP, PHASE0, DPHASE - global NWOBBLES - - if isinstance(wobbles, str): - self.wobbles = (wobbles,) - else: - self.wobbles = wobbles - + self.wobbles = (wobbles,) if isinstance(wobbles, str) else wobbles self.nwobbles = len(wobbles) for name in self.wobbles: - if name not in {'mean', 'peri', 'node', 'a', 'e', 'i', 'e2d', 'i2d', - 'pole'}: - raise ValueError('invalid name for wobble in KeplerPath: ' - + repr(name)) + if name not in {'mean', 'peri', 'node', 'a', 'e', 'i', 'e2d', 'i2d', 'pole'}: + raise ValueError('invalid name for wobble in KeplerPath: ' + repr(name)) - self.nparams = NELEMENTS + self.nwobbles * NWOBBLES - self.param_name = "elements" - self.cache = {} + self.nelements = NELEMENTS + self.nwobbles * NWOBBLES + self._fittable_nparams = self.nelements self.planet = Path.BODY_CLASS.as_body(body) self.center = self.planet.path @@ -124,7 +116,7 @@ def __init__(self, body, epoch, elements=None, observer=None, wobbles=(), if observer is None: self.observer = None self.origin = self.planet.path - self.frame = frame or self.planet.ring_frame + self.frame = self.planet.ring_frame self.to_j2000 = Matrix3.IDENTITY else: self.observer = Path.as_path(observer) @@ -132,11 +124,12 @@ def __init__(self, body, epoch, elements=None, observer=None, wobbles=(), raise ValueError('KeplerPath requires a shapeless observer') self.origin = self.observer - self.frame = frame or Frame.J2000 + self.frame = Frame.J2000 frame = self.frame.wrt(self.planet.ring_frame) self.to_j2000 = frame.transform_at_time(epoch).matrix self.epoch = float(epoch) + self._events = Cache() if elements is None: self.elements = None @@ -158,71 +151,56 @@ def __init__(self, body, epoch, elements=None, observer=None, wobbles=(), elements['phase0'], elements['dphase_dt'] ] - self.set_params(items) + self.set_elements(items) else: - self.set_params(elements) + self.set_elements(elements) - self.path_id = path_id self.shape = () - self.keys = set() - self.register(unpickled=unpickled) + self.path_id = self._recover_id() + self.register() - # Save in internal dict for name lookup upon serialization - if self.path_id in Path.WAYPOINT_REGISTRY: - key = (self.planet.name, self.epoch, tuple(self.elements), - self.observer.path_id if self.observer else None, - self.wobbles, self.frame.frame_id) - KeplerPath.PATH_IDS[key] = self.path_id + ###################################################################################### + # Fittable support + ###################################################################################### - # Unpickled paths will always have temporary IDs to avoid conflicts - def __getstate__(self): - return (self.planet, self.epoch, self.elements, - Path.as_primary_path(self.observer), - self.wobbles, - Frame.as_primary_frame(self.frame)) + def _set_params(self, params): + """Re-define the orbital elements of this KeplerPath.""" + self.set_elements(params) - def __setstate__(self, state): - # If this path matches a pre-existing path, re-use its ID - (body, epoch, elements, observer, wobbles, frame) = state - key = (body.name, epoch, tuple(elements), - observer.path_id if observer else None, - wobbles, frame.frame_id) - path_id = KeplerPath.PATH_IDS.get(key, None) - self.__init__(*state, path_id=path_id, unpickled=True) - - #=========================================================================== - def set_params_new(self, elements): + @property + def _params(self): + return tuple(self.elements) + + def set_elements(self, elements): """Re-define the path given new orbital elements. Part of the Fittable interface. Input: - elements An array or list of orbital elements. In order, they are - [a, mean0, dmean, e, peri0, dperi, i, node0, dnode], - followed by - [amp, phase0, dphase] - for each wobble. - - a semimajor axis (km). - mean0 mean longitude (radians) at the epoch. - dmean mean motion (radians/s). - e eccentricity. - peri0 longitude of pericenter (radians) at the epoch. - dperi pericenter precession rate (rad/s). - i inclination (radians). - node0 ascending node (radians) at the epoch. - dnode nodal regression rate (rad/s, < 0). - - amp amplitude of the wobble, radians. - phase0 phase of the wobble at epoch, radians. - dphase rate of change of the wobble, radians/s. - """ + elements (array-like): The orbital elements and wobble terms, in this order: + + * [0]: `a`, mean radius of orbit, km. + * [1]: `lon`, mean longitude of orbit at epoch, radians. + * [2]: `n`, mean motion, radians/sec. + * [3]: `e`, orbital eccentricity. + * [4]: `peri`, longitude of pericenter at epoch, radians. + * [5]: `prec`,pericenter precession rate, radians/sec. + * [6]: `i`, inclination, radians. + * [7]: `node`, longitude of ascending node at epoch, radians. + * [8]: `regr`, nodal regression rate, radians/sec, NEGATIVE! + + If the orbit has "wobble" terms, which can describe streamlines + of ring particles. Repeat these three elements for each wobble term: + + * [9, 12, ...] `amp`: amplitude of the term, radians. + * [10, 13, ...] `phase0`: initial phase of the first wobble term, radians. + * [11, 14, ...] `dphase_dt`: rate of change of the first wobble term, + radians/s. - global SEMIM, MEAN0, DMEAN, ECCEN, PERI0, DPERI, INCLI, NODE0, DNODE - global NELEMENTS + """ - self.elements = np.asarray(elements, dtype=np.float64) - if self.elements.shape != (self.nparams,): + self.elements = np.array(elements, dtype=np.float64) + if self.elements.shape != (self.nelements,): raise ValueError('revised KeplerPath elements do not match shape ' 'of original') @@ -252,57 +230,59 @@ def set_params_new(self, elements): self.phase0 = 0. self.dphase_dt = 0. - # Empty the cache - self.cached_observation_time = None - self.cached_planet_event = None - - #=========================================================================== - def copy(self): - """A deep copy of the object. Part of the Fittable interface.""" - - return KeplerPath(self.planet, self.epoch, self.get_params().copy(), - self.observer, self.wobbles) - - #=========================================================================== def get_elements(self): """The complete set of orbital elements, including wobbles.""" return self.elements - #=========================================================================== - def xyz_planet(self, time, partials=False): - """Body position and velocity relative to the planet, in planet's frame. + ###################################################################################### + # Serialization support + ###################################################################################### - Results are returned in an inertial frame where the Z-axis is aligned - with the planet's rotation pole. Optionally, it also returns the partial - derivatives of the position vector with respect to the orbital elements, - on the the assumption that all orbital elements are independent. The - The coordinates are only accurate to first order in (e,i) and in the - wobbles. The derivatives are precise relative to the definitions of - these elements. However, partials are not provided for the wobbles. + def _path_key(self): + return (self.planet, self.epoch, self.elements, self.observer, self.wobbles) - Input: - time time (seconds) as a Scalar. - partials True to include partial derivatives of the position with - respect to the elements. + def __getstate__(self): + self.refresh() + self._cache_id() + return (self.planet, self.epoch, self.elements, + Path.as_primary_path(self.observer), self.wobbles, self._state_id()) - Return: (pos, vel) - pos a Vector3 of position vectors. - vel a Vector3 of velocity vectors. - """ + def __setstate__(self, state): + (body, epoch, elements, observer, wobbles, path_id) = state + self.__init__(body, epoch, elements, observer, wobbles=wobbles, path_id=path_id) + self.freeze() - global SEMIM, MEAN0, DMEAN, ECCEN, PERI0, DPERI, INCLI, NODE0, DNODE - global NELEMENTS + ###################################################################################### + # Orbit calculation relative to planet + ###################################################################################### - global LIBAMP, PHASE0, DPHASE - global NWOBBLES + def xyz_planet(self, time, partials=False): + """Body position and velocity relative to the planet, in planet's frame. + + Results are returned in an inertial frame where the Z-axis is aligned with the + planet's rotation pole. Optionally, it also returns the partial derivatives of the + position vector with respect to the orbital elements, on the the assumption that + all orbital elements are independent. The The coordinates are only accurate to + first order in (e,i) and in the wobbles. The derivatives are precise relative to + the definitions of these elements. However, partials are not provided for the + wobbles. + + Parameters: + time (Scalar or float): Time in seconds TDB. + partials (bool, optional): True to include partial derivatives of the position + with respect to the elements. + + Returns: + (tuple): (position, velocity), each represented by a Vector3. + """ # Convert to array if necessary time = Scalar.as_scalar(time) t = time.vals - self.epoch if partials: - partials_shape = time.shape + (self.nparams,) + partials_shape = time.shape + (self.nelements,) dmean_delem = np.zeros(partials_shape) dperi_delem = np.zeros(partials_shape) dnode_delem = np.zeros(partials_shape) @@ -310,12 +290,12 @@ def xyz_planet(self, time, partials=False): de_delem = np.zeros(partials_shape) di_delem = np.zeros(partials_shape) - ######################################################################## + ################################################################################## # Determine three angles and their time derivatives # mean = mean0 + t * dmean_dt # peri = peri0 + t * dperi_dt # node = node0 + t * dnode_dt - ######################################################################## + ################################################################################## mean = self.mean0 + t * self.dmean_dt peri = self.peri0 + t * self.dperi_dt @@ -344,9 +324,9 @@ def xyz_planet(self, time, partials=False): de_delem[..., ECCEN] = 1. di_delem[..., INCLI] = 1. - ######################################################################## + ################################################################################## # Apply the wobbles - ######################################################################## + ################################################################################## # For Laplace planes laplace_plane = False @@ -526,9 +506,9 @@ def xyz_planet(self, time, partials=False): laplace_sin_node = sin_arg laplace_cos_node = cos_arg - ######################################################################## + ################################################################################## # Evaluate some derived elements - ######################################################################## + ################################################################################## ae = a * e cos_i = np.cos(i) @@ -550,11 +530,11 @@ def xyz_planet(self, time, partials=False): dcosi_delem[...,INCLI] = -sin_i dsini_delem[...,INCLI] = cos_i - ######################################################################## + ################################################################################## # Determine moon polar coordinates in orbit plane # r = a - a * e * cos(mean - peri)) # theta = mean - 2 * e * sin(mean - peri) - ######################################################################## + ################################################################################## mp = mean - peri cos_mp = np.cos(mp) @@ -583,13 +563,13 @@ def xyz_planet(self, time, partials=False): (e[...,np.newaxis] * dsinmp_delem + de_delem * sin_mp[...,np.newaxis]) - ######################################################################## + ################################################################################## # Locate body on an inclined orbit, in a frame where X is along the # ascending node # asc[X] = r cos(theta - node) # asc[Y] = r sin(theta - node) cos(i) # asc[Z] = r sin(theta - node) sin(i) - ######################################################################## + ################################################################################## tn = theta - node cos_tn = np.cos(tn) @@ -631,12 +611,12 @@ def xyz_planet(self, time, partials=False): r[...,np.newaxis,np.newaxis] * dunit1_delem # shape is (..., 9, 3) - ######################################################################## + ################################################################################## # Rotate the ascending node back into position in our inertial frame # xyz[X] = asc[X] * cos(node) - asc[Y] * sin(node) # xyz[Y] = asc[X] * sin(node) + asc[Y] * cos(node) # xyz[Z] = asc[Z] - ######################################################################## + ################################################################################## cos_node = np.cos(node) sin_node = np.sin(node) @@ -680,12 +660,12 @@ def xyz_planet(self, time, partials=False): asc[...,np.newaxis,np.newaxis,:], axis=-1) # shape = (..., 9, 3) - ######################################################################## + ################################################################################## # Apply Laplace Plane # asc[X] = r cos(theta - node) # asc[Y] = r sin(theta - node) cos(i) # asc[Z] = r sin(theta - node) sin(i) - ######################################################################## + ################################################################################## if laplace_plane: node = np.array([laplace_cos_node, laplace_sin_node, 0.]) @@ -717,9 +697,9 @@ def xyz_planet(self, time, partials=False): dxyz_delem = np.sum(rotate[...,np.newaxis,:,:] * dxyz_delem[...,np.newaxis,:], axis=-1) - ######################################################################## + ################################################################################## # Return results - ######################################################################## + ################################################################################## pos = Vector3(xyz) vel = Vector3(dxyz_dt) @@ -730,57 +710,25 @@ def xyz_planet(self, time, partials=False): return (pos, vel) - #=========================================================================== - def xyz_observed(self, time, quick={}, partials=False, planet_event=None): - """Body position and velocity relative to the observer in J2000 frame. - - Input: - time time (seconds) as a Scalar. - partials True to include partial derivatives of the position with - quick False to disable QuickPaths; a dictionary to override - specific options. - respect to the elements. - planet_event - the corresponding event of the photon leaving the - planet; None to calculate this from the time. Note that - this can be calculated once using planet_event(), - avoiding the need to re-calculate this quantity for - repeated calls using the same time(s). - - Return: - pos a Vector3 of position vectors. - vel a Vector3 of velocity vectors. - """ - - if planet_event is None: - observer_event = Event(time, Vector3.ZERO, - self.observer, self.frame) - planet_event = self.center.photon_to_event(observer_event, - quick=quick)[0] - - (pos, vel) = self.xyz_planet(planet_event.time, partials) - - pos_j2000 = self.to_j2000.rotate(pos) + planet_event.pos - vel_j2000 = self.to_j2000.rotate(vel) + planet_event.vel - - return (pos_j2000, vel_j2000) + ###################################################################################### + # Path API + ###################################################################################### - #=========================================================================== - def event_at_time(self, time, quick={}, partials=False, planet_event=None): + def event_at_time(self, time, *, quick={}, partials=False): """An Event object corresponding to a Scalar time on this path. - Input: - time a time Scalar at which to evaluate the path. - quick False to disable QuickPaths; a dictionary to override - specific options. - partials True to include the derivatives of position with respect - to the orbital elements; False otherwise. - planet_event optional event of the photon leaving the center of the - planet. Saves a re-calculation if the time is re-used. - Only relevant when an observer is defined. - - Return: an Event object containing the time, position and - velocity of the paths. + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. For KeplerPaths, this is actually the time at the planet's + center. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. + partials (bool, optional): True to include the derivatives of position with + respect to the orbital elements. + + Returns: + (Event): Event containing the time, position, and velocity on the path. """ # Without an observer, return event in the planet frame @@ -788,27 +736,40 @@ def event_at_time(self, time, quick={}, partials=False, planet_event=None): (pos, vel) = self.xyz_planet(time, partials=partials) return Event(time, (pos, vel), self.origin, self.frame) - # Otherwise, return the event WRT the observer - (pos, vel) = self.xyz_observed(time, quick, partials, planet_event) - return Event(time, (pos, vel), self.observer, Frame.J2000) + # With an observer, return event in J2000, with the light time accounted for + planet_event = self._photon_from_planet(time, quick=quick)[0] + (pos, vel) = self.xyz_planet(planet_event.time, partials=partials) + pos_j2000 = self.to_j2000.rotate(pos) + planet_event.pos + vel_j2000 = self.to_j2000.rotate(vel) + planet_event.vel + return Event(time, (pos_j2000, vel_j2000), self.observer, Frame.J2000) + + def _photon_from_planet(self, time, *, derivs=False, guess=None, antimask=None, + quick={}, converge={}): + + # Check the cache for the planet event + events = self._events[time] + if events is None: + obs_event = Event(time, Vector3.ZERO, self.observer, self.frame) + events = self.center.photon_to_event(obs_event, derivs=derivs, guess=guess, + antimask=antimask, quick=quick, + converge=converge) + if np.size(time) == 1: + self._events[time] = events + + return events - #=========================================================================== def node_at_time(self, time): """The longitude of ascending node at the specified time. - Wobbles are ignored. The angle is a positive rotation about the planet's - ring frame. + Wobbles are ignored. The angle is a positive rotation about the planet's ring + frame. """ - global NODE0, DNODE - time = Scalar.as_scalar(time) - return self.elements[NODE0] + (time-self.epoch) * self.elements[DNODE] + return self.elements[NODE0] + (time - self.epoch) * self.elements[DNODE] - #=========================================================================== def pole_at_time(self, time): - """The J2000 vector pointing toward the orbit's pole at the specified - time. + """The J2000 vector pointing toward the orbit's pole at the specified time. Wobbles are ignored. """ @@ -823,47 +784,45 @@ def pole_at_time(self, time): sin_node = np.sin(node) # This vector is 90 degrees behind of the node in the reference equator - target_in_j2000 = ( sin_node * x_axis_in_j2000 + - -cos_node * y_axis_in_j2000) + target_in_j2000 = (sin_node * x_axis_in_j2000 - cos_node * y_axis_in_j2000) return self.cos_i * z_axis_in_j2000 + self.sin_i * target_in_j2000 - ############################################################################ + ###################################################################################### # Override for the case where observer != None - ############################################################################ + ###################################################################################### - def photon_to_event(self, arrival, derivs=False, guess=None, quick={}, - converge={}, partials=False): + def photon_to_event(self, arrival, derivs=False, *, guess=None, antimask=None, + quick={}, converge={}, partials=False): """The photon departure event from this path to match the arrival event. + + This is an override of the default method, provided to support the partial + derivatives. """ if self.observer is None: (path_event, - link_event) = super(KeplerPath, - self).photon_to_event(arrival, derivs, guess, - quick=quick, - converge=converge) + obs_event) = Path.photon_to_event(self, arrival, derivs=derivs, guess=guess, + antimask=antimask, quick=quick, + converge=converge) if partials: - (pos, vel) = self.xyz_planet(link_event.time, partials=True) + (pos, vel) = self.xyz_planet(path_event.time, partials=True) path_event.pos.insert_deriv('elements', pos.d_delements) - return (path_event, link_event) + return (path_event, obs_event) (planet_event, - link_event) = self.center.photon_to_event(arrival, derivs, guess, - quick, converge) - - path_event = self.event_at_time(planet_event.time, quick, partials, - planet_event) + obs_event) = self._photon_from_planet(arrival.time, derivs=derivs, guess=guess, + antimask=antimask, quick=quick, + converge=converge) - path_event.dep_lt = planet_event.time - link_event.time - path_event.dep_j2000 = path_event.pos_j2000 - link_event.pos_j2000 + path_event = self.event_at_time(planet_event.time, quick=quick, partials=partials) + path_event.dep_lt = path_event.time - obs_event.time + path_event.dep_j2000 = path_event.pos_j2000 - obs_event.pos_j2000 - link_event = Event(link_event.time, link_event.state, - link_event.origin, link_event.frame) - link_event.arr_lt = path_event.dep_lt - link_event.arr_j2000 = path_event.dep_j2000 + obs_event.arr_lt = path_event.dep_lt + obs_event.arr_j2000 = path_event.dep_j2000 - return (path_event, link_event) + return (path_event, obs_event) -################################################################################ +########################################################################################## diff --git a/oops/path/linearcoordpath.py b/oops/path/linearcoordpath.py index 19cb421b..07ee3798 100755 --- a/oops/path/linearcoordpath.py +++ b/oops/path/linearcoordpath.py @@ -1,39 +1,34 @@ -################################################################################ +########################################################################################## # oops/path/linearcoordpath.py: Subclass LinearCoordPath of class Path -################################################################################ +########################################################################################## from polymath import Qube, Scalar from oops.event import Event +from oops.fittable import Fittable_ from oops.path.path_ import Path + class LinearCoordPath(Path): - """A path defined by coordinates changing linearly on a specified Surface. - """ + """A path defined by coordinates changing linearly on a specified Surface.""" - # Note: LinearCoordPaths are not generally re-used, so their IDs are - # expendable. Their IDs are not preserved during pickling. + _PATH_IDS = {} - #=========================================================================== def __init__(self, surface, coords, coords_dot, epoch, obs=None, path_id=None): """Constructor for a LinearCoordPath. - Input: - surface a surface. - coords a tuple of 2 or 3 Scalars defining the coordinates on - the surface. - coords_dot the time-derivative of the coords. - epoch the epoch at which the coords are defined, seconds TDB. - obs optional path of observer, needed to calculate points - on virtual surfaces. - path_id the name under which to register the new path; None to - leave the path unregistered. + Parameters: + surface (Surface): The surface to which the coordinates refer. + coords (tuple): 2 or 3 Scalars defining the coordinates on the surface. + coords_dot (tuple): The time-derivatives of `coords`. + obs (Path or str, optional): Path of observer, needed to calculate points on + virtual surfaces. + path_id (str, optional): The ID to use; None to leave the path unregistered. """ - if self.surface.IS_VIRTUAL: - raise NotImplementedError('LinearCoordPath cannot be defined for ' - 'virtual surface class ' - + type(self.surface).__name__) + if surface.IS_VIRTUAL and obs is None: + raise NotImplementedError('LinearCoordPath requires an observation path for ' + 'virtual surface class ' + type(surface).__name__) self.surface = surface self.coords = [Scalar.as_scalar(c) for c in coords] @@ -42,34 +37,49 @@ def __init__(self, surface, coords, coords_dot, epoch, obs=None, self.obs_path = Path.as_path(obs) # Required attributes - self.path_id = path_id - self.origin = self.surface.origin - self.frame = self.origin.frame - self.keys = set() - self.shape = Qube.broadcasted_shape(self.surface, *self.coords, - *self.coords_dot, self.epoch, - self.obs_path) - - # Update waypoint and path_id; register only if necessary + self.origin = self.surface.origin + self.frame = self.origin.frame + self.shape = Qube.broadcasted_shape(self.surface, *self.coords, *self.coords_dot, + self.epoch, self.obs_path) + self.path_id = self._recover_id(path_id) + self.register() + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.surface, *self.coords, *self.coords_deriv, self.epoch, self.obs) - # Unpickled paths will always have temporary IDs to avoid conflicts def __getstate__(self): + Fittable_.refresh(self) + self._cache_id() return (self.surface, self.coords, self.coords_dot, self.epoch, - Path.as_primary_path(self.obs_path)) + Path.as_primary_path(self.obs_path), self._state_id()) def __setstate__(self, state): - self.__init__(*state) + self.__init__(*state[:-1], path_id=state[-1]) + Fittable_.freeze(self) + + ###################################################################################### + # Path API + ###################################################################################### - #=========================================================================== def event_at_time(self, time, quick={}): """An Event corresponding to a specified time on this path. - Input: - time a time Scalar at which to evaluate the path. + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. - Return: an Event object containing (at least) the time, position - and velocity on the path. + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ new_coords = [] @@ -82,4 +92,4 @@ def event_at_time(self, time, quick={}): pos = self.surface.vector3_from_coords(new_coords, derivs=True) return Event(time, pos, self.origin, self.frame) -################################################################################ +########################################################################################## diff --git a/oops/path/linearpath.py b/oops/path/linearpath.py index 40c2b58a..f2b675f5 100755 --- a/oops/path/linearpath.py +++ b/oops/path/linearpath.py @@ -1,44 +1,41 @@ -################################################################################ +########################################################################################## # oops/path/linearpath.py: Subclass LinearPath of class Path -################################################################################ +########################################################################################## from polymath import Qube, Scalar, Vector3 from oops.event import Event +from oops.fittable import Fittable_ from oops.frame.frame_ import Frame from oops.path.path_ import Path class LinearPath(Path): """A path defining linear motion relative to another path and frame.""" - # Note: LinearPaths are not generally re-used, so their IDs are expendable. - # Their IDs are not preserved during pickling. + _PATH_IDS = {} - #=========================================================================== def __init__(self, pos, epoch, origin, frame=None, path_id=None): """Constructor for a LinearPath. Input: - pos a Vector3 of position vectors. The velocity should be - defined via a derivative 'd_dt'. Alternatively, it can - be specified as a tuple of two Vector3 objects, - (position, velocity). - epoch time Scalar relative to which the motion is defined, - seconds TDB - origin the path or path ID of the reference point. - frame the frame or frame ID of the coordinate system; None for - the frame used by the origin path. - path_id the name under which to register the new path; None to - leave the path unregistered. + pos (Vector3, array-like, or tuple): Position vector(s). The velocity is + defined via a derivative 'd_dt'. Alternatively, provide (pos, vel) as a + tuple of two Vector3 or array-like values. + epoch (Scalar, array-like, or float): The time TDB relative to which all + orbital elements are defined. + origin (Path or str): The path or ID of the center of the circle. + frame (Frame or str): The frame or ID of the frame in which the fixed + coordinates are defined. + path_id (str, optional): The ID to use; None to leave the path unregistered. """ # Interpret the position - if isinstance(pos, (tuple,list)) and len(pos) == 2: + if isinstance(pos, (tuple, list)) and len(pos) == 2: self.pos = Vector3.as_vector3(pos[0]).wod.as_readonly() self.vel = Vector3.as_vector3(pos[1]).wod.as_readonly() else: pos = Vector3.as_vector3(pos) - if hasattr('d_dt', pos): + if hasattr(pos, 'd_dt'): self.vel = pos.d_dt.as_readonly() else: self.vel = Vector3.ZERO @@ -48,38 +45,53 @@ def __init__(self, pos, epoch, origin, frame=None, path_id=None): self.epoch = Scalar.as_scalar(epoch) # Required attributes - self.path_id = path_id - self.origin = Path.as_waypoint(origin) - self.frame = Frame.as_wayframe(frame) or self.origin.frame - self.keys = set() - self.shape = Qube.broadcasted_shape(self.pos, self.vel, - self.epoch, - self.origin, self.frame) - - # Update waypoint and path_id; register only if necessary + self.origin = Path.as_waypoint(origin) + self.frame = Frame.as_wayframe(frame) or self.origin.frame + self.shape = Qube.broadcasted_shape(self.pos, self.vel, self.epoch, self.origin, + self.frame) + self.path_id = self._recover_id(path_id) + self.register() + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.pos, self.vel, self.epoch, self.origin, self.frame) - # Unpickled paths will always have temporary IDs to avoid conflicts def __getstate__(self): - return (self.pos, self.epoch, - Path.as_primary_path(self.origin), - Frame.as_primary_frame(self.frame)) + Fittable_.refresh(self) + self._cache_id() + return (self.pos, self.vel, self.epoch, Path.as_primary_path(self.origin), + Frame.as_primary_frame(self.frame), self._state_id()) def __setstate__(self, state): - self.__init__(*state) + (pos, vel, epoch, origin, frame, path_id) = state + self.__init__((pos, vel), epoch, origin, frame, path_id=path_id) + Fittable_.freeze(self) - #=========================================================================== - def event_at_time(self, time, quick=None): + ###################################################################################### + # Path API + ###################################################################################### + + def event_at_time(self, time, quick=False): """An Event corresponding to a specified time on this path. - Input: - time a time Scalar at which to evaluate the path. + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. - Return: an Event object containing (at least) the time, position - and velocity on the path. + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ return Event(time, (self.pos + (time-self.epoch) * self.vel, self.vel), self.origin, self.frame) -################################################################################ +############################################################################################ diff --git a/oops/path/multipath.py b/oops/path/multipath.py index d49d92ca..df180266 100755 --- a/oops/path/multipath.py +++ b/oops/path/multipath.py @@ -1,95 +1,98 @@ -################################################################################ +########################################################################################## # oops/path/multipath.py: Subclass MultiPath of class Path -################################################################################ +########################################################################################## import numpy as np from polymath import Qube, Scalar from oops.event import Event +from oops.fittable import Fittable_ from oops.frame.frame_ import Frame from oops.path.path_ import Path + class MultiPath(Path): """Gathers a set of paths into a single 1-D Path object.""" - PATH_IDS = {} + _PATH_IDS = {} - #=========================================================================== - def __init__(self, paths, origin=None, frame=None, path_id=None, - unpickled=False): + def __init__(self, paths, origin=None, frame=None, path_id=None): """Constructor for a MultiPath Path. - Input: - paths a tuple, list or 1-D ndarray of paths or path IDs. - origin a path or path ID identifying the common origin of all - paths. None to use the SSB. - frame a frame or frame ID identifying the reference frame. - None to use the default frame of the origin path. - path_id the name or ID under which this path will be registered. - A single '+' is changed to the ID of the first path with - a '+' appended. None to leave the path unregistered. - unpickled True if this path has been read from a pickle file. + Parameters: + paths (tuple or list): Paths or path IDs to include in this MultiPath. + origin (Path or str, optional): Path or ID identifying the common origin of + all paths. None to use the SSB. + frame (Frame or str, optional): Frame or ID identifying the reference frame. + None to use the default frame of the `origin` path. + path_id (str, optional): The ID to use; None to leave the path unregistered. + Use '+' for the names of all the paths appended with "+". """ # Interpret the inputs self.origin = Path.as_waypoint(origin) or Path.SSB - self.frame = Frame.as_wayframe(frame) or self.origin.frame + self.frame = Frame.as_wayframe(frame) or self.origin.frame self.paths = np.array(paths, dtype='object').ravel() self.shape = self.paths.shape - self.keys = set() for (index, path) in np.ndenumerate(self.paths): self.paths[index] = Path.as_path(path).wrt(self.origin, self.frame) # Fill in the path_id - self.path_id = path_id - - if self.path_id == '+': - self.path_id = self.paths[0].path_id + '+others' + if path_id == '+': + if len(self.paths) == 1: + self.path_id = self.paths[0].path_id + '+' + else: + self.path_id = '+'.join(p.path_id for p in self.paths) + else: + self.path_id = self._recover_id(path_id) + + self.register() + self._cache_id() + + # Support indexing by integer and numeric range + def __getitem__(self, i): + paths = self.paths[i] + if np.shape(paths) == (): + return paths + return MultiPath(paths, self.origin, self.frame, path_id=None) - # Update waypoint and path_id; register only if necessary - self.register(unpickled=unpickled) + ###################################################################################### + # Serialization support + ###################################################################################### - # Save in internal dict for name lookup upon serialization - if not unpickled and self.path_id in Path.WAYPOINT_REGISTRY: - key = tuple([path.path_id for path in self.paths]) - MultiPath.PATH_IDS[key] = self.path_id + def _path_key(self): + return list(self.paths) + [self.origin, self.frame] - # Unpickled paths will always have temporary IDs to avoid conflicts def __getstate__(self): - return (self.paths, - Path.as_primary_path(self.origin), - Frame.as_primary_frame(self.frame)) + Fittable_.refresh(self) + self._cache_id() + return (self.paths, Path.as_primary_path(self.origin), + Frame.as_primary_frame(self.frame), self._state_id()) def __setstate__(self, state): - # If this path matches a pre-existing path, re-use its ID - (paths, origin, frame) = state - key = tuple([path.path_id for path in paths]) - path_id = MultiPath.PATH_IDS.get(key, None) - self.__init__(paths, origin, frame, path_id=path_id, unpickled=True) + (paths, origin, frame, path_id) = state + self.__init__(paths, origin, frame, path_id=path_id) + Fittable_.freeze(self) - #=========================================================================== - def __getitem__(self, i): - slice = self.paths[i] - if np.shape(slice) == (): - return slice - return MultiPath(slice, self.origin, self.frame, path_id=None) + ###################################################################################### + # Path API + ###################################################################################### - #=========================================================================== def event_at_time(self, time, quick={}): - """An Event object corresponding to a specified Scalar time on this - path. - - The times are broadcasted across the shape of the MultiPath. - - Input: - time a time Scalar at which to evaluate the path. - quick False to disable QuickPaths; a dictionary to override - specific options. - - Return: an Event object containing the time, position and - velocity of the paths. + """An Event corresponding to a specified time on this path. + + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ # Broadcast everything to the same shape @@ -102,37 +105,35 @@ def event_at_time(self, time, quick={}): mask[...] = time.mask for (index, path) in np.ndenumerate(self.paths): - event = path.event_at_time(time.values[...,index], quick=quick) - pos[...,index,:] = event.pos.values - vel[...,index,:] = event.vel.values - mask[...,index] |= (event.pos.mask | event.vel.mask) + event = path.event_at_time(time.values[..., index], quick=quick) + pos[..., index, :] = event.pos.values + vel[..., index, :] = event.vel.values + mask[..., index] |= (event.pos.mask | event.vel.mask) if not np.any(mask): mask = False elif np.all(mask): mask = True - return Event(Scalar(time.values, mask), (pos,vel), - self.origin, self.frame) + return Event(Scalar(time.values, mask), (pos,vel), self.origin, self.frame) - #=========================================================================== def quick_path(self, time, quick={}): - """Override of the default quick_path method to return a MultiPath of - quick_paths. + """Override of the default quick_path method to return a MultiPath of quick_paths. A QuickPath operates by sampling the given path and then setting up an - interpolation grid to evaluate in its place. It can substantially speed - up performance when the same path must be evaluated many times, e.g., - for every pixel of an image. - - Input: - time a Scalar defining the set of times at which the frame is - to be evaluated. Alternatively, a tuple (minimum time, - maximum time, number of times) - quick if None or False, no QuickPath is created and self is - returned; if another dictionary, then the values - provided override the values in the default dictionary - QUICK.dictionary, and the merged dictionary is used. + interpolation grid to evaluate in its place. It can substantially speed up + performance when the same path must be evaluated many times, e.g., for every pixel + of an image. + + Parameters: + time (Scalar or array-like): The times at which the frame is to be evaluated. + Alternatively, a tuple (minimum time, maximum time, number of times) + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. + quick (dict or bool, optional): If False, no QuickPath is created and self is + returned. If a dictionary is given, its values override the values in the + default dictionary QUICK.dictionary and the merged dictionary is used. """ new_paths = [] @@ -142,4 +143,4 @@ def quick_path(self, time, quick={}): return MultiPath(new_paths, self.origin, self.frame) -################################################################################ +########################################################################################## diff --git a/oops/path/path_.py b/oops/path/path_.py index d16f544b..d2e9b418 100755 --- a/oops/path/path_.py +++ b/oops/path/path_.py @@ -6,8 +6,10 @@ import scipy.interpolate as interp from polymath import Qube, Scalar, Vector3 +from oops.cache import Cache from oops.config import QUICK, PATH_PHOTONS, LOGGING, PICKLE_CONFIG from oops.event import Event +from oops.fittable import Fittable_ from oops.frame.frame_ import Frame import oops.constants as constants @@ -186,7 +188,7 @@ def reset_registry(): Path.initialize_registry() #=========================================================================== - def register(self, shortcut=None, override=False, unpickled=False): + def register(self, shortcut=None, override=False): """Register a Path's definition. A shortcut makes it possible to calculate the state of one SPICE body @@ -199,10 +201,6 @@ def register(self, shortcut=None, override=False, unpickled=False): definition of any previous path with the same name. The old path will still exist, but it will not be available from the registry. - If unpickled is True and a path with the same ID is already in the - registry, then this path is not registered. Instead, its will share its - waypoint with the existing, registered path of the same name. - If the path ID is None, blank, or begins with '.', this is treated as a temporary path and is not registered. """ @@ -212,6 +210,8 @@ def register(self, shortcut=None, override=False, unpickled=False): Path.initialize_registry() path_id = self.path_id + if not hasattr(self, 'keys'): + self.keys = set() # Handle a shortcut if shortcut is not None: @@ -290,24 +290,21 @@ def register(self, shortcut=None, override=False, unpickled=False): if not hasattr(self, 'waypoint') or self.waypoint is None: self.waypoint = Path.WAYPOINT_REGISTRY[path_id] - # If this is not an unpickled path, make it the path returned by - # any of the standard keys. - if not unpickled: - # Cache (self.waypoint, self.origin); overwrite if necessary - key = (self.waypoint, self.origin) - if key in Path.PATH_CACHE: # remove an old version - Path.PATH_CACHE[key].keys -= {key} + # Cache (self.waypoint, self.origin); overwrite if necessary + key = (self.waypoint, self.origin) + if key in Path.PATH_CACHE: # remove an old version + Path.PATH_CACHE[key].keys -= {key} - Path.PATH_CACHE[key] = self - self.keys |= {key} + Path.PATH_CACHE[key] = self + self.keys |= {key} - # Cache (self.waypoint, self.origin, self.frame) - key = (self.waypoint, self.origin, self.frame) - if key in Path.PATH_CACHE: # remove an old version - Path.PATH_CACHE[key].keys -= {key} + # Cache (self.waypoint, self.origin, self.frame) + key = (self.waypoint, self.origin, self.frame) + if key in Path.PATH_CACHE: # remove an old version + Path.PATH_CACHE[key].keys -= {key} - Path.PATH_CACHE[key] = self - self.keys |= {key} + Path.PATH_CACHE[key] = self + self.keys |= {key} #=========================================================================== @staticmethod @@ -380,6 +377,70 @@ def is_registered(self): return (self.path_id in Path.WAYPOINT_REGISTRY) + #=========================================================================== + @staticmethod + def id_is_registered(path_id): + """True if the given path ID is registered.""" + + return (path_id in Path.WAYPOINT_REGISTRY) + + #=========================================================================== + @staticmethod + def id_is_temporary(frame_id): + """True if this is a temporary frame ID.""" + + return frame_id.startswith('TEMPORARY_') + + ############################################################################ + # Serialization support + ############################################################################ + + def _cache_id(self): + """Save this object's path ID in a class dictionary `_PATH_IDS`. + + This dictionary is keyed by a tuple of attributes of the object, as + returned by the method `_path_key`. It returns the path ID. + + If an object is constructed with a default paath ID, and an existing + path with the same key already exists, the path ID is reused (although + it will still be different, unique object). + """ + + if self.shape != (): # shapeless + return + if not Path.id_is_temporary(self.path_id): # permanent id + return + if self.path_id in self._PATH_IDS.values(): # don't overwrite + return + if not Fittable_.is_frozen(self): # frozen + return + + key = Cache.clean_key(self._path_key()) + if key in self._PATH_IDS: # don't overwrite + return + + self._PATH_IDS[key] = self.path_id + + def _recover_id(self, path_id=None): + """If the given path ID is None, check the class's `_PATH_IDS` + dictionary for a matching object and use its ID if found. + """ + + if path_id is not None: + return path_id + + if hasattr(self, '_PATH_IDS'): + key = Cache.clean_key(self._path_key()) + if key in self._PATH_IDS: + return self._PATH_IDS[key] + + return None + + def _state_id(self): + if Path.id_is_temporary(self.path_id): + return None + return self.path_id + ############################################################################ # Event operations ############################################################################ @@ -1314,7 +1375,7 @@ def __init__(self, path, interval, quickdict): self.slowpath = path self.waypoint = path.waypoint - self.path_id = path.path_id + self.path_id = path.path_id self.origin = path.origin self.frame = path.frame self.shape = () diff --git a/oops/path/pathshift.py b/oops/path/pathshift.py new file mode 100755 index 00000000..370c4f42 --- /dev/null +++ b/oops/path/pathshift.py @@ -0,0 +1,90 @@ +######################################################################################### +# oops/path/pathshift.py: Subclass PathShift of class Path +######################################################################################### + +from polymath import Scalar +from oops.fittable import Fittable +from oops.path.path_ import Path + + +class PathShift(Path, Fittable): + """A path defined by a time-shift along another path. + + PLACEHOLDER CODE. "CONCEPTUALLY" CORRECT BUT NOT YET TESTED. + """ + + _PATH_IDS = {} + + def __init__(self, dt, /, path, *, path_id=None): + """Constructor for a PathShift. + + Parameters: + dt (float): The initial time shift in seconds. + path (Path or str): The Path or ID to which the time shift applies. + path_id (str, optional): The new path ID to use; None to leave this path + unregistered. + """ + + self.dt = dt + self.path = path + + # Required attributes + self.origin = self.path.origin + self.frame = self.path.frame + self.shape = self.path.shape + self.path_id = self._recover_id(path_id) + + self.register() + self._cache_id() + + ###################################################################################### + # Fittable interface + ###################################################################################### + + def _set_params(self, params): + self.dt = params[0] + + @property + def _params(self): + return (self.dt,) + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.dt, self.path) + + def __getstate__(self): + self.refresh(self) + self._cache_id() + return (self.dt, Path.as_primary_path(self.path), self._state_id()) + + def __setstate__(self, state): + (dt, path, path_id) = state + self.__init__(dt, path, path_id=path_id) + self.freeze() + + ###################################################################################### + # Path API + ###################################################################################### + + def event_at_time(self, time, quick={}): + """An Event corresponding to a specified time on this path. + + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. + """ + + time = Scalar.as_scalar(time) + return self.path.event_at_time(time + self.dt, quick=quick) + +######################################################################################### diff --git a/oops/path/spicepath.py b/oops/path/spicepath.py index 74cd8c01..5ae5c054 100755 --- a/oops/path/spicepath.py +++ b/oops/path/spicepath.py @@ -1,10 +1,9 @@ -################################################################################ +########################################################################################## # oops/path_/spicepath.py: Subclass SpicePath of class Path -################################################################################ - -import numpy as np +########################################################################################## import cspyce +import numpy as np from polymath import Scalar, Vector3 from oops.event import Event @@ -12,6 +11,7 @@ from oops.path.path_ import Path, ReversedPath, RotatedPath import oops.spice_support as spice + class SpicePath(Path): """A Path subclass that returns information based on an SPICE SP kernel. @@ -19,35 +19,32 @@ class SpicePath(Path): a single origin. """ - # Set False to confirm that SpicePaths return the same results without - # shortcuts and with shortcuts + # Set False to confirm that SpicePaths return the same results without shortcuts and + # with shortcuts USE_SPICEPATH_SHORTCUTS = True - #=========================================================================== - def __init__(self, spice_id, spice_origin="SSB", spice_frame="J2000", - path_id=None, shortcut=None, unpickled=False): + _PATH_IDS = {} + + def __init__(self, spice_id, spice_origin='SSB', spice_frame='J2000', path_id=None, + shortcut=None): """Constructor for a SpicePath object. - Input: - spice_id the name or integer ID of the target body as used - in the SPICE toolkit. - spice_origin the name or integer ID of the origin body as - used in the SPICE toolkit; "SSB" for the Solar - System Barycenter by default. It may also be the - registered name of another SpicePath. - spice_frame the name or integer ID of the reference frame or of - the a body with which the frame is primarily - associated, as used in the SPICE toolkit. - path_id the name or ID under which the path will be - registered. By default, this will be the value of - spice_id if that is given as a string; otherwise - it will be the name as used by the SPICE toolkit. - shortcut If a shortcut is specified, then this is registered - as a shortcut definition; the other registered path - definitions are unchanged. - unpickled True if this object was read from a pickle file. If - so, then it will be treated as a duplicate of a - pre-existing SpicePath for the same SPICE ID. + Parameters: + spice_id (str or int): The SPICE toolkit identification of the target body. + spice_origin (str or int, optional): The SPICE toolkit identification of the + origin body; "SSB" for the Solar System Barycenter by default. It may also + be the registered name of another SpicePath. + spice_frame (str or int, optional): The SPICE toolkit identification of the + reference frame or of the a body with which the frame is primarily + associated. + path_id (str, optional): The name or ID under which the path will be + registered. By default, this will be the value of spice_id if that is + given as a string; otherwise it will be the name as used by the SPICE + toolkit. + shortcut (str, optional): If a shortcut name is specified, then this object is + registered as a shortcut definition; the other registered path definitions + are unchanged. Note that most shortcut definitions are handled + automatically, so users should not need to use this option. """ # Preserve the inputs @@ -59,10 +56,8 @@ def __init__(self, spice_id, spice_origin="SSB", spice_frame="J2000", # Interpret the SPICE IDs (self.spice_target_id, self.spice_target_name) = spice.body_id_and_name(spice_id) - (self.spice_origin_id, self.spice_origin_name) = spice.body_id_and_name(spice_origin) - self.spice_frame_name = spice.frame_id_and_name(spice_frame)[1] # Fill in the Path ID and save it in the global dictionary @@ -74,9 +69,9 @@ def __init__(self, spice_id, spice_origin="SSB", spice_frame="J2000", else: self.path_id = path_id - # Only save info in the PATH_TRANSLATION dictionary if it is not already - # there. We do not want to overwrite original definitions with those - # just read from pickle files. + # Only save info in the PATH_TRANSLATION dictionary if it is not already there. We + # do not want to overwrite original definitions with those just read from pickle + # files. if not shortcut: if self.spice_target_id not in spice.PATH_TRANSLATION: spice.PATH_TRANSLATION[self.spice_target_id] = self.path_id @@ -91,37 +86,47 @@ def __init__(self, spice_id, spice_origin="SSB", spice_frame="J2000", frame_id = spice.FRAME_TRANSLATION[self.spice_frame_name] self.frame = Frame.as_wayframe(frame_id) - # No shape, no keys + # No shape self.shape = () - self.keys = set() - self.shortcut = shortcut # Register the SpicePath; fill in the waypoint - self.register(shortcut, unpickled=unpickled) + self.register(shortcut=shortcut) + self._cache_id() + + ###################################################################################### + # Serialization support + ###################################################################################### + + def _path_key(self): + return (self.spice_target_id, self.spice_origin_id, self.spice_frame_name) def __getstate__(self): - return (self.spice_target_id, self.spice_origin_id, - self.spice_frame_name) + return (self.spice_target_id, self.spice_origin_id, self.spice_frame_name, + self._state_id()) def __setstate__(self, state): + (spice_target_id, spice_origin_id, spice_frame_name, path_id) = state + if path_id is None: + path_id = spice.PATH_TRANSLATION.get(spice_target_id, None) + self.__init__(spice_target_id, spice_origin_id, spice_frame_name, path_id=path_id) - (spice_target_id, spice_origin_id, spice_frame_name) = state - - # If this is a duplicate of a pre-existing SpicePath, make sure it gets - # assigned the pre-existing path ID and Waypoint - path_id = spice.PATH_TRANSLATION.get(spice_target_id, None) - self.__init__(spice_target_id, spice_origin_id, spice_frame_name, - path_id=path_id, unpickled=True) + ###################################################################################### + # Path API + ###################################################################################### - #=========================================================================== def event_at_time(self, time, quick={}): - """An Event corresponding to a specified Scalar time on this path. - - Input: - time a time Scalar at which to evaluate the path. - - Return: an Event object containing (at least) the time, position - and velocity of the path. + """An Event corresponding to a specified time on this path. + + Parameters: + time (Scalar, array-like, or float): Time at which to evaluate the path, in + seconds TDB. + quick (dict or bool, optional): A dictionary of parameter values to use as + overrides to the configured default QuickPath and QuickFrame parameters; + use False to disable the use of QuickPaths and QuickFrames. + + Returns: + (Event): Event object containing (at least) the time, position, and velocity + on the path. """ time = Scalar.as_scalar(time).as_float() @@ -132,14 +137,12 @@ def event_at_time(self, time, quick={}): # A single unmasked time can be handled quickly if time.shape == (): - (state, - lighttime) = cspyce.spkez(self.spice_target_id, - time.vals, - self.spice_frame_name, - 'NONE', - self.spice_origin_id) - - return Event(time, (state[0:3],state[3:6]), self.origin, self.frame) + (state, lighttime) = cspyce.spkez(self.spice_target_id, + time.vals, + self.spice_frame_name, + 'NONE', + self.spice_origin_id) + return Event(time, (state[0:3], state[3:6]), self.origin, self.frame) # Use a QuickPath if warranted, possibly making a recursive call if isinstance(quick, dict): @@ -152,7 +155,6 @@ def event_at_time(self, time, quick={}): self.spice_frame_name, 'NONE', self.spice_origin_id)[0] - pos = np.zeros(time.shape + (3,)) vel = np.zeros(time.shape + (3,)) pos[time.antimask] = state[...,0:3] @@ -170,17 +172,16 @@ def event_at_time(self, time, quick={}): # Convert to an Event and return return Event(time, (pos,vel), self.origin, self.frame) - #=========================================================================== def wrt(self, origin, frame=None): """Construct a path pointing from an origin to this target in any frame. - SpicePath overrides the default method to create quicker "shortcuts" - between SpicePaths. + SpicePath overrides the default method to create quicker "shortcuts" between + SpicePaths. - Input: - origin an origin Path object or its registered name. - frame a frame object or its registered ID. Default is to use - the frame of the origin's path. + Parameters: + origin (Path or str): The origin Path object or its ID. + frame (Frame or str, optional): The frame to use. Default is to use the frame + of the origin's path. """ # Use the slow method if necessary, for debugging @@ -208,7 +209,7 @@ def wrt(self, origin, frame=None): spice_frame_name = frame.spice_frame_name uses_spiceframe = True else: - uses_spiceframe = False # not a SpiceFrame + uses_spiceframe = False # not a SpiceFrame spice_frame_name = 'J2000' if uses_spiceframe: @@ -216,21 +217,18 @@ def wrt(self, origin, frame=None): else: frame_id = 'J2000' - shortcut = ('SPICE_SHORTCUT[' + str(self.path_id) + ',' + - str(origin_id) + ',' + - str(frame_id) + ']') - - result = SpicePath(self.spice_target_id, spice_origin_id, - spice_frame_name, self.path_id, shortcut) + shortcut = ('SPICE_SHORTCUT[' + str(self.path_id) + ',' + str(origin_id) + + ',' + str(frame_id) + ']') + result = SpicePath(self.spice_target_id, spice_origin_id, spice_frame_name, + self.path_id, shortcut=shortcut) # If the path uses a non-spice frame, add a rotated version if not uses_spiceframe: - shortcut = ('SHORTCUT_' + str(self.path_id) + '_' + - str(origin_id) + '_' + - str(frame.frame_id)) + shortcut = ('SHORTCUT_' + str(self.path_id) + '_' + str(origin_id) + + '_' + str(frame.frame_id)) result = RotatedPath(result, frame) - result.register(shortcut) + result.register(shortcut=shortcut) return result -################################################################################ +########################################################################################## diff --git a/tests/frame/test_frame.py b/tests/frame/test_frame.py index dbe126e4..1894157a 100755 --- a/tests/frame/test_frame.py +++ b/tests/frame/test_frame.py @@ -105,7 +105,7 @@ def runTest(self): self.assertEqual(xform.matrix.vals[2,2], 1) # Attempt to register a frame defined relative to an unregistered frame - self.assertRaises(ValueError, Rotation, -np.pi, 2, rot_neg60, 'NEG180') + self.assertRaises(ValueError, Rotation, -np.pi, 2, rot_neg60, frame_id='NEG180') # Link unregistered frame to registered frame identity = rot_neg120.wrt('J2000') diff --git a/tests/frame/test_poleframe.py b/tests/frame/test_poleframe.py index dae9781a..9b7e3c45 100755 --- a/tests/frame/test_poleframe.py +++ b/tests/frame/test_poleframe.py @@ -181,47 +181,6 @@ def runTest(self): diffs = node_vecs - node_vecs[0] self.assertTrue(diffs.norm().max() < 0.02) - # Test cache - poleframe = PoleFrame(planet, pole, cache_size=3) - self.assertTrue(poleframe.cache_size == 4) - self.assertTrue(poleframe.trim_size == 1) - self.assertTrue(len(poleframe.cache) == 0) - - pole_vecs = poleframe.transform_at_time(times).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 0) # don't cache vectors - self.assertFalse(poleframe.cached_value_returned) - - pole_vecs = poleframe.transform_at_time(100.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 1) - self.assertTrue(100. in poleframe.cache) - self.assertFalse(poleframe.cached_value_returned) - - pole_vecs = poleframe.transform_at_time(100.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 1) - self.assertTrue(poleframe.cached_value_returned) - - pole_vecs = poleframe.transform_at_time(200.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 2) - - pole_vecs = poleframe.transform_at_time(300.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 3) - - pole_vecs = poleframe.transform_at_time(400.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 4) - - pole_vecs = poleframe.transform_at_time(500.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 4) - self.assertTrue(100. not in poleframe.cache) - - pole_vecs = poleframe.transform_at_time(200.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 4) - self.assertTrue(poleframe.cached_value_returned) - - pole_vecs = poleframe.transform_at_time(100.).unrotate(Vector3.ZAXIS) - self.assertTrue(len(poleframe.cache) == 4) - self.assertFalse(poleframe.cached_value_returned) - self.assertTrue(300. not in poleframe.cache) - ######################################## if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tests/path/test_keplerpath.py b/tests/path/test_keplerpath.py index 2e37146c..aedde895 100755 --- a/tests/path/test_keplerpath.py +++ b/tests/path/test_keplerpath.py @@ -20,14 +20,16 @@ def _xyz_planet_derivative_test(kep, t, delta=1.e-7): pos_norm = xyz.norm().vals # Create new Kepler objects for tweaking the parameters - khi = kep.copy() - klo = kep.copy() + khi = KeplerPath(kep.planet, kep.epoch, kep.elements.copy(), kep.observer, + kep.wobbles) + klo = KeplerPath(kep.planet, kep.epoch, kep.elements.copy(), kep.observer, + kep.wobbles) - params = kep.get_params() + params = kep.get_elements() # Loop through parameters... - errors = np.zeros(np.shape(t) + (3,kep.nparams)) - for e in range(kep.nparams): + errors = np.zeros(np.shape(t) + (3,kep.nelements)) + for i,e in enumerate(range(kep.nelements)): # Tweak one parameter hi = params.copy() @@ -69,14 +71,16 @@ def _pos_derivative_test(kep, t, delta=1.e-5): pos_norm = event.pos.norm().vals # Create new Kepler objects for tweaking the parameters - khi = kep.copy() - klo = kep.copy() + khi = KeplerPath(kep.planet, kep.epoch, kep.elements.copy(), kep.observer, + kep.wobbles) + klo = KeplerPath(kep.planet, kep.epoch, kep.elements.copy(), kep.observer, + kep.wobbles) - params = kep.get_params() + params = kep.get_elements() # Loop through parameters... - errors = np.zeros(np.shape(t) + (3,kep.nparams)) - for e in range(kep.nparams): + errors = np.zeros(np.shape(t) + (3,kep.nelements)) + for e in range(kep.nelements): # Tweak one parameter hi = params.copy() diff --git a/tests/path/test_multipath.py b/tests/path/test_multipath.py index 3ceea796..4ff69851 100755 --- a/tests/path/test_multipath.py +++ b/tests/path/test_multipath.py @@ -33,7 +33,7 @@ def runTest(self): test = MultiPath([sun,earth,moon], "SSB", path_id='+') - self.assertEqual(test.path_id, "SUN+others") + self.assertEqual(test.path_id, "SUN+EARTH+MOON") self.assertEqual(test.shape, (3,)) # Single time diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100755 index 00000000..6c1c858d --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,148 @@ +########################################################################################## +# test/test_cache.py +########################################################################################## + +import unittest +import numpy as np +from oops.cache import Cache +from oops.frame import Rotation +from oops.path import LinearPath +from polymath import Scalar, Vector + + +class Test_Cache(unittest.TestCase): + + def test_clean_key(self): + + clean_key = Cache.clean_key + + key = 1 + self.assertEqual(clean_key(key), 1) + self.assertIsInstance(clean_key(key), int) + + key = 2. + self.assertEqual(clean_key(key), 2.) + self.assertIsInstance(clean_key(key), float) + + key = True + self.assertEqual(clean_key(key), True) + self.assertIsInstance(clean_key(key), bool) + + key = False + self.assertEqual(clean_key(key), False) + self.assertIsInstance(clean_key(key), bool) + + key = 'abc' + self.assertEqual(clean_key(key), 'abc') + + key = None + self.assertIs(clean_key(key), None) + + key = [1] + self.assertEqual(clean_key(key), (1,)) + + key = [2, 3., 'four'] + self.assertEqual(clean_key(key), (2, 3., 'four')) + + key = np.array(4.) + self.assertEqual(clean_key(key), ((), (4.,))) + self.assertIsInstance(clean_key(key)[1][0], np.float64) + + key = np.array([[1,2],[3,4]]) + self.assertEqual(clean_key(key), ((2,2), (1,2,3,4))) + self.assertIsInstance(clean_key(key)[-1][-1], np.int64) + + key = Scalar(3.14) + self.assertEqual(clean_key(key), ('Scalar', (), 3.14, False)) + + key = Scalar((2.718, 3.14)) + self.assertEqual(clean_key(key), ('Scalar', (2,), (2.718, 3.14), False)) + + key = Scalar((2.718, 3.14), True) + self.assertEqual(clean_key(key), ('Scalar', (2,), (2.718, 3.14), True)) + + key = Scalar((2.718, 3.14), (False,True)) + self.assertEqual(clean_key(key), ('Scalar', (2,), (2.718, 3.14), (False,True))) + + key = Vector([[1,2],[3,4]]) + self.assertEqual(clean_key(key), ('Vector', (2,), (1,2,3,4), False)) + + key = Vector([[1,2],[3,4]], (False,True)) + self.assertEqual(clean_key(key), ('Vector', (2,), (1,2,3,4), (False,True))) + + key = Vector([[1,2],[3,4]], drank=1) + self.assertEqual(clean_key(key), ('Vector', (), (1,2,3,4), False)) + + path = LinearPath((0,0,0), 0., 'SSB') + self.assertEqual(clean_key(path), path.waypoint) + test = {path.waypoint} # TypeError if unhashable + + frame = Rotation(1., 2, 'J2000') + self.assertEqual(clean_key(frame), frame.wayframe) + test = {frame.wayframe} # TypeError if unhashable + + key = (1, Vector([[1,2],[3,4]]), path, frame) + self.assertEqual(clean_key(key), (1, ('Vector', (2,), (1, 2, 3, 4), False), + path.waypoint, frame.wayframe)) + test = {key} # TypeError if unhashable + + def test_Cache(self): + + cache = Cache() + self.assertEqual(cache._maxsize, 100) + self.assertEqual(cache._extras, 10) + self.assertEqual(cache._limit, 110) + + for key in range(110): + cache[key] = str(key) + + self.assertEqual(len(cache), 110) + self.assertIn(0, cache) + self.assertIn(109, cache) + self.assertEqual(cache[0], '0') + self.assertEqual(cache[109], '109') + self.assertEqual(cache[-1], None) + + cache[110] = '110' + self.assertEqual(len(cache), 100) + self.assertEqual(cache[0], '0') + self.assertEqual(cache[1], None) + self.assertEqual(cache[11], None) + self.assertEqual(cache[12], '12') + self.assertEqual(cache[110], '110') + self.assertIn(0, cache) + self.assertNotIn(1, cache) + self.assertNotIn(11, cache) + self.assertIn(12, cache) + + # maxsize = 0 + cache = Cache(maxsize=0) + self.assertEqual(len(cache), 0) + cache['pi'] = 3.14 + self.assertEqual(len(cache), 0) + self.assertEqual(cache['pi'], None) + + # maxsize = 2 + cache = Cache(maxsize=2) + self.assertEqual(cache._maxsize, 2) + self.assertEqual(cache._extras, 3) + self.assertEqual(cache._limit, 5) + self.assertEqual(len(cache), 0) + + cache['pi'] = 3.14 + cache['e'] = 2.718 + cache['c'] = 3.e8 + cache['avogadro'] = 6.e23 + cache['h-bar'] = 1.054e-34 + self.assertEqual(len(cache), 5) + + ignore = cache['e'] + cache['G'] = 6.67e-11 + self.assertEqual(len(cache), 2) + self.assertIn('e', cache) + self.assertIn('G', cache) + +######################################## +if __name__ == '__main__': + unittest.main(verbosity=2) +########################################################################################## diff --git a/tests/test_fittable.py b/tests/test_fittable.py new file mode 100755 index 00000000..2cc185fb --- /dev/null +++ b/tests/test_fittable.py @@ -0,0 +1,179 @@ +########################################################################################## +# tests/test_fittable.py +########################################################################################## + +import unittest + +from oops.fittable import Fittable, Fittable_ + + +class A(Fittable): + def __init__(self, x): + self.x = x + self._refresh() + + def _refresh(self): + self.x_squared = self.x**2 + + def _set_params(self, params): + self.x = params[0] + + @property + def _params(self): + return (self.x,) + +class B: + def __init__(self, x, a): + self.x = x + self.a = a + self._refresh() + + def _refresh(self): + self.x_plus_a2 = self.x + self.a.x_squared + + +class C(Fittable): + def __init__(self, x, a): + self.x = x + self.a = a + self.c = self + self._refresh() + + def _set_params(self, params): + self.x = params[0] + + @property + def _params(self): + return (self.x,) + + def _refresh(self): + self.x_plus_a2_plus_cx_plus_ccx = (self.x + self.a.x_squared + self.c.x + + self.c.c.x) + + +class D: + def __init__(self, x): + self.x = x + + +class Test_Fittable(unittest.TestCase): + + def runTest(self): + + x = () + self.assertFalse(Fittable_.is_fittable(x)) + self.assertEqual(Fittable_.fittables(x), []) + self.assertEqual(Fittable_.version(x), 0) + + a = A(7) + self.assertEqual(a.x_squared, 49) + self.assertEqual(Fittable_.get_params(a), (7,)) + self.assertIsInstance(Fittable_.get_params(a), tuple) + self.assertEqual(Fittable_.get_params(a, as_dict=True), {'':(7,)}) + self.assertTrue(Fittable_.is_fittable(a)) + self.assertEqual(Fittable_.fittables(a), []) + self.assertEqual(Fittable_.version(a), 0) + + self.assertEqual(a.get_params(), (7,)) + self.assertEqual(a.get_params(as_dict=True), {'':(7,)}) + self.assertEqual(a.version(), 0) + + a.set_params([5]) + self.assertEqual(Fittable_.get_params(a), (5,)) + self.assertEqual(a.x_squared, 25) + self.assertEqual(Fittable_.fittables(a), []) + self.assertTrue(Fittable_.is_fittable(a)) + self.assertEqual(Fittable_.version(a), 1) + + b = B(1, a) + self.assertEqual(b.x_plus_a2, 26) + self.assertEqual(Fittable_.get_params(b), ()) + self.assertEqual(Fittable_.get_params(b, as_dict=True), {'a':(5,)}) + + self.assertTrue(Fittable_.is_fittable(b)) + Fittable_.set_params(b, {'a':7}) + self.assertEqual(b.x_plus_a2, 50) + self.assertEqual(Fittable_.fittables(b), ['a']) + self.assertEqual(Fittable_.fittables(b, frozen=True), ['a']) + self.assertEqual(Fittable_.fittables(b, frozen=False), ['a']) + + Fittable_.freeze(a) + self.assertEqual(Fittable_.fittables(b, frozen=True), ['a']) + self.assertEqual(Fittable_.fittables(b, frozen=False), []) + + a = A(5) + c = C(1, a) + self.assertEqual(Fittable_.fittables(c), ['a', 'c']) + self.assertTrue(Fittable_.is_fittable(c)) + self.assertEqual(c.x_plus_a2_plus_cx_plus_ccx, 28) + + a.set_params([6]) + self.assertTrue(c.refresh()) + self.assertEqual(c.x_plus_a2_plus_cx_plus_ccx, 39) + self.assertFalse(c.refresh()) + self.assertFalse(c.refresh()) + + a = A(5) + c = C(1, a) + self.assertEqual(c.x_plus_a2_plus_cx_plus_ccx, 28) + c.set_params({'':1, 'a':6}) + self.assertEqual(c.x_plus_a2_plus_cx_plus_ccx, 39) + self.assertFalse(c.refresh()) + self.assertEqual(Fittable_.get_params(a), (6,)) + self.assertEqual(Fittable_.get_params(a, as_dict=True), {'': (6,)}) + self.assertEqual(Fittable_.get_params(c), (1,)) + self.assertEqual(Fittable_.get_params(c, as_dict=True), {'': (1,), 'a': (6,)}) + + self.assertFalse(Fittable_.is_frozen(a)) + self.assertFalse(Fittable_.is_frozen(c)) + self.assertFalse(a.is_frozen()) + self.assertFalse(c.is_frozen()) + + Fittable_.freeze(a) + self.assertTrue(Fittable_.is_frozen(a)) + self.assertFalse(Fittable_.is_frozen(c)) + self.assertRaises(ValueError, Fittable_.set_params, a, 2) + + self.assertEqual(Fittable_.fittables(c), ['c']) + self.assertEqual(Fittable_.fittables(c, frozen=True), ['a', 'c']) + self.assertTrue(Fittable_.is_fittable(a)) + self.assertTrue(Fittable_.is_fittable(c)) + self.assertEqual(Fittable_.get_params(c), (1,)) + self.assertEqual(Fittable_.get_params(c, as_dict=True), {'': (1,)}) + self.assertEqual(Fittable_.get_params(c, as_dict=True, frozen=True), + {'': (1,), 'a': (6,)}) + self.assertEqual(c.get_params(), (1,)) + self.assertEqual(c.get_params(as_dict=True), {'': (1,)}) + self.assertEqual(c.get_params(as_dict=True, frozen=True), + {'': (1,), 'a': (6,)}) + + a = A(5) + c = C(1, a) + Fittable_.freeze(c) + self.assertTrue(Fittable_.is_frozen(a)) + self.assertTrue(Fittable_.is_frozen(c)) + self.assertTrue(a.is_frozen()) + self.assertTrue(c.is_frozen()) + self.assertRaises(ValueError, Fittable_.set_params, a, 2) + self.assertRaises(ValueError, Fittable_.set_params, c, 2) + + self.assertEqual(Fittable_.fittables(c), []) + self.assertEqual(Fittable_.fittables(c, frozen=True), ['a', 'c']) + self.assertTrue(Fittable_.is_fittable(a)) + self.assertTrue(Fittable_.is_fittable(c)) + + # type class has __data__ but is immutable + d = D(int) + self.assertEqual(len(Fittable_._FROZEN_IDS), 0) + self.assertEqual(Fittable_.get_params(d), ()) + self.assertTrue(Fittable_.is_frozen(d)) + self.assertEqual(len(Fittable_._FROZEN_IDS), 1) + self.assertRaises(ValueError, Fittable_.set_params, d, ()) + + d = D(float) + self.assertIs(Fittable_.freeze(d), None) # tests TypeError check in freeze() + +######################################### +if __name__ == '__main__': + unittest.main(verbosity=2) +########################################################################################## diff --git a/tests/unittester.py b/tests/unittester.py index 897b68b7..062e8911 100755 --- a/tests/unittester.py +++ b/tests/unittester.py @@ -13,7 +13,9 @@ from tests.path.unittester import * from tests.surface.unittester import * from tests.test_body import * +from tests.test_cache import * from tests.test_event import * +from tests.test_fittable import * from tests.test_transform import * from tests.test_utils import * diff --git a/tests/unittester_with_hosts.py b/tests/unittester_with_hosts.py index 480fd93e..6f2d6654 100755 --- a/tests/unittester_with_hosts.py +++ b/tests/unittester_with_hosts.py @@ -1,5 +1,5 @@ ################################################################################ -# tests/unittester.py +# tests/unittester_with_hosts.py ################################################################################ import unittest @@ -13,7 +13,9 @@ from tests.path.unittester import * from tests.surface.unittester import * from tests.test_body import * +from tests.test_cache import * from tests.test_event import * +from tests.test_fittable import * from tests.test_transform import * from tests.test_utils import * From b97fb7faf6d3c23e7b67d436040ce3024d598509 Mon Sep 17 00:00:00 2001 From: Robert French Date: Mon, 24 Nov 2025 10:21:23 -0800 Subject: [PATCH 2/2] Remove Python 3.9 --- .github/workflows/run-tests.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 8ee50c77..f4b4b55d 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -18,8 +18,6 @@ jobs: matrix: # MacOS: Python 3.8-3.10 does not currently work on MacOS. include: - - os: self-hosted-linux - python-version: "3.9" - os: self-hosted-linux python-version: "3.10" - os: self-hosted-linux @@ -34,8 +32,6 @@ jobs: python-version: "3.12" - os: self-hosted-macos python-version: "3.13" - - os: self-hosted-windows - python-version: "3.9" - os: self-hosted-windows python-version: "3.10" - os: self-hosted-windows