Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@
pub mod data;
use data::{Nmrc, Sprs, Symb};

#[derive(Copy, Clone, Debug)]
pub enum Error {
/// Cholesky factorization failed (not positive definite)
NotPositiveDefinite,
/// LU factorization failed (no pivot found)
NoPivot,
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoPivot => write!(f, "Could not find a pivot"),
Self::NotPositiveDefinite => write!(f, "Could not complete Cholesky factorization. Please provide a positive definite matrix"),
}
}
}

impl std::error::Error for Error {}

// --- Public functions --------------------------------------------------------

/// C = alpha * A + beta * B
Expand Down Expand Up @@ -253,7 +272,7 @@ pub fn add(a: &Sprs, b: &Sprs, alpha: f64, beta: f64) -> Sprs {
///
/// See: `schol(...)`
///
pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
pub fn chol(a: &Sprs, s: &mut Symb) -> Result<Nmrc, Error> {
let mut top;
let mut d;
let mut lki;
Expand Down Expand Up @@ -302,7 +321,7 @@ pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
// --- Compute L(k,k) -----------------------------------------------
if d <= 0. {
// not pos def
panic!("Could not complete Cholesky factorization. Please provide a positive definite matrix");
return Err(Error::NotPositiveDefinite);
}
let p = w[wc + k];
w[wc + k] += 1;
Expand All @@ -311,7 +330,7 @@ pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
}
n_mat.l.p[n] = s.cp[n]; // finalize L

n_mat
Ok(n_mat)
}

/// A\b solver using Cholesky factorization.
Expand Down Expand Up @@ -351,16 +370,18 @@ pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
/// println!("{:?}", &b);
/// ```
///
pub fn cholsol(a: &Sprs, b: &mut [f64], order: i8) {
pub fn cholsol(a: &Sprs, b: &mut [f64], order: i8) -> Result<(), Error> {
let n = a.n;
let mut s = schol(a, order); // ordering and symbolic analysis
let n_mat = chol(a, &mut s); // numeric Cholesky factorization
let n_mat = chol(a, &mut s)?; // numeric Cholesky factorization
let mut x = vec![0.; n];

ipvec(n, &s.pinv, b, &mut x[..]); // x = P*b
lsolve(&n_mat.l, &mut x); // x = L\x
ltsolve(&n_mat.l, &mut x); // x = L'\x
pvec(n, &s.pinv, &x[..], &mut b[..]); // b = P'*x

Ok(())
}

/// Generalized A times X Plus Y
Expand Down Expand Up @@ -490,7 +511,7 @@ pub fn ltsolve(l: &Sprs, x: &mut [f64]) {
///
/// See: `sqr(...)`
///
pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Result<Nmrc, Error> {
let n = a.n;
let mut col;
let mut top;
Expand Down Expand Up @@ -556,7 +577,7 @@ pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
}
}
if ipiv == -1 || a_f <= 0. {
panic!("Could not find a pivot");
return Err(Error::NoPivot);
}
if n_mat.pinv.as_ref().unwrap()[col] < 0 && f64::abs(x[col]) >= a_f * tol {
ipiv = col as isize;
Expand Down Expand Up @@ -592,7 +613,7 @@ pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
n_mat.l.quick_trim();
n_mat.u.quick_trim();

n_mat
Ok(n_mat)
}

/// A\b solver using LU factorization.
Expand Down Expand Up @@ -642,16 +663,17 @@ pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
/// ```

///
pub fn lusol(a: &Sprs, b: &mut [f64], order: i8, tol: f64) {
pub fn lusol(a: &Sprs, b: &mut [f64], order: i8, tol: f64) -> Result<(), Error> {
let mut x = vec![0.; a.n];
let mut s;
s = sqr(a, order, false); // ordering and symbolic analysis
let n = lu(a, &mut s, tol); // numeric LU factorization
let n = lu(a, &mut s, tol)?; // numeric LU factorization

ipvec(a.n, &n.pinv, b, &mut x[..]); // x = P*b
lsolve(&n.l, &mut x); // x = L\x
usolve(&n.u, &mut x[..]); // x = U\x
ipvec(a.n, &s.q, &x[..], &mut b[..]); // b = Q*x
Ok(())
}

/// C = A * B
Expand Down
32 changes: 16 additions & 16 deletions tests/solver_tests.rs

Large diffs are not rendered by default.