Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added TODO_linax.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
287 changes: 287 additions & 0 deletions commit_changes_visualization_NOA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Visualization utilities for plotting time series and spatiotemporal data."""

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np


def plot_time_series(U_lst,
t=None,
time_series_labels=None,
line_formats = None,
state_var_names=None,
t_lim = None,
figsize = (20,8),
x_label = r'$t$',
title = None,
**plot_kwargs):
"""Plot time series data with separate panels for each state variable.

Parameters
----------
U_lst : 2D array or list of 2D arrays
If a 2D array, shape should be (Nt, Nu) where Nu is the number of state
variables and Nt is the number of time points.If a list of 2D arrays,
each array should have shape (Nt, Nu) and represent different time series.
t : 1D array, optional
1D array of time points. If None, the time points will be assumed to be
evenly spaced from 0 to Nt-1.
time_series_labels : list of strings, optional
List of strings containing the labels for each time series to be shown in
a legend. If None, no labels will be shown.
line_formats : list of strings, optional
List of strings containing the line formats for each time series. If None,
default line format will be used.
state_var_names : list of strings, optional
List of strings containing the names of the state variables. If None,
no y-axis labels will be shown.
t_lim : tuple, optional
Limit for the x-axis. If None, the x-axis will be set to the full
range of time points.
figsize : tuple, optional
Size of the figure to be created. Default is (20, 8).
x_label : string, optional
Label for the x-axis. Default is r'$t$'.
title : string, optional
Title of the plot. If None, no title is shown.
plot_kwargs : dict, optional
Additional arguments to pass to the plot function.
"""
# Input validation
if not isinstance(U_lst, list):
if not isinstance(U_lst, jnp.ndarray | np.ndarray) or U_lst.ndim != 2:
raise TypeError("U_lst must be a 2D JAX or NumPy array or a list of \
2D JAX/NumPy arrays.")
U_lst = [U_lst]
else:
if not all(
isinstance(U, jnp.ndarray | np.ndarray) and U.ndim == 2 for U in U_lst):
raise TypeError("All elements in U_lst must be 2D JAX or NumPy arrays.")
if not all(U.shape == U_lst[0].shape for U in U_lst):
raise ValueError("All arrays in U_lst must have the same shape.")


Nu = U_lst[0].shape[1]
Nt = U_lst[0].shape[0]

if t is not None:
if not isinstance(t, jnp.ndarray | np.ndarray) or t.ndim != 1:
raise TypeError("t must be a 1D JAX or NumPy array.")
if len(t) != Nt:
raise ValueError(f"Length of t ({len(t)}) must match the number of time\
points in U_lst ({Nt}).")

if time_series_labels is not None:
if not isinstance(time_series_labels, list):
raise TypeError("time_series_labels must be a list of strings.")
if len(time_series_labels) != len(U_lst):
raise ValueError(f"Length of time_series_labels ({len(time_series_labels)})\
must match the number of time series ({len(U_lst)}).")

if line_formats is not None:
if not isinstance(line_formats, list):
raise TypeError("line_formats must be a list of strings.")
if len(line_formats) != len(U_lst):
raise ValueError(f"Length of line_formats ({len(line_formats)}) must \
match the number of time series ({len(U_lst)}).")

if state_var_names is not None:
if not isinstance(state_var_names, list):
raise TypeError("state_var_names must be a list of strings.")
if len(state_var_names) != Nu:
raise ValueError(f"Length of state_var_names ({len(state_var_names)}) \
must match the number of state variables ({Nu}).")

if t_lim is not None and not isinstance(t_lim, int | float):
raise TypeError("t_lim must be a number (int or float).")

# defaults
plot_kwargs.setdefault('linewidth', 2)

# setup time vectors
if t is None:
t = jnp.arange(Nt)
if t_lim is None:
t_lim = t[-1]

# handle optional inputs
if time_series_labels is None:
time_series_labels = [None for _ in range(len(U_lst))]
if line_formats is None:
line_formats = ['-' for _ in range(len(U_lst))]

# plot
fig, axs = plt.subplots(Nu, figsize = figsize)
# Ensure axs is always iterable, even if Nu=1
if Nu == 1:
axs = [axs]
for i in range(Nu):
for j, Y in enumerate(U_lst):
axs[i].plot(t, Y[:, i], line_formats[j], label=time_series_labels[j],
**plot_kwargs)
axs[i].set_xlim([0, t_lim])
if state_var_names is not None:
axs[i].set(ylabel=state_var_names[i])
if time_series_labels[0] is not None:
axs[0].legend(loc='upper right')
axs[-1].set(xlabel=x_label)
if title is not None:
axs[0].set_title(title, fontsize=14)
plt.show()

def imshow_1D_spatiotemp(U,
tN,
domain=(0,1),
figsize=(20, 6),
title = None,
x_label = r'$t$',
**imshow_kwargs):
"""
Plot 1D spatiotemporal data using imshow.

