diff --git a/src/orc/utils/visualization.py b/src/orc/utils/visualization.py index 78cb567..8a4742c 100644 --- a/src/orc/utils/visualization.py +++ b/src/orc/utils/visualization.py @@ -178,4 +178,100 @@ def imshow_1D_spatiotemp(U, plt.colorbar(pad = 0.01, label = r'$u$') plt.show() -# TODO: add plot_attrator function to visualize 2D/3D attractors in state space +def plot_in_3D_state_space(U_lst, + time_series_labels=None, + line_formats=None, + state_var_names=None, + figsize = (20,8), + title = None, + **plot_kwargs): + """Plot time series data to visualize 3D attractors in state space. + + Parameters + ---------- + U_lst : 2D array or list of 2D arrays + If a 2D array, shape should be (Nt, 3) with 3 state + variables, where Nt is the number of time points. If a list of 2D arrays, + each array should have shape (Nt, 3) and represent different time series. + time_series_labels : list of strings, optional + List of strings containing the labels for each time series to be shown in + a legend. If None, no labels will be shown. + line_formats : list of strings, optional + List of strings containing the line formats for each time series. If None, + default line format will be used. + state_var_names : list of strings, optional + List of strings containing the names of the state variables. If None, + no axis labels will be shown. + figsize : tuple, optional + Size of the figure to be created. Default is (20, 8). + title : string, optional + Title of the plot. If None, no title is shown. + plot_kwargs : dict, optional + Additional arguments to pass to the plot function. + """ + # Input validation + if not isinstance(U_lst, list): + if not isinstance(U_lst, jnp.ndarray | np.ndarray) or U_lst.ndim != 2: + raise TypeError("U_lst must be a 2D JAX or NumPy array or a list of \ + 2D JAX/NumPy arrays.") + U_lst = [U_lst] + else: + if not all( + isinstance(U, jnp.ndarray | np.ndarray) and U.ndim == 2 for U in U_lst): + raise TypeError("All elements in U_lst must be 2D JAX or NumPy arrays.") + if not all(U.shape == U_lst[0].shape for U in U_lst): + raise ValueError("All arrays in U_lst must have the same shape.") + + Nu = U_lst[0].shape[1] + + if time_series_labels is not None: + if not isinstance(time_series_labels, list): + raise TypeError("time_series_labels must be a list of strings.") + if len(time_series_labels) != len(U_lst): + raise ValueError(f"Length of time_series_labels ({len(time_series_labels)})\ + must match the number of time series ({len(U_lst)}).") + + if line_formats is not None: + if not isinstance(line_formats, list): + raise TypeError("line_formats must be a list of strings.") + if len(line_formats) != len(U_lst): + raise ValueError(f"Length of line_formats ({len(line_formats)}) must \ + match the number of time series ({len(U_lst)}).") + + if state_var_names is not None: + if not isinstance(state_var_names, list): + raise TypeError("state_var_names must be a list of strings.") + if len(state_var_names) != Nu: + raise ValueError(f"Length of state_var_names ({len(state_var_names)}) \ + must match the number of state variables ({Nu}).") + + # defaults + plot_kwargs.setdefault('linewidth', 2) + + # handle optional inputs + if time_series_labels is None: + time_series_labels = [None for _ in range(len(U_lst))] + if line_formats is None: + line_formats = ['-' for _ in range(len(U_lst))] + + # plot + fig, axs = plt.subplots(subplot_kw={"projection": "3d"}, figsize = figsize) + # Ensure axs is always iterable, even if Nu=1 + if Nu == 1: + axs = [axs] + for j, Y in enumerate(U_lst): + axs.plot(Y[:, 0], Y[:, 1], Y[:, 2], line_formats[j], + label=time_series_labels[j], **plot_kwargs) + if state_var_names is not None: + axs.set(xlabel=state_var_names[0]) + axs.set(ylabel=state_var_names[1]) + axs.set(zlabel=state_var_names[2]) + if time_series_labels[0] is not None: + axs.legend(loc='upper right') + if title is not None: + axs.set_title(title, fontsize=14) + plt.show() + + + +# TODO: add plot_attrator function to visualize 2D attractors in state space