Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
419 changes: 419 additions & 0 deletions examples/classification.ipynb

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions src/orc/classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Classification with Reservoir Computers.
"""Classification with Reservoir Computers."""

This module is currently a placeholder for future classifier implementations.
"""
from orc.classifier.base import RCClassifierBase
from orc.classifier.models import ESNClassifier
from orc.classifier.train import train_ESNClassifier

__all__ = []
__all__ = [
"RCClassifierBase",
"ESNClassifier",
"train_ESNClassifier",
]
237 changes: 233 additions & 4 deletions src/orc/classifier/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,235 @@
"""Base classes for Reservoir Computer Classifiers.
"""Defines base classes for Reservoir Computer Classifiers."""

This module is currently a placeholder for future classifier implementations.
"""
from abc import ABC

# TODO: Implement RCClassifierBase and related classes
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from orc.drivers import DriverBase
from orc.embeddings import EmbedBase
from orc.readouts import ReadoutBase


class RCClassifierBase(eqx.Module, ABC):
"""Base class for reservoir computer classifiers.

Defines the interface for the reservoir computer classifier which includes
the driver, readout and embedding layers. The classifier forces input sequences
through the reservoir, extracts a feature vector from the reservoir states,
and applies a trained readout to produce class probabilities.

Attributes
----------
driver : DriverBase
Driver layer of the reservoir computer.
readout : ReadoutBase
Readout layer of the reservoir computer. Output dimension equals n_classes.
embedding : EmbedBase
Embedding layer of the reservoir computer.
in_dim : int
Dimension of the input data.
out_dim : int
Dimension of the output data (equals n_classes).
res_dim : int
Dimension of the reservoir.
n_classes : int
Number of classification classes.
state_repr : str
Reservoir state representation for classification.
"final" uses the last reservoir state, "mean" averages states after spinup.
dtype : type
Data type of the reservoir computer (jnp.float64 is highly recommended).
seed : int
Random seed for generating the PRNG key for the reservoir computer.

Methods
-------
force(in_seq, res_state)
Teacher forces the reservoir with the input sequence.
classify(in_seq, res_state)
Classify an input sequence, returning class probabilities.
set_readout(readout)
Replaces the readout layer of the reservoir computer.
set_embedding(embedding)
Replaces the embedding layer of the reservoir computer.
"""

driver: DriverBase
readout: ReadoutBase
embedding: EmbedBase
in_dim: int
out_dim: int
res_dim: int
n_classes: int
state_repr: str = "final"
dtype: Float = jnp.float64
seed: int = 0

def __init__(
self,
driver: DriverBase,
readout: ReadoutBase,
embedding: EmbedBase,
n_classes: int,
state_repr: str = "final",
dtype: Float = jnp.float64,
seed: int = 0,
) -> None:
"""Initialize RCClassifier Base.

Parameters
----------
driver : DriverBase
Driver layer of the reservoir computer.
readout : ReadoutBase
Readout layer of the reservoir computer.
embedding : EmbedBase
Embedding layer of the reservoir computer.
n_classes : int
Number of classification classes.
state_repr : str
Reservoir state representation for classification.
"final" uses the last reservoir state, "mean" averages states after spinup.
dtype : type
Data type of the reservoir computer (jnp.float64 is highly recommended).
seed : int
Random seed for generating the PRNG key for the reservoir computer.
"""
if state_repr not in ("final", "mean"):
raise ValueError(
f"state_repr must be 'final' or 'mean', got '{state_repr}'."
)
self.driver = driver
self.readout = readout
self.embedding = embedding
self.in_dim = self.embedding.in_dim
self.out_dim = self.readout.out_dim
self.res_dim = self.driver.res_dim
self.n_classes = n_classes
self.state_repr = state_repr
self.dtype = dtype
self.seed = seed

@eqx.filter_jit
def force(self, in_seq: Array, res_state: Array) -> Array:
"""Teacher forces the reservoir with the input sequence.

Parameters
----------
in_seq : Array
Input sequence to force the reservoir, (shape=(seq_len, in_dim)).
res_state : Array
Initial reservoir state, (shape=(res_dim,)).

Returns
-------
Array
Forced reservoir sequence, (shape=(seq_len, res_dim)).
"""

def scan_fn(state, in_vars):
proj_vars = self.embedding.embed(in_vars)
res_state = self.driver.advance(proj_vars, state)
return (res_state, res_state)

_, res_seq = jax.lax.scan(scan_fn, res_state, in_seq)
return res_seq

@eqx.filter_jit
def classify(
self, in_seq: Array, res_state: Array | None = None, spinup: int = 0
) -> Array:
"""Classify an input sequence.

Forces the reservoir with the input sequence, extracts a feature vector
from the reservoir states, and returns softmax class probabilities.

Parameters
----------
in_seq : Array
Input sequence to classify, (shape=(seq_len, in_dim)).
res_state : Array
Initial reservoir state, (shape=(res_dim,)).
spinup : int
Number of initial reservoir states to discard before extracting
features. Only used when state_repr="mean".

Returns
-------
Array
Class probabilities, (shape=(n_classes,)).
"""
if res_state is None:
res_state = jnp.zeros(self.res_dim)

res_seq = self.force(in_seq, res_state)

if self.state_repr == "final":
feature = res_seq[-1]
else: # "mean"
feature = jnp.mean(res_seq[spinup:], axis=0)

logits = self.readout.readout(feature)
return jax.nn.softmax(logits)

def __call__(self, in_seq: Array, res_state: Array, spinup: int = 0) -> Array:
"""Classify an input sequence, wrapper for `classify` method.