Parameters
----------
U: 2D array
Data to be plotted, shape should be (Nt, Nx) where Nt is the number of time
points and Nx is the number of spatial points
tN: float
Final time of the simulation
domain: tuple of length 2
Bounds of the spatial domain, default is (0, 1)
figsize: tuple
Size of the figure to be created, default is (20, 6)
title: string, optional
Title of the plot, if None no title is shown
x_label: string, optional
Label for the x-axis, default is r'$t$'
**imshow_kwargs: additional arguments to pass to imshow
"""
# Input validation
if not isinstance(U, jnp.ndarray | np.ndarray) or U.ndim != 2:
raise TypeError("U must be a 2D JAX or NumPy array.")
if not isinstance(domain, tuple) or len(domain) != 2:
raise TypeError("domain must be a tuple of length 2.")
if not all(isinstance(x, int | float) for x in domain):
raise TypeError("Both elements of domain must be numbers (int or float).")

#set defaults for imshow
imshow_kwargs.setdefault('aspect', 'auto')
imshow_kwargs.setdefault('origin', 'lower')
imshow_kwargs.setdefault('cmap', 'RdGy')
imshow_kwargs.setdefault('extent', [0, tN, domain[0], domain[1]])

plt.figure(figsize=figsize, dpi=200)
plt.imshow(U.T, **imshow_kwargs)
plt.ylabel('x')
plt.xlabel(x_label)
if title is not None:
plt.title(title)
plt.colorbar(pad = 0.01, label = r'$u$')
plt.show()

# TODO: add plot_attrator function to visualize 2D/3D attractors in state space



def plot_in_3D_state_space(U_lst,
time_series_labels=None,
line_formats=["-", "r--"],

Check failure on line 188 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (B006)

commit_changes_visualization_NOA.py:188:41: B006 Do not use mutable data structures for argument defaults

Check failure on line 188 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (B006)

commit_changes_visualization_NOA.py:188:41: B006 Do not use mutable data structures for argument defaults
state_var_names=None,
figsize = (20,8),
title = None,
**plot_kwargs):
""" **STEP 1 - clear, one-line summary:** Plot time series data to visualize 3D attractors in state space.

Check failure on line 193 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (E501)

commit_changes_visualization_NOA.py:193:89: E501 Line too long (110 > 88)

Check failure on line 193 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (E501)

commit_changes_visualization_NOA.py:193:89: E501 Line too long (110 > 88)

Parameters
----------
**STEP 2 -

Check failure on line 197 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:197:15: W291 Trailing whitespace

