Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion src/orc/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,100 @@ def imshow_1D_spatiotemp(U,
plt.colorbar(pad = 0.01, label = r'$u$')
plt.show()

# TODO: add plot_attrator function to visualize 2D/3D attractors in state space
def plot_in_3D_state_space(U_lst,
time_series_labels=None,
line_formats=None,
state_var_names=None,
figsize = (20,8),
title = None,
**plot_kwargs):
"""Plot time series data to visualize 3D attractors in state space.

Parameters
----------
U_lst : 2D array or list of 2D arrays
If a 2D array, shape should be (Nt, 3) with 3 state
variables, where Nt is the number of time points. If a list of 2D arrays,
each array should have shape (Nt, 3) and represent different time series.
time_series_labels : list of strings, optional
List of strings containing the labels for each time series to be shown in
a legend. If None, no labels will be shown.
line_formats : list of strings, optional
List of strings containing the line formats for each time series. If None,
default line format will be used.
state_var_names : list of strings, optional
List of strings containing the names of the state variables. If None,
no axis labels will be shown.
figsize : tuple, optional
Size of the figure to be created. Default is (20, 8).
title : string, optional
Title of the plot. If None, no title is shown.
plot_kwargs : dict, optional
Additional arguments to pass to the plot function.
"""
# Input validation
if not isinstance(U_lst, list):
if not isinstance(U_lst, jnp.ndarray | np.ndarray) or U_lst.ndim != 2:
raise TypeError("U_lst must be a 2D JAX or NumPy array or a list of \
2D JAX/NumPy arrays.")
U_lst = [U_lst]
else:
if not all(
isinstance(U, jnp.ndarray | np.ndarray) and U.ndim == 2 for U in U_lst):
raise TypeError("All elements in U_lst must be 2D JAX or NumPy arrays.")
if not all(U.shape == U_lst[0].shape for U in U_lst):
raise ValueError("All arrays in U_lst must have the same shape.")

Nu = U_lst[0].shape[1]

if time_series_labels is not None:
if not isinstance(time_series_labels, list):
raise TypeError("time_series_labels must be a list of strings.")
if len(time_series_labels) != len(U_lst):
raise ValueError(f"Length of time_series_labels ({len(time_series_labels)})\
must match the number of time series ({len(U_lst)}).")

if line_formats is not None:
if not isinstance(line_formats, list):
raise TypeError("line_formats must be a list of strings.")
if len(line_formats) != len(U_lst):
raise ValueError(f"Length of line_formats ({len(line_formats)}) must \
match the number of time series ({len(U_lst)}).")

if state_var_names is not None:
if not isinstance(state_var_names, list):
raise TypeError("state_var_names must be a list of strings.")
if len(state_var_names) != Nu:
raise ValueError(f"Length of state_var_names ({len(state_var_names)}) \
must match the number of state variables ({Nu}).")

# defaults
plot_kwargs.setdefault('linewidth', 2)

# handle optional inputs
if time_series_labels is None:
time_series_labels = [None for _ in range(len(U_lst))]
if line_formats is None:
line_formats = ['-' for _ in range(len(U_lst))]

# plot
fig, axs = plt.subplots(subplot_kw={"projection": "3d"}, figsize = figsize)
# Ensure axs is always iterable, even if Nu=1
if Nu == 1:
axs = [axs]
for j, Y in enumerate(U_lst):
axs.plot(Y[:, 0], Y[:, 1], Y[:, 2], line_formats[j],
label=time_series_labels[j], **plot_kwargs)
if state_var_names is not None:
axs.set(xlabel=state_var_names[0])
axs.set(ylabel=state_var_names[1])
axs.set(zlabel=state_var_names[2])
if time_series_labels[0] is not None:
axs.legend(loc='upper right')
if title is not None:
axs.set_title(title, fontsize=14)
plt.show()



# TODO: add plot_attrator function to visualize 2D attractors in state space