Skip to content
33 changes: 16 additions & 17 deletions src/svdrom/dmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import svdrom.config as config
from svdrom.logger import setup_logger
from svdrom.svdrom_base import DecompositionModel

logger = setup_logger("DMD", "dmd.log")


class OptDMD:
class OptDMD(DecompositionModel):
def __init__(
self,
n_modes: int = -1,
Expand Down Expand Up @@ -100,7 +101,8 @@ def __init__(
logger.exception(msg)
raise ValueError(msg)

self._n_modes = n_modes
super().__init__(n_components=n_modes)

self._time_dimension = time_dimension
self._time_units = time_units
self._input_time_units = input_time_units or time_units
Expand All @@ -127,7 +129,7 @@ def __init__(
@property
def n_modes(self) -> int:
"""Number of DMD modes (read-only)."""
return self._n_modes
return self._n_components

@property
def time_dimension(self) -> str:
Expand Down Expand Up @@ -604,22 +606,26 @@ def fit(
of the Hankel pre-processed matrix via the TruncatedSVD class.
"""
self._check_svd_inputs(u, s, v)
if self._n_modes > len(s):
if self._n_components > len(s):
msg = (
"The requested number of DMD modes exceeds the number "
"of available SVD components."
)
logger.exception(msg)
raise ValueError(msg)
if self._n_modes == -1:
self._n_modes = len(s)
u, s, v = u[:, : self._n_modes], s[: self._n_modes], v[: self._n_modes, :]
if self._n_components == -1:
self._n_components = len(s)
u, s, v = (
u[:, : self._n_components],
s[: self._n_components],
v[: self._n_components, :],
)
if config.get("hankel_coord_name") in u.coords:
self._hankel_d = len(np.unique(u[config.get("hankel_coord_name")].values))
self._hankel_time_mapping = v.attrs[config.get("hankel_time_mapping_attr")]

bopdmd = BOPDMD(
svd_rank=self._n_modes,
svd_rank=self._n_components,
use_proj=True,
proj_basis=u.data,
num_trials=self._num_trials,
Expand Down Expand Up @@ -1046,10 +1052,7 @@ def forecast(
Xarrays are NumPy-backed or Dask-backed depending on the
'memory_limit_bytes' parameter.
"""
if self._solver is None:
msg = "The OptDMD model must be fitted before forecasting."
logger.exception(msg)
raise RuntimeError(msg)
self._check_is_fitted(["_solver"])
try:
t_forecast, time_forecast = self._generate_forecast_time_vector(
forecast_span, dt
Expand Down Expand Up @@ -1280,11 +1283,7 @@ def reconstruct(
Reconstruct the whole training dataset, which could be huge:
>>> optdmd.reconstruct()
"""
if self._solver is None:
msg = "The OptDMD model must be fitted before reconstructing."
logger.exception(msg)
raise RuntimeError(msg)

self._check_is_fitted(["_solver"])
try:
t_reconstruct, time_reconstruct, lags = (
self._generate_reconstruct_time_vector(t)
Expand Down
64 changes: 26 additions & 38 deletions src/svdrom/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import svdrom.config as config
from svdrom.logger import setup_logger
from svdrom.svdrom_base import DecompositionModel

logger = setup_logger("SVD", "svd.log")


class TruncatedSVD:
class TruncatedSVD(DecompositionModel):
def __init__(
self,
n_components: int,
Expand Down Expand Up @@ -84,8 +85,8 @@ def __init__(
function. The 'randomized' algorithm is implemented via Dask's
`dask.array.linalg.svd_compressed` function.
"""
super().__init__(n_components=n_components) # Inherit from baseclass

self._n_components = n_components
self._algorithm = algorithm
self._compute_u = compute_u
self._compute_v = compute_v
Expand Down Expand Up @@ -159,21 +160,21 @@ def _check_array(self, X: xr.DataArray):
"The input array must be 2-dimensional. "
f"Got a {X.ndim}-dimensional array."
)
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)
if self._n_components >= X.shape[1]:
msg = (
"n_components must be less than n_features. "
f"Got n_components: {self.n_components}, n_features: {X.shape[1]}."
)
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)
if not isinstance(X.data, da.Array):
msg = (
f"The {self.__class__.__name__} class only supports Dask-backed "
f"Xarray DataArrays. Got {type(X.data)} instead."
)
logger.exception(msg)
logger.error(msg)
raise TypeError(msg)

def _singular_vectors_to_dataarray(
Expand Down Expand Up @@ -213,7 +214,7 @@ def _singular_vectors_to_dataarray(
"Cannot transform singular vectors into Xarray DataArray. "
"Shape of singular_vectors does not match X."
)
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)
return xr.DataArray(
singular_vectors,
Expand All @@ -227,7 +228,7 @@ def fit(
self,
X: xr.DataArray,
**kwargs,
) -> None:
) -> "TruncatedSVD":
"""Fit the SVD model to the input array.

Parameters
Expand All @@ -244,7 +245,7 @@ def fit(
f"Unsupported algorithm: {self._algorithm}. "
"Supported algorithms are 'tsqr' and 'randomized'."
)
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)

self._check_array(X)
Expand Down Expand Up @@ -301,14 +302,14 @@ def fit(
self._v = self._singular_vectors_to_dataarray(v, X)
self._explained_var_ratio = explained_var_ratio

return self

def compute_u(self) -> None:
"""Compute left singular vectors if they are
still a lazy Dask collection.
"""
if self._u is None:
msg = "You must call fit() before calling compute_u()."
logger.exception(msg)
raise ValueError(msg)
self._check_is_fitted(["_u"])
assert self._u is not None # needed for mypy check
msg = "Computing left singular vectors..."
logger.info(msg)
self._u = self._u.compute()
Expand All @@ -319,10 +320,8 @@ def compute_v(self) -> None:
"""Compute right singular vectors if they are
still a lazy Dask collection.
"""
if self._v is None:
msg = "You must call fit() before calling compute_v()."
logger.exception(msg)
raise ValueError(msg)
self._check_is_fitted(["_v"])
assert self._v is not None # needed for mypy check
msg = "Computing right singular vectors..."
logger.info(msg)
self._v = self._v.compute()
Expand All @@ -333,10 +332,7 @@ def compute_var_ratio(self) -> None:
"""Compute the ratio of explained variance if it is
still a lazy Dask collection.
"""
if self._explained_var_ratio is None:
msg = "You must call fit() before calling compute_var_ratio()."
logger.exception(msg)
raise ValueError(msg)
self._check_is_fitted(["_explained_var_ratio"])
msg = "Computing explained variance ratio..."
logger.info(msg)
if isinstance(self._explained_var_ratio, da.Array):
Expand All @@ -363,7 +359,7 @@ def transform(self, X: xr.DataArray) -> xr.DataArray:
"Computed right singular vectors are "
"required in order to call transform()."
)
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)
X_da = X.data
try:
Expand All @@ -383,7 +379,7 @@ def transform(self, X: xr.DataArray) -> xr.DataArray:
raise ValueError(msg) from e
return self._singular_vectors_to_dataarray(X_da_transformed, X)

def reconstruct_snapshot(
def reconstruct(
self,
snapshot: int | str,
snapshot_dim: str = "time",
Expand All @@ -408,23 +404,15 @@ def reconstruct_snapshot(
Examples
--------
# Reconstructs the first snapshot
>>> tsvd.reconstruct_snapshot(0)
>>> tsvd.reconstruct(0)

# Reconstructs all snapshots with label '2017-01-01'
>>> tsvd.reconstruct_snapshot("2017-01-01")
>>> tsvd.reconstruct("2017-01-01")
"""

if not (
isinstance(self._u, xr.DataArray)
and isinstance(self._v, xr.DataArray)
and isinstance(self._s, np.ndarray)
):
msg = (
"Computed left and right singular vectors and "
"singular values are required before calling reconstruct()."
)
logger.exception(msg)
raise ValueError(msg)
self._check_is_fitted(["_u", "_v", "_s"])
assert self._u is not None # needed for mypy check
assert self._v is not None
assert self._s is not None

if isinstance(snapshot, int):
try:
Expand All @@ -433,7 +421,7 @@ def reconstruct_snapshot(
if snapshot_dim in self._u.dims:
return (self._u[snapshot_dim, :] * self._s) @ self._v
msg = f"Snapshot dimension '{snapshot_dim}' does not exist."
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)
except IndexError as e:
msg = (
Expand All @@ -449,7 +437,7 @@ def reconstruct_snapshot(
if snapshot_dim in self._u.dims:
return (self._u.loc[snapshot_dim, :] * self._s) @ self._v
msg = f"Snapshot dimension '{snapshot_dim}' does not exist."
logger.exception(msg)
logger.error(msg)
raise ValueError(msg)
except KeyError as e:
msg = f"Snapshot '{snapshot}' not found in the right singular vectors."
Expand Down
65 changes: 65 additions & 0 deletions src/svdrom/svdrom_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import ABC, abstractmethod

import xarray as xr

from svdrom.logger import setup_logger

logger = setup_logger("Base", "base.log")


class DecompositionModel(ABC):
"""Abstract Base Class for all SVD-based Reduced Order Models.

Enforces a common interface for SVD, POD, DMD, SPOD, etc.
"""

def __init__(self, n_components: int) -> None:
"""
Parameters
----------
n_components : int
The number of components/modes to keep.
"""
self._n_components = n_components

@abstractmethod
def fit(self, *args, **kwargs) -> "DecompositionModel":
"""Fit the model to the data.

Parameters
----------
*args : list
Variable length argument list.
**kwargs : dict
Arbitrary keyword arguments.
"""

@abstractmethod
def reconstruct(
self, *args, **kwargs
) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]:
"""Reconstruct the data using the fitted model.

Returns
-------
xr.DataArray | tuple[xr.DataArray, xr.DataArray]
The reconstructed data (deterministic or probabilistic).
"""

def _check_is_fitted(self, attributes: list[str]) -> None:
"""Checks if the model is fitted by verifying the existence
of specific attributes.

Parameters
----------
attributes : List[str]
List of attribute names to check (e.g. ['_u', '_s']).
"""
for attr in attributes:
if not hasattr(self, attr) or getattr(self, attr) is None:
msg = (
f"This {self.__class__.__name__} instance is not fitted yet. "
"Call 'fit' with appropriate arguments before using this estimator."
)
logger.error(msg)
raise RuntimeError(msg)
6 changes: 3 additions & 3 deletions tests/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def test_transform(matrix_type):


@pytest.mark.parametrize("matrix_type", ["tall-and-skinny", "short-and-fat"])
def test_reconstruct_snapshot(matrix_type):
"""Test the reconstruct_snapshot method of TruncatedSVD."""
def test_reconstruct(matrix_type):
"""Test the reconstruct method of TruncatedSVD."""
X = make_dataarray(matrix_type)
n_components = 10
tsvd = TruncatedSVD(n_components=n_components)
tsvd.fit(X)

X_r = tsvd.reconstruct_snapshot(0)
X_r = tsvd.reconstruct(0)
assert isinstance(
X_r, xr.DataArray
), f"Reconstructed snapshot should be an xarray DataArray, got {type(X_r)}."
Expand Down
Loading