Check failure on line 197 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:197:15: W291 Trailing whitespace
<param_name> : <type and optional?>
<description and shape>**
U_lst : 2D array or list of 2D arrays
If a 2D array, shape should be (Nt, 3) with 3 state
variables, where Nt is the number of time points. If a list of 2D arrays,
each array should have shape (Nt, 3) and represent different time series.
time_series_labels : list of strings, optional
List of strings containing the labels for each time series to be shown in
a legend. If None, no labels will be shown.
line_formats : list of strings, optional
List of strings containing the line formats for each time series. If None,
default line format will be used.
state_var_names : list of strings, optional
List of strings containing the names of the state variables. If None,
no axis labels will be shown.
figsize : tuple, optional
Size of the figure to be created. Default is (20, 8).
title : string, optional
Title of the plot. If None, no title is shown.
plot_kwargs : dict, optional
Additional arguments to pass to the plot function.
"""

Check failure on line 219 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (D210)

commit_changes_visualization_NOA.py:193:5: D210 No whitespaces allowed surrounding docstring text

Check failure on line 219 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (D202)

commit_changes_visualization_NOA.py:193:5: D202 No blank lines allowed after function docstring (found 1)

Check failure on line 219 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (D210)

commit_changes_visualization_NOA.py:193:5: D210 No whitespaces allowed surrounding docstring text

Check failure on line 219 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (D202)

commit_changes_visualization_NOA.py:193:5: D202 No blank lines allowed after function docstring (found 1)

# **STEP 3**

Check failure on line 221 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:221:17: W291 Trailing whitespace

Check failure on line 221 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:221:17: W291 Trailing whitespace
# Input validation
if not isinstance(U_lst, list):
if not isinstance(U_lst, jnp.ndarray | np.ndarray) or U_lst.ndim != 2:
raise TypeError("U_lst must be a 2D JAX or NumPy array or a list of \
2D JAX/NumPy arrays.")
U_lst = [U_lst]
else:
if not all(
isinstance(U, jnp.ndarray | np.ndarray) and U.ndim == 2 for U in U_lst):
raise TypeError("All elements in U_lst must be 2D JAX or NumPy arrays.")
if not all(U.shape == U_lst[0].shape for U in U_lst):
raise ValueError("All arrays in U_lst must have the same shape.")

Check failure on line 234 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (W293)

commit_changes_visualization_NOA.py:234:1: W293 Blank line contains whitespace

Check failure on line 234 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (W293)

commit_changes_visualization_NOA.py:234:1: W293 Blank line contains whitespace
Nu = U_lst[0].shape[1]

if time_series_labels is not None:
if not isinstance(time_series_labels, list):
raise TypeError("time_series_labels must be a list of strings.")
if len(time_series_labels) != len(U_lst):
raise ValueError(f"Length of time_series_labels ({len(time_series_labels)})\
must match the number of time series ({len(U_lst)}).")

if line_formats is not None:
if not isinstance(line_formats, list):
raise TypeError("line_formats must be a list of strings.")
if len(line_formats) != len(U_lst):
raise ValueError(f"Length of line_formats ({len(line_formats)}) must \
match the number of time series ({len(U_lst)}).")

if state_var_names is not None:
if not isinstance(state_var_names, list):
raise TypeError("state_var_names must be a list of strings.")
if len(state_var_names) != Nu:
raise ValueError(f"Length of state_var_names ({len(state_var_names)}) \
must match the number of state variables ({Nu}).")

# **STEP 4**

Check failure on line 258 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:258:17: W291 Trailing whitespace

Check failure on line 258 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:258:17: W291 Trailing whitespace
# defaults
plot_kwargs.setdefault('linewidth', 2)

# **STEP 5**

Check failure on line 262 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:262:17: W291 Trailing whitespace

Check failure on line 262 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:262:17: W291 Trailing whitespace
# handle optional inputs
if time_series_labels is None:
time_series_labels = [None for _ in range(len(U_lst))]
if line_formats is None:
line_formats = ['-' for _ in range(len(U_lst))]

# **STEP 5 - CODE:**

Check failure on line 269 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.12, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:269:25: W291 Trailing whitespace

Check failure on line 269 in commit_changes_visualization_NOA.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Ruff (W291)

commit_changes_visualization_NOA.py:269:25: W291 Trailing whitespace
# plot
fig, axs = plt.subplots(subplot_kw={"projection": "3d"}, figsize = figsize)
# Ensure axs is always iterable, even if Nu=1
if Nu == 1:
axs = [axs]
for j, Y in enumerate(U_lst):
axs.plot(Y[:, 0], Y[:, 1], Y[:, 2], line_formats[j],
label=time_series_labels[j],
**plot_kwargs)
if state_var_names is not None:
axs.set(xlabel=state_var_names[0])
axs.set(ylabel=state_var_names[1])
axs.set(zlabel=state_var_names[2])
if time_series_labels[0] is not None:
axs.legend(loc='upper right')
if title is not None:
axs.set_title(title, fontsize=14)
plt.show()
58 changes: 58 additions & 0 deletions how to commit changes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
0. activate miniconda environment ("myenv"):
conda activate myenv


1. go to directory:
cd C:\Users\nk675\Documents\GitHub\OpenReservoirComputing


2. run tests locally - in cmd:
pytest
ruff check <relative_path_to_file_in_OpenReservoirComputing>.py

2. run tests locally - in vscode (with miniconda virtual environment):
....


3. repeat tests until all tests pass


4. go to path:
cd C:\Users\nk675\Documents\GitHub\OpenReservoirComputing


5. switch to branch - if new branch:
git checkout main
git pull
git checkout -b <new_branch_name>

5. switch to branch - if existing branch:
git checkout <existing_branch_name>


6. stage changes - in vscode:
-->> Source control (Ctrl+Shift+G)
-->> press "+" by files with changes to commit
-->> in "Message" write a short summary of changes
-->> press "Commit"

6. stage changes - in cmd:
git add <changed_file_name>.py


7. push - if new branch:
git push origin <new_branch_name>

7. push - if existing branch:
git push


8. pull - in GitHub browser:
-->> go to "Pull request" tab
-->> press "New Pull request"
-->> choose the base branch to commit to (left): usually "main". choose the compare branch to commit (right): <branch_name>
-->> title + description of change
-->> press "Create pull request"



24 changes: 24 additions & 0 deletions installs_NOA.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
##install packages for a specific Python 3.11:

python3.11 -m pip install <package_name>


##works even if <command> isn’t on your PATH:

python -m <command>



##create new miniconda environment:

conda create -n <env_name> python=3.11


##activate miniconda environment:

conda activate <env_name>


##install packages in conda (after activating the environment):
conda install <package_name>

Loading
Loading