diff --git a/docs/user-guide/models.md b/docs/user-guide/models.md index 6515c15..56e0e09 100644 --- a/docs/user-guide/models.md +++ b/docs/user-guide/models.md @@ -176,6 +176,74 @@ model = CESNForecaster( ) ``` +## ESN Controller + +The `ESNController` implements an Echo State Network for control tasks, where the model tracks a reference trajectory by optimizing a control input sequence. + +### Basic Usage + +```python +from orc.control import ESNController, train_ESNController +import jax.numpy as jnp + +# Create ESN controller +model = ESNController( + data_dim=3, # System input/output dimension + control_dim=1, # Control input dimension + res_dim=500, # Reservoir dimension + seed=42 +) + +# Train with input and control sequences +trained_model, res_states = train_ESNController( + model, + train_seq, # shape: (seq_len, data_dim) + control_seq, # shape: (seq_len, control_dim) + target_seq=target_seq, + spinup=100, + beta=8e-8 +) + +# Compute optimal control to track a reference trajectory +control_opt = trained_model.compute_control( + control_seq=initial_guess, # shape: (fcast_len, control_dim) + res_state=res_states[-1], + ref_traj=reference # shape: (fcast_len, data_dim) +) +``` + +### Control Penalty Weights + +The controller optimizes a penalty function with three terms, each weighted by a tunable parameter: + +- `alpha_1` (default 100): **Trajectory deviation** — penalizes squared error between the controlled forecast and the reference trajectory. Higher values enforce tighter tracking. + +- `alpha_2` (default 1): **Control magnitude** — penalizes the squared norm of control inputs. Higher values favor smaller control effort. + +- `alpha_3` (default 5): **Control smoothness** — penalizes the squared norm of consecutive control input differences. Higher values produce smoother control signals. + +```python +# Custom penalty weights +model = ESNController( + data_dim=3, + control_dim=1, + res_dim=500, + alpha_1=200, # Stricter trajectory tracking + alpha_2=0.5, # Allow larger control inputs + alpha_3=10, # Enforce smoother control + seed=42 +) +``` + +### Key Parameters + +In addition to the standard ESN parameters (`leak_rate`, `bias`, `embedding_scaling`, `Wr_density`, `Wr_spectral_radius`), the controller accepts: + +- `control_dim`: Dimension of the control input +- `alpha_1`, `alpha_2`, `alpha_3`: Control penalty weights (see above) + +For detailed API documentation, see the [Control API Reference](../api/control.md). + ## Training Functions Both ESN models use ridge regression for training the readout layer. diff --git a/src/orc/control/models.py b/src/orc/control/models.py index 9f9a40e..a6c933e 100644 --- a/src/orc/control/models.py +++ b/src/orc/control/models.py @@ -29,6 +29,12 @@ class ESNController(RCControllerBase): Trainable linear readout layer. embedding : LinearEmbedding Untrainable linear embedding layer for [input, control] concatenation. + alpha_1 : float + Weight for trajectory deviation penalty in control optimization. + alpha_2 : float + Weight for control magnitude penalty in control optimization. + alpha_3 : float + Weight for control derivative penalty in control optimization. Methods ------- @@ -57,6 +63,9 @@ def __init__( Wr_density: float = 0.02, Wr_spectral_radius: float = 0.8, dtype: type = jnp.float64, + alpha_1: float = 100, + alpha_2: float = 1, + alpha_3: float = 5, seed: int = 0, quadratic: bool = False, use_sparse_eigs: bool = True, @@ -84,6 +93,12 @@ def __init__( Largest eigenvalue of the reservoir adjacency matrix Wr. dtype : type Data type of the model (jnp.float64 is highly recommended). + alpha_1 : float + Weight for trajectory deviation penalty in control optimization. + alpha_2 : float + Weight for control magnitude penalty in control optimization. + alpha_3 : float + Weight for control derivative penalty in control optimization. seed : int Random seed for generating the PRNG key for the reservoir computer. quadratic : bool @@ -136,4 +151,7 @@ def __init__( control_dim=control_dim, dtype=dtype, seed=seed, + alpha_1=alpha_1, + alpha_2=alpha_2, + alpha_3=alpha_3 )