Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/svdrom/dmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from dask.utils import parse_bytes
from pydmd import BOPDMD

from svdrom.svdrom_base import DecompositionModel
from svdrom.logger import setup_logger

logger = setup_logger("DMD", "dmd.log")


class OptDMD:
class OptDMD(DecompositionModel):
def __init__(
self,
n_modes: int = -1,
Expand Down Expand Up @@ -73,7 +74,8 @@ def __init__(
This class is a wrapper of the `BOPDMD.fit_econ()` method, which fits
an approximate Optimized DMD on an array X by operating on the SVD of X.
"""
if n_modes != -1 and n_modes < 1:
super().__init__(n_components=n_modes)
if self.n_components != -1 and self.n_components < 1:
msg = "'n_modes' must be a positive integer or -1."
logger.exception(msg)
raise ValueError(msg)
Expand Down
87 changes: 7 additions & 80 deletions src/svdrom/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import numpy as np
import xarray as xr

from svdrom.svdrom_base import DecompositionModel
from svdrom.logger import setup_logger

logger = setup_logger("SVD", "svd.log")


class TruncatedSVD:
class TruncatedSVD(DecompositionModel):
def __init__(
self,
n_components: int,
Expand Down Expand Up @@ -83,16 +84,14 @@ def __init__(
function. The 'randomized' algorithm is implemented via Dask's
`dask.array.linalg.svd_compressed` function.
"""

self._n_components = n_components
super().__init__(n_components=n_components) # Inherit from baseclass

self._algorithm = algorithm
self._compute_u = compute_u
self._compute_v = compute_v
self._compute_var_ratio = compute_var_ratio
self._rechunk = rechunk
self._u: xr.DataArray | None = None
self._s: np.ndarray | None = None
self._v: xr.DataArray | None = None
# Note: _u, _s, _v are initialized in super().__init__ but typed here for clarity if needed
self._explained_var_ratio: np.ndarray | da.Array | None = None

@property
Expand Down Expand Up @@ -364,77 +363,5 @@ def transform(self, X: xr.DataArray) -> xr.DataArray:
)
logger.exception(msg)
raise ValueError(msg) from e
return self._singular_vectors_to_dataarray(X_da_transformed, X)

def reconstruct_snapshot(
self,
snapshot: int | str,
snapshot_dim: str = "time",
) -> xr.DataArray:
"""Reconstruct a snapshot or group of snapshots from
the left singular vectors, singular values, and right
singular vectors.

Parameters
----------
snapshot: int | str
The index or label of the snapshot to reconstruct.
If it's an integer, it's interpreted as an index. If it's
a string, it's interpreted as a label.
snapshot_dim: str, (default 'time')
The dimension along which the snapshots are indexed.

Returns
-------
xr.DataArray: The reconstructed snapshot/s as an Xarray DataArray.

Examples
--------
# Reconstructs the first snapshot
>>> tsvd.reconstruct_snapshot(0)

# Reconstructs all snapshots with label '2017-01-01'
>>> tsvd.reconstruct_snapshot("2017-01-01")
"""

if not (
isinstance(self._u, xr.DataArray)
and isinstance(self._v, xr.DataArray)
and isinstance(self._s, np.ndarray)
):
msg = (
"Computed left and right singular vectors and "
"singular values are required before calling reconstruct()."
)
logger.exception(msg)
raise ValueError(msg)

if isinstance(snapshot, int):
try:
if snapshot_dim in self._v.dims:
return self._u @ (self._s * self._v[:, snapshot].T)
if snapshot_dim in self._u.dims:
return (self._u[snapshot_dim, :] * self._s) @ self._v
msg = f"Snapshot dimension '{snapshot_dim}' does not exist."
logger.exception(msg)
raise ValueError(msg)
except IndexError as e:
msg = (
f"Snapshot index {snapshot} is out of bounds for the right "
f"singular vectors with shape {self._v.shape}."
)
logger.exception(msg)
raise IndexError(msg) from e
else:
try:
if snapshot_dim in self._v.dims:
return self._u @ (self._s * self._v.loc[:, snapshot].T)
if snapshot_dim in self._u.dims:
return (self._u.loc[snapshot_dim, :] * self._s) @ self._v
msg = f"Snapshot dimension '{snapshot_dim}' does not exist."
logger.exception(msg)
raise ValueError(msg)
except KeyError as e:
msg = f"Snapshot '{snapshot}' not found in the right singular vectors."
logger.exception(msg)
raise KeyError(msg) from e

return self._singular_vectors_to_dataarray(X_da_transformed, X)
183 changes: 183 additions & 0 deletions src/svdrom/svdrom_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from abc import ABC, abstractmethod
from typing import Any, List

import numpy as np
import xarray as xr
import dask.array as da

from svdrom.logger import setup_logger

logger = setup_logger("Base", "base.log")


class DecompositionModel(ABC):
"""Abstract Base Class for all SVD-based Reduced Order Models.

Enforces a common interface for SVD, POD, DMD, SPOD, etc.
"""

def __init__(self, n_components: int):
"""
Parameters
----------
n_components : int
The number of components/modes to keep.
"""
self._n_components = n_components
self._u: xr.DataArray | None = None
self._s: np.ndarray | None = None
self._v: xr.DataArray | None = None

@property
def n_components(self) -> int:
"""Number of components/modes (read-only)."""
return self._n_components

@abstractmethod
def fit(self, *args, **kwargs) -> None:
"""Fit the model to the data.

Parameters
----------
*args : list
Variable length argument list.
**kwargs : dict
Arbitrary keyword arguments.
"""
pass

@abstractmethod
def reconstruct(self, *args, **kwargs) -> xr.DataArray:
"""Reconstruct the data using the fitted model.

Returns
-------
xr.DataArray
The reconstructed data.
"""
pass

def _check_is_fitted(self, attributes: List[str]) -> None:
"""Checks if the model is fitted by verifying the existence
of specific attributes.

Parameters
----------
attributes : List[str]
List of attribute names to check (e.g. ['_u', '_s']).
"""
for attr in attributes:
if not hasattr(self, attr) or getattr(self, attr) is None:
msg = (
f"This {self.__class__.__name__} instance is not fitted yet. "
"Call 'fit' with appropriate arguments before using this estimator."
)
logger.exception(msg)
raise RuntimeError(msg)

def reconstruct(
self,
time: slice | int | str | None = None,
time_dimension: str = "time",
memory_limit_bytes: float = 1e9,
) -> xr.DataArray:
"""Reconstruct the data using the fitted model components (U, S, V).

This method computes the approximation $X \\approx U \\Sigma V^T$.
It supports reconstructing specific snapshots, slices of time, or the
entire dataset. It automatically handles memory management by switching
to Dask if the estimated reconstruction size exceeds the limit.

Parameters
----------
time: slice | int | str | None, optional
The time span over which to perform the reconstruction.
- If None: Reconstructs the entire training dataset.
- If int: Reconstructs a single snapshot by integer index.
- If str: Reconstructs a single snapshot by coordinate label.
- If slice: Reconstructs a range. Slices containing integers are
treated as indices; slices containing strings are treated as labels.
time_dimension: str, default 'time'
The name of the time dimension in the right singular vectors (_v).
memory_limit_bytes: float, default 1e9 (1 GB)
The memory threshold. If the estimated reconstruction size exceeds
this, Dask is used for lazy computation. Otherwise, NumPy is used.

Returns
-------
xr.DataArray
The reconstructed data.
"""
self._check_is_fitted(["_u", "_s", "_v"])

if time_dimension not in self._v.dims:
msg = f"Dimension '{time_dimension}' not found in right singular vectors."
logger.exception(msg)
raise ValueError(msg)

try:
if time is None:
v_subset = self._v
elif isinstance(time, int):
v_subset = self._v.isel({time_dimension: time})
elif isinstance(time, str):
v_subset = self._v.sel({time_dimension: time})
elif isinstance(time, slice):
is_index_slice = True
if (isinstance(time.start, str)) or (isinstance(time.stop, str)):
is_index_slice = False

if is_index_slice:
v_subset = self._v.isel({time_dimension: time})
else:
v_subset = self._v.sel({time_dimension: time})
else:
msg = "Parameter 'time' must be a slice, int, str, or None."
logger.exception(msg)
raise ValueError(msg)
except KeyError as e:
msg = f"Could not slice the time dimension '{time_dimension}' with key {time}."
logger.exception(msg)
raise KeyError(msg) from e

# Estimate Reconstruction Size
# Shape approximation: U (spatial) * V_subset (temporal) * 8 bytes (float64)
# U shape is (n_features, n_components)
n_features = self._u.shape[0]
# Handle case where v_subset is 1D (single snapshot) vs 2D (multiple)
if time_dimension in v_subset.dims:
n_time = v_subset.sizes[time_dimension]
else:
n_time = 1

estimated_bytes = n_features * n_time * 8

msg = f"Estimated reconstruction size is {estimated_bytes/1e6:.2f} MB."
logger.info(msg)

use_dask = estimated_bytes > memory_limit_bytes

# We pre-multiply U * S for efficiency: X = (U*S) @ V.T
lhs = self._u * self._s
rhs = v_subset

if use_dask:
logger.info("Memory limit exceeded. Using Dask for reconstruction.")
if not isinstance(lhs.data, da.Array):
lhs = lhs.chunk("auto")
if not isinstance(rhs.data, da.Array):
rhs = rhs.chunk("auto")
else:
logger.info("Within memory limit. Using NumPy for reconstruction.")
# if arrays were lazy, bring them into memory to avoid dask overhead for small ops
if isinstance(lhs.data, da.Array):
lhs = lhs.compute()
if isinstance(rhs.data, da.Array):
rhs = rhs.compute()

logger.info("Computing reconstruction...")
reconstruction = lhs.dot(rhs, dim="components")

logger.info("Done.")

return reconstruction
6 changes: 3 additions & 3 deletions tests/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ def test_transform(matrix_type):


@pytest.mark.parametrize("matrix_type", ["tall-and-skinny", "short-and-fat"])
def test_reconstruct_snapshot(matrix_type):
"""Test the reconstruct_snapshot method of TruncatedSVD."""
def test_reconstruct(matrix_type):
"""Test the reconstruct method of TruncatedSVD."""
X = make_dataarray(matrix_type)
n_components = 10
tsvd = TruncatedSVD(n_components=n_components)
tsvd.fit(X)

X_r = tsvd.reconstruct_snapshot(0)
X_r = tsvd.reconstruct(0)
assert isinstance(
X_r, xr.DataArray
), f"Reconstructed snapshot should be an xarray DataArray, got {type(X_r)}."
Expand Down
Loading