Parameters
----------
in_seq : Array
Input sequence to classify, (shape=(seq_len, in_dim)).
res_state : Array
Initial reservoir state, (shape=(res_dim,)).
spinup : int
Number of initial reservoir states to discard before extracting
features. Only used when state_repr="mean".

Returns
-------
Array
Class probabilities, (shape=(n_classes,)).
"""
return self.classify(in_seq, res_state, spinup)

def set_readout(self, readout: ReadoutBase) -> "RCClassifierBase":
"""Replace readout layer.

Parameters
----------
readout : ReadoutBase
New readout layer.

Returns
-------
RCClassifierBase
Updated model with new readout layer.
"""

def where(m: "RCClassifierBase"):
return m.readout

new_model = eqx.tree_at(where, self, readout)
return new_model

def set_embedding(self, embedding: EmbedBase) -> "RCClassifierBase":
"""Replace embedding layer.

Parameters
----------
embedding : EmbedBase
New embedding layer.

Returns
-------
RCClassifierBase
Updated model with new embedding layer.
"""

def where(m: "RCClassifierBase"):
return m.embedding

new_model = eqx.tree_at(where, self, embedding)
return new_model
141 changes: 137 additions & 4 deletions src/orc/classifier/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,139 @@
"""Classifier models based on Reservoir Computing.
"""Discrete ESN classifier implementation."""

This module is currently a placeholder for future classifier implementations.
"""
import jax
import jax.numpy as jnp

# TODO: Implement classifier models (e.g., ESNClassifier)
from orc.classifier.base import RCClassifierBase
from orc.drivers import ESNDriver
from orc.embeddings import LinearEmbedding
from orc.readouts import LinearReadout, QuadraticReadout

jax.config.update("jax_enable_x64", True)


class ESNClassifier(RCClassifierBase):
"""
Basic implementation of ESN for classification tasks.

Attributes
----------
res_dim : int
Reservoir dimension.
data_dim : int
Input data dimension.
n_classes : int
Number of classification classes.
driver : ESNDriver
Driver implementing the Echo State Network dynamics.
readout : ReadoutBase
Trainable linear readout layer with out_dim=n_classes.
embedding : LinearEmbedding
Untrainable linear embedding layer.

Methods
-------
force(in_seq, res_state)
Teacher forces the reservoir with the input sequence.
classify(in_seq, res_state)
Classify an input sequence, returning class probabilities.
set_readout(readout)
Replace readout layer.
set_embedding(embedding)
Replace embedding layer.
"""

res_dim: int
data_dim: int

def __init__(
self,
data_dim: int,
n_classes: int,
res_dim: int,
leak_rate: float = 0.6,
bias: float = 1.6,
embedding_scaling: float = 0.08,
Wr_density: float = 0.02,
Wr_spectral_radius: float = 0.8,
dtype: type = jnp.float64,
seed: int = 0,
quadratic: bool = False,
use_sparse_eigs: bool = True,
state_repr: str = "final",
) -> None:
"""
Initialize the ESN classifier.

Parameters
----------
data_dim : int
Dimension of the input data.
n_classes : int
Number of classification classes.
res_dim : int
Dimension of the reservoir adjacency matrix Wr.
leak_rate : float
Integration leak rate of the reservoir dynamics.
bias : float
Bias term for the reservoir dynamics.
embedding_scaling : float
Scaling factor for the embedding layer.
Wr_density : float
Density of the reservoir adjacency matrix Wr.
Wr_spectral_radius : float
Largest eigenvalue of the reservoir adjacency matrix Wr.
dtype : type
Data type of the model (jnp.float64 is highly recommended).
seed : int
Random seed for generating the PRNG key for the reservoir computer.
quadratic : bool
Use quadratic nonlinearity in output, default False.
use_sparse_eigs : bool
Whether to use sparse eigensolver for setting the spectral radius of wr.
Default is True, which is recommended to save memory and compute time. If
False, will use dense eigensolver which may be more accurate.
state_repr : str
Reservoir state representation for classification.
"final" uses the last reservoir state, "mean" averages states after spinup.
"""
# Initialize the random key and reservoir dimension
self.res_dim = res_dim
self.seed = seed
self.data_dim = data_dim
key = jax.random.PRNGKey(seed)
key_driver, key_readout, key_embedding = jax.random.split(key, 3)

embedding = LinearEmbedding(
in_dim=data_dim,
res_dim=res_dim,
seed=key_embedding[0],
scaling=embedding_scaling,
)
driver = ESNDriver(
res_dim=res_dim,
seed=key_driver[0],
leak=leak_rate,
bias=bias,
density=Wr_density,
spectral_radius=Wr_spectral_radius,
dtype=dtype,
use_sparse_eigs=use_sparse_eigs,
)
if quadratic:
readout = QuadraticReadout(
out_dim=n_classes, res_dim=res_dim, seed=key_readout[0]
)
else:
readout = LinearReadout(
out_dim=n_classes, res_dim=res_dim, seed=key_readout[0]
)

super().__init__(
driver=driver,
readout=readout,
embedding=embedding,
n_classes=n_classes,
state_repr=state_repr,
dtype=dtype,
seed=seed,
)
Loading