Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions docs/user-guide/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,74 @@ model = CESNForecaster(
)
```

## ESN Controller

The `ESNController` implements an Echo State Network for control tasks, where the model tracks a reference trajectory by optimizing a control input sequence.

### Basic Usage

```python
from orc.control import ESNController, train_ESNController
import jax.numpy as jnp

# Create ESN controller
model = ESNController(
data_dim=3, # System input/output dimension
control_dim=1, # Control input dimension
res_dim=500, # Reservoir dimension
seed=42
)

# Train with input and control sequences
trained_model, res_states = train_ESNController(
model,
train_seq, # shape: (seq_len, data_dim)
control_seq, # shape: (seq_len, control_dim)
target_seq=target_seq,
spinup=100,
beta=8e-8
)

# Compute optimal control to track a reference trajectory
control_opt = trained_model.compute_control(
control_seq=initial_guess, # shape: (fcast_len, control_dim)
res_state=res_states[-1],
ref_traj=reference # shape: (fcast_len, data_dim)
)
```

### Control Penalty Weights

The controller optimizes a penalty function with three terms, each weighted by a tunable parameter:

- `alpha_1` (default 100): **Trajectory deviation** — penalizes squared error between the controlled forecast and the reference trajectory. Higher values enforce tighter tracking.

- `alpha_2` (default 1): **Control magnitude** — penalizes the squared norm of control inputs. Higher values favor smaller control effort.

- `alpha_3` (default 5): **Control smoothness** — penalizes the squared norm of consecutive control input differences. Higher values produce smoother control signals.

```python
# Custom penalty weights
model = ESNController(
data_dim=3,
control_dim=1,
res_dim=500,
alpha_1=200, # Stricter trajectory tracking
alpha_2=0.5, # Allow larger control inputs
alpha_3=10, # Enforce smoother control
seed=42
)
```

### Key Parameters

In addition to the standard ESN parameters (`leak_rate`, `bias`, `embedding_scaling`, `Wr_density`, `Wr_spectral_radius`), the controller accepts:

- `control_dim`: Dimension of the control input
- `alpha_1`, `alpha_2`, `alpha_3`: Control penalty weights (see above)

For detailed API documentation, see the [Control API Reference](../api/control.md).

## Training Functions

Both ESN models use ridge regression for training the readout layer.
Expand Down
18 changes: 18 additions & 0 deletions src/orc/control/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class ESNController(RCControllerBase):
Trainable linear readout layer.
embedding : LinearEmbedding
Untrainable linear embedding layer for [input, control] concatenation.
alpha_1 : float
Weight for trajectory deviation penalty in control optimization.
alpha_2 : float
Weight for control magnitude penalty in control optimization.
alpha_3 : float
Weight for control derivative penalty in control optimization.

Methods
-------
Expand Down Expand Up @@ -57,6 +63,9 @@ def __init__(
Wr_density: float = 0.02,
Wr_spectral_radius: float = 0.8,
dtype: type = jnp.float64,
alpha_1: float = 100,
alpha_2: float = 1,
alpha_3: float = 5,
seed: int = 0,
quadratic: bool = False,
use_sparse_eigs: bool = True,
Expand Down Expand Up @@ -84,6 +93,12 @@ def __init__(
Largest eigenvalue of the reservoir adjacency matrix Wr.
dtype : type
Data type of the model (jnp.float64 is highly recommended).
alpha_1 : float
Weight for trajectory deviation penalty in control optimization.
alpha_2 : float
Weight for control magnitude penalty in control optimization.
alpha_3 : float
Weight for control derivative penalty in control optimization.
seed : int
Random seed for generating the PRNG key for the reservoir computer.
quadratic : bool
Expand Down Expand Up @@ -136,4 +151,7 @@ def __init__(
control_dim=control_dim,
dtype=dtype,
seed=seed,
alpha_1=alpha_1,
alpha_2=alpha_2,
alpha_3=alpha_3
)