diff --git a/TODO_linax.png b/TODO_linax.png new file mode 100644 index 0000000..d0ae837 Binary files /dev/null and b/TODO_linax.png differ diff --git a/commit_changes_visualization_NOA.py b/commit_changes_visualization_NOA.py new file mode 100644 index 0000000..88a0ef1 --- /dev/null +++ b/commit_changes_visualization_NOA.py @@ -0,0 +1,287 @@ +"""Visualization utilities for plotting time series and spatiotemporal data.""" + +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + + +def plot_time_series(U_lst, + t=None, + time_series_labels=None, + line_formats = None, + state_var_names=None, + t_lim = None, + figsize = (20,8), + x_label = r'$t$', + title = None, + **plot_kwargs): + """Plot time series data with separate panels for each state variable. + + Parameters + ---------- + U_lst : 2D array or list of 2D arrays + If a 2D array, shape should be (Nt, Nu) where Nu is the number of state + variables and Nt is the number of time points.If a list of 2D arrays, + each array should have shape (Nt, Nu) and represent different time series. + t : 1D array, optional + 1D array of time points. If None, the time points will be assumed to be + evenly spaced from 0 to Nt-1. + time_series_labels : list of strings, optional + List of strings containing the labels for each time series to be shown in + a legend. If None, no labels will be shown. + line_formats : list of strings, optional + List of strings containing the line formats for each time series. If None, + default line format will be used. + state_var_names : list of strings, optional + List of strings containing the names of the state variables. If None, + no y-axis labels will be shown. + t_lim : tuple, optional + Limit for the x-axis. If None, the x-axis will be set to the full + range of time points. + figsize : tuple, optional + Size of the figure to be created. Default is (20, 8). + x_label : string, optional + Label for the x-axis. Default is r'$t$'. + title : string, optional + Title of the plot. If None, no title is shown. + plot_kwargs : dict, optional + Additional arguments to pass to the plot function. + """ + # Input validation + if not isinstance(U_lst, list): + if not isinstance(U_lst, jnp.ndarray | np.ndarray) or U_lst.ndim != 2: + raise TypeError("U_lst must be a 2D JAX or NumPy array or a list of \ + 2D JAX/NumPy arrays.") + U_lst = [U_lst] + else: + if not all( + isinstance(U, jnp.ndarray | np.ndarray) and U.ndim == 2 for U in U_lst): + raise TypeError("All elements in U_lst must be 2D JAX or NumPy arrays.") + if not all(U.shape == U_lst[0].shape for U in U_lst): + raise ValueError("All arrays in U_lst must have the same shape.") + + + Nu = U_lst[0].shape[1] + Nt = U_lst[0].shape[0] + + if t is not None: + if not isinstance(t, jnp.ndarray | np.ndarray) or t.ndim != 1: + raise TypeError("t must be a 1D JAX or NumPy array.") + if len(t) != Nt: + raise ValueError(f"Length of t ({len(t)}) must match the number of time\ + points in U_lst ({Nt}).") + + if time_series_labels is not None: + if not isinstance(time_series_labels, list): + raise TypeError("time_series_labels must be a list of strings.") + if len(time_series_labels) != len(U_lst): + raise ValueError(f"Length of time_series_labels ({len(time_series_labels)})\ + must match the number of time series ({len(U_lst)}).") + + if line_formats is not None: + if not isinstance(line_formats, list): + raise TypeError("line_formats must be a list of strings.") + if len(line_formats) != len(U_lst): + raise ValueError(f"Length of line_formats ({len(line_formats)}) must \ + match the number of time series ({len(U_lst)}).") + + if state_var_names is not None: + if not isinstance(state_var_names, list): + raise TypeError("state_var_names must be a list of strings.") + if len(state_var_names) != Nu: + raise ValueError(f"Length of state_var_names ({len(state_var_names)}) \ + must match the number of state variables ({Nu}).") + + if t_lim is not None and not isinstance(t_lim, int | float): + raise TypeError("t_lim must be a number (int or float).") + + # defaults + plot_kwargs.setdefault('linewidth', 2) + + # setup time vectors + if t is None: + t = jnp.arange(Nt) + if t_lim is None: + t_lim = t[-1] + + # handle optional inputs + if time_series_labels is None: + time_series_labels = [None for _ in range(len(U_lst))] + if line_formats is None: + line_formats = ['-' for _ in range(len(U_lst))] + + # plot + fig, axs = plt.subplots(Nu, figsize = figsize) + # Ensure axs is always iterable, even if Nu=1 + if Nu == 1: + axs = [axs] + for i in range(Nu): + for j, Y in enumerate(U_lst): + axs[i].plot(t, Y[:, i], line_formats[j], label=time_series_labels[j], + **plot_kwargs) + axs[i].set_xlim([0, t_lim]) + if state_var_names is not None: + axs[i].set(ylabel=state_var_names[i]) + if time_series_labels[0] is not None: + axs[0].legend(loc='upper right') + axs[-1].set(xlabel=x_label) + if title is not None: + axs[0].set_title(title, fontsize=14) + plt.show() + +def imshow_1D_spatiotemp(U, + tN, + domain=(0,1), + figsize=(20, 6), + title = None, + x_label = r'$t$', + **imshow_kwargs): + """ + Plot 1D spatiotemporal data using imshow. + + Parameters + ---------- + U: 2D array + Data to be plotted, shape should be (Nt, Nx) where Nt is the number of time + points and Nx is the number of spatial points + tN: float + Final time of the simulation + domain: tuple of length 2 + Bounds of the spatial domain, default is (0, 1) + figsize: tuple + Size of the figure to be created, default is (20, 6) + title: string, optional + Title of the plot, if None no title is shown + x_label: string, optional + Label for the x-axis, default is r'$t$' + **imshow_kwargs: additional arguments to pass to imshow + """ + # Input validation + if not isinstance(U, jnp.ndarray | np.ndarray) or U.ndim != 2: + raise TypeError("U must be a 2D JAX or NumPy array.") + if not isinstance(domain, tuple) or len(domain) != 2: + raise TypeError("domain must be a tuple of length 2.") + if not all(isinstance(x, int | float) for x in domain): + raise TypeError("Both elements of domain must be numbers (int or float).") + + #set defaults for imshow + imshow_kwargs.setdefault('aspect', 'auto') + imshow_kwargs.setdefault('origin', 'lower') + imshow_kwargs.setdefault('cmap', 'RdGy') + imshow_kwargs.setdefault('extent', [0, tN, domain[0], domain[1]]) + + plt.figure(figsize=figsize, dpi=200) + plt.imshow(U.T, **imshow_kwargs) + plt.ylabel('x') + plt.xlabel(x_label) + if title is not None: + plt.title(title) + plt.colorbar(pad = 0.01, label = r'$u$') + plt.show() + +# TODO: add plot_attrator function to visualize 2D/3D attractors in state space + + + +def plot_in_3D_state_space(U_lst, + time_series_labels=None, + line_formats=["-", "r--"], + state_var_names=None, + figsize = (20,8), + title = None, + **plot_kwargs): + """ **STEP 1 - clear, one-line summary:** Plot time series data to visualize 3D attractors in state space. + + Parameters + ---------- + **STEP 2 - + : + ** + U_lst : 2D array or list of 2D arrays + If a 2D array, shape should be (Nt, 3) with 3 state + variables, where Nt is the number of time points. If a list of 2D arrays, + each array should have shape (Nt, 3) and represent different time series. + time_series_labels : list of strings, optional + List of strings containing the labels for each time series to be shown in + a legend. If None, no labels will be shown. + line_formats : list of strings, optional + List of strings containing the line formats for each time series. If None, + default line format will be used. + state_var_names : list of strings, optional + List of strings containing the names of the state variables. If None, + no axis labels will be shown. + figsize : tuple, optional + Size of the figure to be created. Default is (20, 8). + title : string, optional + Title of the plot. If None, no title is shown. + plot_kwargs : dict, optional + Additional arguments to pass to the plot function. + """ + + # **STEP 3** + # Input validation + if not isinstance(U_lst, list): + if not isinstance(U_lst, jnp.ndarray | np.ndarray) or U_lst.ndim != 2: + raise TypeError("U_lst must be a 2D JAX or NumPy array or a list of \ + 2D JAX/NumPy arrays.") + U_lst = [U_lst] + else: + if not all( + isinstance(U, jnp.ndarray | np.ndarray) and U.ndim == 2 for U in U_lst): + raise TypeError("All elements in U_lst must be 2D JAX or NumPy arrays.") + if not all(U.shape == U_lst[0].shape for U in U_lst): + raise ValueError("All arrays in U_lst must have the same shape.") + + Nu = U_lst[0].shape[1] + + if time_series_labels is not None: + if not isinstance(time_series_labels, list): + raise TypeError("time_series_labels must be a list of strings.") + if len(time_series_labels) != len(U_lst): + raise ValueError(f"Length of time_series_labels ({len(time_series_labels)})\ + must match the number of time series ({len(U_lst)}).") + + if line_formats is not None: + if not isinstance(line_formats, list): + raise TypeError("line_formats must be a list of strings.") + if len(line_formats) != len(U_lst): + raise ValueError(f"Length of line_formats ({len(line_formats)}) must \ + match the number of time series ({len(U_lst)}).") + + if state_var_names is not None: + if not isinstance(state_var_names, list): + raise TypeError("state_var_names must be a list of strings.") + if len(state_var_names) != Nu: + raise ValueError(f"Length of state_var_names ({len(state_var_names)}) \ + must match the number of state variables ({Nu}).") + + # **STEP 4** + # defaults + plot_kwargs.setdefault('linewidth', 2) + + # **STEP 5** + # handle optional inputs + if time_series_labels is None: + time_series_labels = [None for _ in range(len(U_lst))] + if line_formats is None: + line_formats = ['-' for _ in range(len(U_lst))] + + # **STEP 5 - CODE:** + # plot + fig, axs = plt.subplots(subplot_kw={"projection": "3d"}, figsize = figsize) + # Ensure axs is always iterable, even if Nu=1 + if Nu == 1: + axs = [axs] + for j, Y in enumerate(U_lst): + axs.plot(Y[:, 0], Y[:, 1], Y[:, 2], line_formats[j], + label=time_series_labels[j], + **plot_kwargs) + if state_var_names is not None: + axs.set(xlabel=state_var_names[0]) + axs.set(ylabel=state_var_names[1]) + axs.set(zlabel=state_var_names[2]) + if time_series_labels[0] is not None: + axs.legend(loc='upper right') + if title is not None: + axs.set_title(title, fontsize=14) + plt.show() \ No newline at end of file diff --git a/how to commit changes.txt b/how to commit changes.txt new file mode 100644 index 0000000..3fa603e --- /dev/null +++ b/how to commit changes.txt @@ -0,0 +1,58 @@ +0. activate miniconda environment ("myenv"): +conda activate myenv + + +1. go to directory: +cd C:\Users\nk675\Documents\GitHub\OpenReservoirComputing + + +2. run tests locally - in cmd: +pytest +ruff check .py + +2. run tests locally - in vscode (with miniconda virtual environment): +.... + + +3. repeat tests until all tests pass + + +4. go to path: +cd C:\Users\nk675\Documents\GitHub\OpenReservoirComputing + + +5. switch to branch - if new branch: +git checkout main +git pull +git checkout -b + +5. switch to branch - if existing branch: +git checkout + + +6. stage changes - in vscode: + -->> Source control (Ctrl+Shift+G) + -->> press "+" by files with changes to commit + -->> in "Message" write a short summary of changes + -->> press "Commit" + +6. stage changes - in cmd: +git add .py + + +7. push - if new branch: +git push origin + +7. push - if existing branch: +git push + + +8. pull - in GitHub browser: + -->> go to "Pull request" tab + -->> press "New Pull request" + -->> choose the base branch to commit to (left): usually "main". choose the compare branch to commit (right): + -->> title + description of change + -->> press "Create pull request" + + + diff --git a/installs_NOA.txt b/installs_NOA.txt new file mode 100644 index 0000000..87d2543 --- /dev/null +++ b/installs_NOA.txt @@ -0,0 +1,24 @@ +##install packages for a specific Python 3.11: + +python3.11 -m pip install + + +##works even if isn’t on your PATH: + +python -m + + + +##create new miniconda environment: + +conda create -n python=3.11 + + +##activate miniconda environment: + +conda activate + + +##install packages in conda (after activating the environment): +conda install + diff --git a/src/orc/models/esn.py b/src/orc/models/esn.py index 5ad363b..ab2badd 100644 --- a/src/orc/models/esn.py +++ b/src/orc/models/esn.py @@ -110,7 +110,7 @@ def __init__( key_driver, key_readout, key_embedding = jax.random.split(key, 3) # init in embedding, driver and readout - embedding = LinearEmbedding( + embedding = LinearEmbedding( ## NOA - get W_in in_dim=data_dim, res_dim=res_dim, seed=key_embedding[0], @@ -119,7 +119,7 @@ def __init__( locality=locality, periodic=periodic, ) - driver = ESNDriver( + driver = ESNDriver( ## NOA - get W res_dim=res_dim, seed=key_driver[0], leak=leak_rate, @@ -135,7 +135,7 @@ def __init__( out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks ) else: - readout = LinearReadout( + readout = LinearReadout( ## NOA - get init W_out (zeros) out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks ) @@ -402,6 +402,7 @@ def train_ESNForecaster( else: tot_seq = jnp.vstack((train_seq, target_seq[-1:])) + tot_res_seq = model.force(tot_seq, initial_res_state) res_seq = tot_res_seq[:-1] if isinstance(model.readout, NonlinearReadout): @@ -409,23 +410,11 @@ def train_ESNForecaster( else: res_seq_train = res_seq - if batch_size is None: - cmat = _solve_all_ridge_reg( - res_seq_train[spinup:], - target_seq[spinup:].reshape( - res_seq[spinup:].shape[0], res_seq.shape[1], -1 - ), - beta, - ) - else: - cmat = _solve_all_ridge_reg_batched( - res_seq_train[spinup:], - target_seq[spinup:].reshape( - res_seq[spinup:].shape[0], res_seq.shape[1], -1 - ), - beta, - batch_size, - ) + cmat = _solve_all_ridge_reg( + res_seq_train[spinup:], + target_seq[spinup:].reshape(res_seq[spinup:].shape[0], res_seq.shape[1], -1), + beta, + ) def where(m): return m.readout.wout diff --git a/test_RC_background.py b/test_RC_background.py new file mode 100644 index 0000000..6e09d2c --- /dev/null +++ b/test_RC_background.py @@ -0,0 +1,151 @@ +# imports +import equinox as eqx +import jax +import jax.numpy as jnp + +import orc +from orc.utils.regressions import ridge_regression + + + +# define Embedding function - two-layer MLP with an ELU activation function: u --ELU(W1u+b1)>> u1 --ELU(W2u1+b2)>> r +class ELUEmbedding(orc.embeddings.EmbedBase): + """ + orc.embeddings.EmbedBase expects us to specify: + input dimension in_dim, + reservoir dimension res_dim, + embed - define a method that maps from in_dim to res_dim + """ + + W1: jnp.ndarray # weight matrix for first layer of ELU MLP + W2: jnp.ndarray # weight matrix for second layer of ELU MLP + b1: jnp.ndarray # bias for first layer of ELU MLP + b2: jnp.ndarray # bias for second layer of ELU MLP + + def __init__(self, in_dim, res_dim, seed=0): + super().__init__(in_dim=in_dim, res_dim=res_dim) + rkey = jax.random.key(seed) + W1key, W2key, b1key, b2key = jax.random.split(rkey, 4) + # random initialization of parameters of ELU MLP + self.W1 = jax.random.normal(W1key, shape=(res_dim // 2, in_dim)) / jnp.sqrt((res_dim // 2) * in_dim) + self.W2 = jax.random.normal(W2key, shape=(res_dim, res_dim // 2))/ jnp.sqrt(res_dim * (res_dim // 2)) + self.b1 = jax.random.normal(b1key, shape=(res_dim // 2)) / jnp.sqrt(res_dim // 2) + self.b2 = jax.random.normal(b1key, shape=(res_dim)) / jnp.sqrt(res_dim) + + def embed(self, in_state): + in_state = self.W1 @ in_state + self.b1 + in_state = jax.nn.elu(in_state) + in_state = self.W2 @ in_state + self.b2 + in_state = jax.nn.elu(in_state) + return in_state + + +# define driver function - update equations of a gated recurrent unit (GRU), already implemented in equinox +class GRUDriver(orc.drivers.DriverBase): + + gru: eqx.Module + + def __init__(self, res_dim, seed=0): + super().__init__(res_dim=res_dim) + key = jax.random.key(seed) + self.gru = eqx.nn.GRUCell(res_dim, res_dim, key=key) + + def advance(self, res_state, in_state): + return self.gru(in_state, res_state) + + +# define readout function - W_O r (initialize W_O with zeros) +class Readout(orc.readouts.ReadoutBase): + res_dim: int + out_dim: int + W_O: jnp.ndarray + + def __init__(self, out_dim, res_dim): + super().__init__(out_dim, res_dim) + self.W_O = jnp.zeros((out_dim, res_dim)) + + def readout(self, res_state): + return self.W_O @ res_state + + +# define RC - embedding + driver + readout +class Forecaster(orc.rc.RCForecasterBase): + driver: orc.drivers.DriverBase + readout: orc.readouts.ReadoutBase + embedding: orc.embeddings.EmbedBase + + +########################################################################## + +# CODE: + + +# integrate Rossler system +tN = 200 +dt = 0.01 +u0 = jnp.array([-10, 2, 1], dtype=jnp.float64) +U,t = orc.data.rossler(tN=tN, dt=dt, u0=u0) +split_idx = int(U.shape[0] * 0.8) +U_train = U[:split_idx] +U_test = U[split_idx:] +t_test = t[split_idx:] +# orc.utils.visualization.plot_time_series( +# U, +# t, +# state_var_names=["$u_1$", "$u_2$", "$u_3$"], +# title="Rossler Data", +# x_label= "$t$", +# ) + + +# create RC +Nr = 500 +Nu = 3 +driver = GRUDriver(Nr) +embedding = ELUEmbedding(Nu, Nr) +readout = Readout(Nu, Nr) +model = Forecaster(driver, readout, embedding) + + +# teacher force the reservoir +forced_seq = model.force(U_train[:-1], res_state=jnp.zeros((Nr))) + +# shift the indices of the target sequence of training data +target_seq = U_train[1:] + +# set transient to discard +spinup = 200 + +# learn W_O +readout_mat = ridge_regression(forced_seq[spinup:], target_seq[spinup:], beta=1e-7) + + +#under the hood many ORC objects are instances of equinox.Module, which are immutable. Thus, we need to create a new Forecaster object with readout.W_O set to readout_mat: +# define where in the forecaster model we need to update +def where(model: Forecaster): + return model.readout.W_O +model = eqx.tree_at(where, model, readout_mat) + + +# perform forecast +U_pred = model.forecast_from_IC(fcast_len=U_test.shape[0], spinup_data=U_train[-spinup:]) + +# plot forecast +orc.utils.visualization.plot_time_series( + [U_test, U_pred], + (t_test - t_test[0]) * 0.07, + state_var_names=["$u_1$", "$u_2$", "$u_3$"], + time_series_labels=["True", "Predicted"], + line_formats=["-", "r--"], + x_label= r"$\lambda _1 t$", +) + + +###################################################################### + +orc.utils.visualization.plot_in_3D_state_space( + [U_test, U_pred], + state_var_names=["$u_1$", "$u_2$", "$u_3$"], + time_series_labels=["True", "Predicted"], + line_formats=["-", "r--"], +) \ No newline at end of file diff --git a/test_lorenz_NOA.py b/test_lorenz_NOA.py new file mode 100644 index 0000000..0770c1f --- /dev/null +++ b/test_lorenz_NOA.py @@ -0,0 +1,478 @@ +import numpy as np +import scipy as sp +from scipy.integrate import solve_ivp +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D # Import for 3D plotting +import matplotlib.gridspec as gridspec +import pandas as pd +import seaborn as sns +import scipy.optimize as optimize + + + +# Lorenz System Definition + +def get_Lorenz_data(timesteps, dt): + # # Generate Lorenz data: + + # Lorenz System parameters: + sigma = 10 + rho = 28 + beta = 8/3 + + # Time span for the solution + t_span = [0, timesteps] + # t_eval = np.linspace(t_span[0], t_span[1], num_of_dt) # a sequence of time points for which to solve for + # dt = t_eval[1] - t_eval[0] + + # solve ODE + sol = solve_ivp(lorenz_system, t_span, s0, args=(sigma, rho, beta), method='RK45', t_eval=np.arange(0, timesteps, dt),rtol=1e-12) + X = sol.y + + # # Plot Lorenz data: + # x, y, z = sol.y + # ax.scatter(x, y, z, s=0.5) # s is point size + # fig = plt.figure(figsize=(10, 6)) + # ax = fig.add_subplot(111, projection='3d') + # ax.set_title("Lorenz Attractor") + # ax.set_xlabel("X") + # ax.set_ylabel("Y") + # ax.set_zlabel("Z") + # plt.tight_layout() + # plt.show() + + return X, dt + +def lorenz_system(t, xyz, sigma, rho, beta): + x, y, z = xyz + dxdt = sigma * (y - x) + dydt = x * (rho - z) - y + dzdt = x * y - beta * z + return [dxdt, dydt, dzdt] + + +def get_W(n_reservoir, sparsity, rng, spectral_radius): + '''generate full rank sparse matrix, normalized so the largest eigenvalue is the spectral_radius''' + # while True: + W = rng.random((n_reservoir, n_reservoir)) - 0.5 # Generate dense random matrix with values in [-0.5, 0.5] + mask = rng.random(W.shape) < sparsity # Apply sparsity mask + W *= mask + # Check full rank + # if np.linalg.matrix_rank(W) == n_reservoir: + largest_EV_abs = np.max(np.abs(np.linalg.eigvals(W))) # largest eigenvalue of W + W = W * spectral_radius/largest_EV_abs # normalizing W so the largest eigenvalue is the spectral_radius + return W + +def get_W_diagonal(n_reservoir, rng, spectral_radius): + '''generate full rank diagonal matrix, normalized so the largest eigenvalue is the spectral_radius''' + diag_values = rng.uniform(-1, 1, size=n_reservoir) + W_diag = np.diag(diag_values) + largest_EV_abs = np.max(diag_values) # largest eigenvalue + W_diag = W_diag * spectral_radius/largest_EV_abs # normalizing + return W_diag + +def StepForward(u_k, state_k, time): + if u_k.ndim == 1: + u_k = u_k.reshape(feature_num,1) # convert u_k to a column vector + state_k = alpha* np.tanh(W.dot(state_k) + W_in.dot(u_k) + sigma_b * np.ones([n_reservoir,1])) + (1 - alpha*leak)*state_k # update reservoir state taking one step forward + time += 1 + return state_k, time + +def ridge_reg_loss_func(Y_gt, Y_pred, forecast_horizon_ind=None): + forecast_horizon_ind = len(Y_gt) if forecast_horizon_ind == None else forecast_horizon_ind + Y_gt = Y_gt[:,:forecast_horizon_ind] + Y_pred = Y_pred[:,:forecast_horizon_ind] + ridge_reg_loss = np.sqrt(np.mean((Y_pred - Y_gt) ** 2)) + ridge_reg_loss_time = np.sqrt(np.mean((Y_pred - Y_gt) ** 2, axis=0)) + return ridge_reg_loss, ridge_reg_loss_time + +def smape_loss_func_t(Y_gt, Y_pred, t, i_start, i_end): + """ + Computes the sMAPE between time indices i_start and i_end. + Y_gt, Y_pred: shape (dim, time) + t: 1D array of time values + """ + dt = t[1] - t[0] # assuming uniform time grid + y_p = Y_pred[:, i_start:i_end+1] + y_t = Y_gt[:, i_start:i_end+1] + numer = np.abs(y_p - y_t) + denom = np.abs(y_p) + np.abs(y_t) + coef = 2 / (t[i_end] - t[i_start]) + smape = coef * np.sum(np.mean(numer / denom, axis=0)) * dt + return smape + +def forecast_horizon_sMAPE(Y_gt, Y_pred, t, epsilon): + i_start = 0 + n = Y_gt.shape[1] + horizon_index = n - 1 # default - entire time series + smape_loss = 0 + for i_end in range(i_start + 1, n): + smape = smape_loss_func_t(Y_gt, Y_pred, t, i_start, i_end) + if smape > epsilon: + horizon_index = i_end + smape_loss = smape + break + return t[horizon_index], horizon_index, smape_loss + + +def plot_3D(Y, Y_pred, train_test): + fig = plt.figure(figsize=(10, 6)) + ax = fig.add_subplot(111, projection='3d') + ax.scatter(*Y, s=0.5, color='orange', label='Lorenz Train') # s is point size + ax.scatter(*Y_pred, s=0.5, color='blue', label='RC Prediction Train') + ax.set_title(f"Lorenz Attractor prediction - {train_test}") + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + plt.tight_layout() + plt.show() + + +def plot_xyz(Y, Y_pred, t, train_test): + fig, ax = plt.subplots(3,1,figsize=(8,6)) + fig.suptitle(f'Forecasting L63, N_res = {n_reservoir} ({train_test})', fontsize=16) + + ax[0].plot(t, Y[0, :], 'k-', linewidth=3, label='Truth') + ax[0].plot(t, Y_pred[0,:],'r--', linewidth=2, label='Forecast') + ax[0].set_xlabel('$t$') + ax[0].set_ylabel('$x$') + #ax[0].set_ylim([-20,20]) + + ax[1].plot(t, Y[1, :], 'k-', linewidth=3) + ax[1].plot(t, Y_pred[1,:],'r--',linewidth=2) + ax[1].set_xlabel('$ t$') + ax[1].set_ylabel('$y$') + + ax[2].plot(t, Y[2, :], 'k-', linewidth=3) + ax[2].plot(t, Y_pred[2,:],'r--',linewidth=2) + + ax[2].set_xlabel('$t$') + ax[2].set_ylabel('$z$') + plt.tight_layout() + ax[0].legend() + plt.show() + +def plot_EV(W, cont_discrete = ""): + # Compute eigenvalues + eigvals_W = np.linalg.eigvals(W) + spectral_radius_actual = np.max(np.abs(eigvals_W)) + + # Create plot + fig, ax = plt.subplots(figsize=(6, 6)) + ax.scatter(eigvals_W.real, eigvals_W.imag, s=20, label="Eigenvalues") + + # Add unit circle (if spectral_radius=1) or actual spectral radius + circle = plt.Circle((0, 0), spectral_radius_actual, color='blue', fill=False, linestyle='--', label=f"Actual spectral Radius = {spectral_radius_actual:.2f}") + ax.add_artist(circle) + + # Axis formatting + ax.axhline(0, color='gray', linestyle='--', linewidth=1) + ax.axvline(0, color='gray', linestyle='--', linewidth=1) + ax.set_xlabel("Real Part") + ax.set_ylabel("Imaginary Part") + ax.set_title(f"Eigenvalues of {cont_discrete} dynamical system (Reservoir)") + ax.grid(True) + ax.axis('equal') + ax.legend() + plt.tight_layout() + plt.show() + +def plot_rc_results(Y, Y_pred, t, train_test, params=None): + fig = plt.figure(figsize=(12, 8)) + ridge_loss_str = f"{params['ridge_loss']:.2f}" if params else "" + fig.suptitle(f'Forecasting L63, N_res = {n_reservoir} ({train_test}), ridge_loss = {ridge_loss_str}', fontsize=16) + params_str = ", ".join(f"{name} = {p:.2f}" for name, p in params.items()) if params else "" + fig.text(0.5, 0.93, params_str, + ha='center', fontsize=12, style='italic', color='gray') + + # Create 3 rows x 2 columns grid + gs = gridspec.GridSpec(3, 2, width_ratios=[2, 3]) # right column is wider + + # Time series x (ax0) + ax0 = fig.add_subplot(gs[0, 0]) + ax0.plot(t, Y[0, :], 'k-', linewidth=2, label='Truth') + ax0.plot(t, Y_pred[0, :], 'r--', linewidth=1.5, label='Forecast') + ax0.set_ylabel('$x$') + ax0.set_xlabel('$t$') + ax0.legend() + + # Time series y (ax1) + ax1 = fig.add_subplot(gs[1, 0]) + ax1.plot(t, Y[1, :], 'k-', linewidth=2) + ax1.plot(t, Y_pred[1, :], 'r--', linewidth=1.5) + ax1.set_ylabel('$y$') + ax1.set_xlabel('$t$') + + # Time series z (ax2) + ax2 = fig.add_subplot(gs[2, 0]) + ax2.plot(t, Y[2, :], 'k-', linewidth=2) + ax2.plot(t, Y_pred[2, :], 'r--', linewidth=1.5) + ax2.set_ylabel('$z$') + ax2.set_xlabel('$t$') + + # 3D plot spans all rows in right column (ax3) + ax3 = fig.add_subplot(gs[:, 1], projection='3d') + ax3.scatter(*Y, s=0.5, color='orange', label='Lorenz') + ax3.scatter(*Y_pred, s=0.5, color='blue', label='RC Prediction') + ax3.set_title("Lorenz Attractor") + ax3.set_xlabel("X") + ax3.set_ylabel("Y") + ax3.set_zlabel("Z") + ax3.legend() + + if params and "forecast_horizon" in params: + fh = params["forecast_horizon"] + for ax in [ax0, ax1, ax2]: + ax.axvline(x=fh, color='gray', linestyle='--', linewidth=1) + ax.text(fh, ax.get_ylim()[1]*0.9, f'{fh:.2f}', rotation=90, + va='top', ha='right', fontsize=8, color='gray') + + plt.tight_layout(rect=[0, 0, 1, 0.95]) # leave space for suptitle + plt.show() + +####################################### + +def create_RC(feature_num, n_reservoir, is_W_diagonal): + # Generate Reservoir Weights + rng = np.random.default_rng(0) # random seed + W_in = (rng.random((n_reservoir, feature_num)) - 0.5) * 2 * input_scaling # random input weight matrix between [-input_scaling,input_scaling]. (nodes in input with nodes in reservoir) + if is_W_diagonal: + W = get_W_diagonal(n_reservoir, rng, spectral_radius) + else: + W = get_W(n_reservoir, sparsity, rng, spectral_radius) # reservoir ajacency matrix - random recurrent weight matrix between [-0.5,0.5]. keep only a "sparsity" percent number of connections. normalizing W so the largest eigenvalue is the spectral_radius + return W, W_in + + +def Train_RC(): + # Training + sp = int(num_of_dt * 0.008) + + #initiate state: + state_k = np.zeros([n_reservoir,1]) + time = 0 + + #propagate states for transient indices: + for i in range(sp): + u_k = U_train[:,i] + state_k, time = StepForward(u_k, state_k, time) + + # Run Reservoir over training data (get state of reservoir at all times t without sp) + X_train = np.zeros([n_reservoir, len(U_train[0,sp+1:])]) + train_count = 0 + for i in range(sp,U_train.shape[1] - 1): + u_k = U_train[:,i] + state_k, time = StepForward(u_k, state_k, time) + X_train[:, train_count] = state_k[:,0] + train_count += 1 + if np.linalg.norm(state_k) > 1e10 or np.isnan(np.linalg.norm(state_k)): + print(f'Reservoir trajectory blew up after {str(i)} steps in training') + break + x_mat = X_train + u_mat = U_train[:, sp+1:] + + # train output weights using regularized least squares + lhs = (x_mat.dot(x_mat.T) + ridge_lambda * np.eye(n_reservoir)) + rhs = x_mat.dot(u_mat.T) + W_out = np.linalg.lstsq(lhs, rhs, rcond=None)[0].T + + # For Ploting the Train trajectory: + y_train_inds = list(range(train_ind_start+sp,train_ind_end-1)) + y_GT = X[:,y_train_inds] + Y_train_teach_pred = W_out.dot(X_train) + t = np.linspace(0, (len(y_train_inds))*dt, len(y_train_inds)) + + return time, state_k, W_out, train_len, train_ind_end, y_GT, Y_train_teach_pred, t + + +def predict_RC(time, state_k, W_out, train_len, train_ind_end): + # number of time steps to forecast + forecast_len = int(train_len * 0.2) + + # Run Reservoir over test data: and predict + test_count = 0 + X_test = np.zeros([n_reservoir, forecast_len]) + for i in range(forecast_len): + temp_u = W_out.dot(state_k) + state_k, time = StepForward(temp_u, state_k, time) + X_test[:, test_count] = state_k[:,0] + test_count += 1 + + if np.linalg.norm(state_k) > 1e10 or np.isnan(np.linalg.norm(state_k)): + print(f'Forecasted trajectory blew up after {str(i)} steps') + break + + # For Ploting the Prediction trajectory: + y_test_inds = list(range(train_ind_end,train_ind_end+forecast_len)) + y_GT = X[:, y_test_inds] + y_test_PRED = W_out.dot(X_test) + t = np.linspace(0, forecast_len*dt, forecast_len) * 0.9 ### characteristic Lyapunov timescale, 0.9 = lambda1 the largest Lyapunov exponent + + return y_GT, y_test_PRED, t + + +def get_fixed_point(cont_discrete): + def f_r(r): + '''continuous: + f(r) = (r(t+dt) - r(t)) / dt + discrete: + f(r) = r(t+dt) + ''' + r = r.reshape(-1, 1) # ensure column vector + if cont_discrete == "continuous": + return ((alpha * np.tanh(W @ r + W_in @ (W_out @ r) + sigma_b * np.ones((n_reservoir, 1))) - alpha * leak * r)/dt).flatten() # return as flat vector + elif cont_discrete == "discrete": + return (alpha * np.tanh(W @ r + W_in @ (W_out @ r) + sigma_b * np.ones((n_reservoir, 1))) + (1 - alpha * leak) * r).flatten() # return as flat vector + # Initial guess + x0 = np.zeros(n_reservoir) + # Find the root + sol = optimize.root(f_r, x0) + r_fixed = sol.x + return r_fixed + +def compute_jacobian(r, W, W_in, W_out, alpha, leak, sigma_b, cont_discrete): + '''continuous: + A_tilde = [[d(f_i(r))/d(r_j(t))]] = [[d((r_i(t+dt) + r_i(t))/dt)/d(r_j(t))]] + discrete: + A_tilde = [[d(f_i(r))/d(r_j(t))]] = [[d(r_i(t+dt))/d(r_j(t))]] + ''' + r = r.reshape(-1, 1) # ensure column vector + u = W @ r + W_in @ (W_out @ r) + sigma_b * np.ones((n_reservoir, 1)) + tanh_u = np.tanh(u) + sech2_u = 1 - tanh_u**2 # derivative of tanh + D = np.diagflat(sech2_u) # diagonal matrix + A = W + W_in @ W_out + if cont_discrete == "continuous": + J = (alpha * A @ D + ( - alpha * leak) * np.eye(n_reservoir)) / dt + elif cont_discrete == "discrete": + J = alpha * D @ A + (1 - alpha * leak) * np.eye(n_reservoir) + return J + + + + + + +########################################################################################### +########################################################################################### + +# INPUT: + +input_scaling = 0.08 # Scales how strongly the input affects the reservoir + +# set initial condition: +s0 = [-10.0, 1.0, 10.0] # [-0.1, 1, 1.05] ,[1, 1, 1] # Initial values for x,y,z + +# Timesteps +timesteps = 100 +num_of_dt = 10000 + + +# Reservoir Parameters +n_reservoir = 400 # Number of neurons in reservoir +sparsity = 0.02 # Fraction of reservoir weights that are non-zero (controls sparsity) +input_scaling = 0.084 # Scales how strongly the input affects the reservoir +spectral_radius = 0.8 # Controls the echo state property (reservoir stability - stable if smaller than 1) +is_W_diagonal = False + +leak = 0.6 +sigma_b = 1.6 # Input bias + +# ridge regression parameter +ridge_lambda = 8e-8 # Tikhonov regularization parameter - small vlaue ensures stability in case X not invertible [det(X)=0 ; (X^T X)^(-1)->inf] + + +# # state propagation in time parameters +# alpha_range = np.linspace(0, 2, 10) +# sigma_b_range = np.linspace(0.5, 2.5, 10) # Input bias + + +############################################################## + +#CODE: + +# integrate the Lorenz system: +X, dt = get_Lorenz_data(timesteps=100, dt=0.01) + +# train split: +feature_num = X.shape[0] +train_len = int(X.shape[1] * 0.8) +train_ind_start = 0 +train_ind_end = train_len +train_inds = list(range(train_ind_start,train_ind_end)) +U_train = X[:,train_inds] + + +results = [] +results_smape = [] +results_horizon = [] +results_score_temp = [] +max_forecast_horizon = 0 +for alpha in [0.67]: # [0.67] + print(f"alpha: {alpha}") + for sigma_b in [0.94]: # [0.94] + print(f"sigma_b: {sigma_b}") + + # Initialize and train the ESN: + W, W_in = create_RC(feature_num, n_reservoir, is_W_diagonal) + time, state_k, W_out, train_len, train_ind_end, y_train_GT, Y_train_teach_pred, t_train = Train_RC() + + # Forecast: + y_test_GT, y_test_PRED, t_test = predict_RC(time, state_k, W_out, train_len, train_ind_end) + +# Plot +plot_rc_results(y_test_GT, y_test_PRED, t_test, "Test", None) + + +###################################################### + +# # forecast horizon with sMAPE: +# forecast_horizon, forecast_horizon_ind, smape_loss = forecast_horizon_sMAPE(y_test_GT, y_test_PRED, t_test, epsilon) +# # ridge regression: +# test_loss, test_loss_time = ridge_reg_loss_func(y_test_GT, y_test_PRED, forecast_horizon_ind) +# print(f"test loss: {test_loss}") + +# #for param tuning +# results.append({"alpha": alpha, "sigma_b": sigma_b, "ridge_loss": test_loss}) +# results_smape.append({"alpha": alpha, "sigma_b": sigma_b, "ridge_loss": test_loss, "smape_loss": smape_loss}) +# results_horizon.append({"alpha": alpha, "sigma_b": sigma_b, "ridge_loss": test_loss, "forecast_horizon": forecast_horizon}) +# #combined smape and horizon +# results_score_temp.append((alpha, sigma_b, test_loss, smape_loss, forecast_horizon)) #loss function is last param for sorting +# max_forecast_horizon = forecast_horizon if forecast_horizon > max_forecast_horizon else max_forecast_horizon + +# #get results_score: +# loss_scores = [float("{:.2f}".format(lambda_score * smape_loss - (1 - lambda_score) * forecast_horizon / max_forecast_horizon)) for _, _, _, smape_loss, forecast_horizon in results_score_temp] +# print(f"loss scores: {loss_scores}") +# results_score = [] +# for i in range(len(loss_scores)): +# alpha, sigma_b, test_loss, smape_loss, forecast_horizon = results_score_temp[i] +# results_score.append({"alpha": alpha, "sigma_b": sigma_b, "ridge_loss": test_loss, "smape_loss": smape_loss, "forecast_horizon": forecast_horizon, "loss_score": loss_scores[i]}) + + +# #Plot eigenvalues of W (reservoir matrix): +# # plot_EV(W) + +# #Plot eigenvalues of linear approx of states equation (reservoir matrix): +# cont_discrete = "continuous" # "continuous", "discrete" + +# dt = t_test[1] - t_test[0] +# r_fixed = get_fixed_point(cont_discrete) +# print(f"r_fixed: {r_fixed}") +# J_at_fixed = compute_jacobian(r_fixed, W, W_in, W_out, alpha, leak, sigma_b, cont_discrete) +# print(f"Jacobian at fixed point: {J_at_fixed}") +# plot_EV(J_at_fixed, cont_discrete) + + + + +# ######################################################## +# # FIND OPTIMAL PARAMS: + +# params_best = min(results_score, key=lambda x: x[test_loss_func]) # sort by the loss function + + + + + + + + diff --git a/test_lorenz_ORC.py b/test_lorenz_ORC.py new file mode 100644 index 0000000..9049351 --- /dev/null +++ b/test_lorenz_ORC.py @@ -0,0 +1,30 @@ +import orc + +# integrate the Lorenz system +U,t = orc.data.lorenz63(tN=100, dt=0.01) + +# train-test split +test_perc = 0.2 +split_idx = int((1 - test_perc) * U.shape[0]) +U_train = U[:split_idx, :] +t_train = t[:split_idx] +U_test = U[split_idx:, :] +t_test = t[split_idx:] + +# Initialize and train the ESN +esn = orc.models.ESNForecaster(data_dim=3, res_dim=400) # ## NOA - get W_in and W, init W_out +esn, R = orc.models.train_ESNForecaster(esn, U_train) ## NOA - get W_out (drive states through reservoir, find W_out using Solution to ridge regression) + +# Forecast! +U_pred = esn.forecast(fcast_len=U_test.shape[0], res_state=R[-1]) # feed in the last reservoir state seen in training ## NOA - and drive states through reservoir using W_out*u as input + + +# Visualize +orc.utils.visualization.plot_time_series( + [U_test, U_pred], + (t_test - t_test[0]), # start time at 0 + state_var_names=["$u_1$", "$u_2$", "$u_3$"], + time_series_labels=["True", "Predicted"], + line_formats=["-", "r--"], + x_label= r"$t$", +) \ No newline at end of file diff --git a/tests/utils/test_visualization.py b/tests/utils/test_visualization.py index f92588f..b203246 100644 --- a/tests/utils/test_visualization.py +++ b/tests/utils/test_visualization.py @@ -24,6 +24,15 @@ def sample_spatiotemporal(): U = np.sin(T) * np.cos(2 * np.pi * X) return U +@pytest.fixture +def sample_3d_series(): + # Create a simple 3D time series: shape (Nt=100) + t = np.linspace(0, 10, 100) + x = np.sin(t) + y = np.cos(t) + z = np.sin(t) * np.cos(0.5 * t) + return np.column_stack((x, y, z)) + @patch('matplotlib.pyplot.show') def test_plot_time_series_basic(mock_show, sample_time_series): # Test with basic parameters @@ -73,16 +82,62 @@ def test_imshow_1D_spatiotemp_with_options(mock_show, sample_spatiotemporal): ) mock_show.assert_called_once() +@patch('matplotlib.pyplot.show') +def test_plot_in_3d_state_space_basic(mock_show, sample_3d_series): + # Test with basic parameters + vis.plot_in_3D_state_space(sample_3d_series) + mock_show.assert_called_once() + +@patch('matplotlib.pyplot.show') +def test_plot_in_3d_state_space_with_options(mock_show, sample_3d_series): + # Test with optional parameters + vis.plot_in_3D_state_space( + [sample_3d_series, sample_3d_series], + time_series_labels=["Data 1", "Data 2"], + line_formats=['-', '--'], + state_var_names=["x1", "x2", "x3"], + title="3D Attractor", + linewidth=1.5, + ) + mock_show.assert_called_once() + +@patch('matplotlib.pyplot.show') +def test_plot_in_3d_state_space_with_jax(mock_show): + # Test with JAX arrays + t = jnp.linspace(0, 10, 100) + x = jnp.sin(t) + y = jnp.cos(t) + z = jnp.sin(t) * jnp.cos(0.5 * t) + data = jnp.column_stack((x, y, z)) + vis.plot_in_3D_state_space(data) + mock_show.assert_called_once() + def test_input_validation(): - # Test input validation for both functions + # Test input validation for all three functions with pytest.raises(TypeError): vis.plot_time_series("not an array") with pytest.raises(TypeError): vis.plot_time_series(np.array([1, 2, 3])) # 1D array + with pytest.raises(ValueError): + a = np.zeros((10, 3)) + b = np.zeros((9, 3)) + vis.plot_time_series([a, b]) # mismatched shapes + with pytest.raises(TypeError): vis.imshow_1D_spatiotemp("not an array", 10) with pytest.raises(TypeError): vis.imshow_1D_spatiotemp(np.array([1, 2, 3]), 10) # 1D array + + with pytest.raises(TypeError): + vis.plot_in_3D_state_space("not an array") + + with pytest.raises(TypeError): + vis.plot_in_3D_state_space(np.array([1, 2, 3])) # 1D array + + with pytest.raises(ValueError): + a = np.zeros((10, 3)) + b = np.zeros((9, 3)) + vis.plot_in_3D_state_space([a, b]) # mismatched shapes