Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ hdf5-metno = { version = "0.12", default-features = false }
ndarray = "0.17"
quanticsgrids = { git = "https://github.com/tensor4all/quanticsgrids-rs", rev = "a76b8fb" }
hdf5-rt = { git = "https://github.com/tensor4all/hdf5-rt", default-features = false }
tenferro = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-algebra = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-prims = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro-tensor-compute = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-algebra = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-prims = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
tenferro-tensor-compute = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
4 changes: 2 additions & 2 deletions crates/tensor4all-core/src/defaults/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ pub use contract::{
};
pub use index::{DefaultIndex, DefaultTagSet, DynId, DynIndex, Index, TagSet};
pub use tensordynlen::{
compute_permutation_from_indices, diag_tensor_dyn_len, is_diag_tensor, unfold_split,
RandomScalar, TensorAccess, TensorDynLen,
compute_permutation_from_indices, diag_tensor_dyn_len, unfold_split, RandomScalar,
TensorAccess, TensorDynLen,
};

// Re-export linear algebra functions and types
Expand Down
60 changes: 52 additions & 8 deletions crates/tensor4all-core/src/defaults/qr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use crate::truncation::TruncationParams;
use crate::{unfold_split, TensorDynLen};
use num_complex::ComplexFloat;
use tensor4all_tensorbackend::{
native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
qr_native_tensor, reshape_col_major_native_tensor,
dense_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major,
native_tensor_primal_to_dense_f64_col_major, qr_native_tensor, reshape_col_major_native_tensor,
TensorElement,
};
use thiserror::Error;

Expand Down Expand Up @@ -112,6 +113,28 @@ where
Ok(r.max(1))
}

fn truncate_matrix_cols<T: TensorElement>(
data: &[T],
rows: usize,
keep_cols: usize,
) -> anyhow::Result<tenferro::Tensor> {
dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
}

fn truncate_matrix_rows<T: TensorElement>(
data: &[T],
rows: usize,
cols: usize,
keep_rows: usize,
) -> anyhow::Result<tenferro::Tensor> {
let mut truncated = Vec::with_capacity(keep_rows * cols);
for col in 0..cols {
let start = col * rows;
truncated.extend_from_slice(&data[start..start + keep_rows]);
}
dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
}

/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
///
/// This function uses the global default rtol for truncation.
Expand Down Expand Up @@ -227,12 +250,33 @@ pub fn qr_with<T>(
}
};
if r < k {
q_native = q_native.take_prefix(1, r).map_err(|e| {
QrError::ComputationError(anyhow::anyhow!("native QR truncation on Q failed: {e}"))
})?;
r_native = r_native.take_prefix(0, r).map_err(|e| {
QrError::ComputationError(anyhow::anyhow!("native QR truncation on R failed: {e}"))
})?;
match q_native.scalar_type() {
tenferro::ScalarType::F64 => {
let q_values = native_tensor_primal_to_dense_f64_col_major(&q_native)
.map_err(QrError::ComputationError)?;
let r_values = native_tensor_primal_to_dense_f64_col_major(&r_native)
.map_err(QrError::ComputationError)?;
q_native =
truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
r_native =
truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
}
tenferro::ScalarType::C64 => {
let q_values = native_tensor_primal_to_dense_c64_col_major(&q_native)
.map_err(QrError::ComputationError)?;
let r_values = native_tensor_primal_to_dense_c64_col_major(&r_native)
.map_err(QrError::ComputationError)?;
q_native =
truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
r_native =
truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
}
other => {
return Err(QrError::ComputationError(anyhow::anyhow!(
"native QR returned unsupported scalar type {other:?}"
)));
}
}
}

let bond_index = DynIndex::new_bond(r)
Expand Down
75 changes: 57 additions & 18 deletions crates/tensor4all-core/src/defaults/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use crate::index_like::IndexLike;
use crate::truncation::{HasTruncationParams, TruncationParams};
use crate::{unfold_split, TensorDynLen};
use tensor4all_tensorbackend::{
dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
reshape_col_major_native_tensor, svd_native_tensor,
reshape_col_major_native_tensor, svd_native_tensor, TensorElement,
};
use thiserror::Error;

Expand Down Expand Up @@ -134,6 +135,28 @@ fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, Sv
}
}

fn truncate_matrix_cols<T: TensorElement>(
data: &[T],
rows: usize,
keep_cols: usize,
) -> anyhow::Result<tenferro::Tensor> {
dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
}

fn truncate_matrix_rows<T: TensorElement>(
data: &[T],
rows: usize,
cols: usize,
keep_rows: usize,
) -> anyhow::Result<tenferro::Tensor> {
let mut truncated = Vec::with_capacity(keep_rows * cols);
for col in 0..cols {
let start = col * rows;
truncated.extend_from_slice(&data[start..start + keep_rows]);
}
dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
}

type SvdTruncatedNativeResult = (
tenferro::Tensor,
tenferro::Tensor,
Expand Down Expand Up @@ -167,17 +190,35 @@ fn svd_truncated_native(
r = r.min(max_rank);
}
if r < k {
u_native = u_native.take_prefix(1, r).map_err(|e| {
SvdError::ComputationError(anyhow::anyhow!("native SVD truncation on U failed: {e}"))
})?;
s_native = s_native.take_prefix(0, r).map_err(|e| {
SvdError::ComputationError(anyhow::anyhow!(
"native SVD truncation on singular values failed: {e}"
))
})?;
vt_native = vt_native.take_prefix(0, r).map_err(|e| {
SvdError::ComputationError(anyhow::anyhow!("native SVD V^T truncation failed: {e}"))
})?;
match u_native.scalar_type() {
tenferro::ScalarType::F64 => {
let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native)
.map_err(SvdError::ComputationError)?;
let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native)
.map_err(SvdError::ComputationError)?;
u_native =
truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
vt_native = truncate_matrix_rows(&vt_values, k, n, r)
.map_err(SvdError::ComputationError)?;
}
tenferro::ScalarType::C64 => {
let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native)
.map_err(SvdError::ComputationError)?;
let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native)
.map_err(SvdError::ComputationError)?;
u_native =
truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
vt_native = truncate_matrix_rows(&vt_values, k, n, r)
.map_err(SvdError::ComputationError)?;
}
other => {
return Err(SvdError::ComputationError(anyhow::anyhow!(
"native SVD returned unsupported singular-vector scalar type {other:?}"
)));
}
}
s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r])
.map_err(SvdError::ComputationError)?;
}

let bond_index = DynIndex::new_bond(r)
Expand Down Expand Up @@ -225,9 +266,8 @@ pub fn svd_with<T>(
let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;

let s_indices = vec![bond_index.clone(), bond_index.sim()];
let s_diag = s_native.diag_embed(2).map_err(|e| {
SvdError::ComputationError(anyhow::anyhow!("native SVD diagonal embedding failed: {e}"))
})?;
let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
.map_err(SvdError::ComputationError)?;
let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;

let mut vh_indices = vec![bond_index.clone()];
Expand Down Expand Up @@ -273,9 +313,8 @@ pub(crate) fn svd_for_factorize(
let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;

let s_indices = vec![bond_index.clone(), bond_index.sim()];
let s_diag = s_native.diag_embed(2).map_err(|e| {
SvdError::ComputationError(anyhow::anyhow!("native SVD diagonal embedding failed: {e}"))
})?;
let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
.map_err(SvdError::ComputationError)?;
let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;

let mut vh_indices = vec![bond_index.clone()];
Expand Down
Loading
Loading