From 5f44911ff38d4ad25051250bfe3e50bce014603f Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 31 Mar 2026 14:05:41 +0900 Subject: [PATCH 1/4] fix: close core reconnect under public bridge contract --- Cargo.toml | 16 +- crates/tensor4all-core/src/defaults/mod.rs | 4 +- crates/tensor4all-core/src/defaults/qr.rs | 60 +- crates/tensor4all-core/src/defaults/svd.rs | 75 ++- .../src/defaults/tensordynlen.rs | 96 +-- crates/tensor4all-core/src/forward_ad.rs | 54 -- crates/tensor4all-core/src/lib.rs | 5 +- .../tensor4all-core/tests/linalg_factorize.rs | 6 +- crates/tensor4all-core/tests/tensor_basic.rs | 56 +- crates/tensor4all-core/tests/tensor_diag.rs | 37 +- .../tensor4all-core/tests/tensor_native_ad.rs | 231 +------- .../src/any_scalar.rs | 229 +++++--- .../tensor4all-tensorbackend/src/storage.rs | 66 ++- .../src/tenferro_bridge.rs | 551 +++++++----------- .../src/tenferro_bridge/tests/mod.rs | 177 ++++-- .../src/tensor_element.rs | 128 ++-- 16 files changed, 831 insertions(+), 960 deletions(-) delete mode 100644 crates/tensor4all-core/src/forward_ad.rs diff --git a/Cargo.toml b/Cargo.toml index d730aaae..c74004e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,11 +51,11 @@ hdf5-metno = { version = "0.12", default-features = false } ndarray = "0.17" quanticsgrids = { git = "https://github.com/tensor4all/quanticsgrids-rs", rev = "a76b8fb" } hdf5-rt = { git = "https://github.com/tensor4all/hdf5-rt", default-features = false } -tenferro = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-algebra = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-prims = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } -tenferro-tensor-compute = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" } +tenferro = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro" } +tenferro-algebra = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-algebra" } +tenferro-device = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-device" } +tenferro-einsum = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-einsum" } +tenferro-prims = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-prims" } +tenferro-tensor = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-tensor" } +tenferro-linalg = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-linalg" } +tenferro-tensor-compute = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-tensor-compute" } diff --git a/crates/tensor4all-core/src/defaults/mod.rs b/crates/tensor4all-core/src/defaults/mod.rs index 9341cc66..06d25d5a 100644 --- a/crates/tensor4all-core/src/defaults/mod.rs +++ b/crates/tensor4all-core/src/defaults/mod.rs @@ -37,8 +37,8 @@ pub use contract::{ }; pub use index::{DefaultIndex, DefaultTagSet, DynId, DynIndex, Index, TagSet}; pub use tensordynlen::{ - compute_permutation_from_indices, diag_tensor_dyn_len, is_diag_tensor, unfold_split, - RandomScalar, TensorAccess, TensorDynLen, + compute_permutation_from_indices, diag_tensor_dyn_len, unfold_split, RandomScalar, + TensorAccess, TensorDynLen, }; // Re-export linear algebra functions and types diff --git a/crates/tensor4all-core/src/defaults/qr.rs b/crates/tensor4all-core/src/defaults/qr.rs index 5ff00587..716f0fae 100644 --- a/crates/tensor4all-core/src/defaults/qr.rs +++ b/crates/tensor4all-core/src/defaults/qr.rs @@ -8,8 +8,9 @@ use crate::truncation::TruncationParams; use crate::{unfold_split, TensorDynLen}; use num_complex::ComplexFloat; use tensor4all_tensorbackend::{ - native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major, - qr_native_tensor, reshape_col_major_native_tensor, + dense_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major, + native_tensor_primal_to_dense_f64_col_major, qr_native_tensor, reshape_col_major_native_tensor, + TensorElement, }; use thiserror::Error; @@ -112,6 +113,28 @@ where Ok(r.max(1)) } +fn truncate_matrix_cols( + data: &[T], + rows: usize, + keep_cols: usize, +) -> anyhow::Result { + dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols]) +} + +fn truncate_matrix_rows( + data: &[T], + rows: usize, + cols: usize, + keep_rows: usize, +) -> anyhow::Result { + let mut truncated = Vec::with_capacity(keep_rows * cols); + for col in 0..cols { + let start = col * rows; + truncated.extend_from_slice(&data[start..start + keep_rows]); + } + dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols]) +} + /// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R). /// /// This function uses the global default rtol for truncation. @@ -227,12 +250,33 @@ pub fn qr_with( } }; if r < k { - q_native = q_native.take_prefix(1, r).map_err(|e| { - QrError::ComputationError(anyhow::anyhow!("native QR truncation on Q failed: {e}")) - })?; - r_native = r_native.take_prefix(0, r).map_err(|e| { - QrError::ComputationError(anyhow::anyhow!("native QR truncation on R failed: {e}")) - })?; + match q_native.scalar_type() { + tenferro::ScalarType::F64 => { + let q_values = native_tensor_primal_to_dense_f64_col_major(&q_native) + .map_err(QrError::ComputationError)?; + let r_values = native_tensor_primal_to_dense_f64_col_major(&r_native) + .map_err(QrError::ComputationError)?; + q_native = + truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?; + r_native = + truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?; + } + tenferro::ScalarType::C64 => { + let q_values = native_tensor_primal_to_dense_c64_col_major(&q_native) + .map_err(QrError::ComputationError)?; + let r_values = native_tensor_primal_to_dense_c64_col_major(&r_native) + .map_err(QrError::ComputationError)?; + q_native = + truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?; + r_native = + truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?; + } + other => { + return Err(QrError::ComputationError(anyhow::anyhow!( + "native QR returned unsupported scalar type {other:?}" + ))); + } + } } let bond_index = DynIndex::new_bond(r) diff --git a/crates/tensor4all-core/src/defaults/svd.rs b/crates/tensor4all-core/src/defaults/svd.rs index 73aab5de..97e085c3 100644 --- a/crates/tensor4all-core/src/defaults/svd.rs +++ b/crates/tensor4all-core/src/defaults/svd.rs @@ -8,8 +8,9 @@ use crate::index_like::IndexLike; use crate::truncation::{HasTruncationParams, TruncationParams}; use crate::{unfold_split, TensorDynLen}; use tensor4all_tensorbackend::{ + dense_native_tensor_from_col_major, diag_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major, - reshape_col_major_native_tensor, svd_native_tensor, + reshape_col_major_native_tensor, svd_native_tensor, TensorElement, }; use thiserror::Error; @@ -134,6 +135,28 @@ fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result, Sv } } +fn truncate_matrix_cols( + data: &[T], + rows: usize, + keep_cols: usize, +) -> anyhow::Result { + dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols]) +} + +fn truncate_matrix_rows( + data: &[T], + rows: usize, + cols: usize, + keep_rows: usize, +) -> anyhow::Result { + let mut truncated = Vec::with_capacity(keep_rows * cols); + for col in 0..cols { + let start = col * rows; + truncated.extend_from_slice(&data[start..start + keep_rows]); + } + dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols]) +} + type SvdTruncatedNativeResult = ( tenferro::Tensor, tenferro::Tensor, @@ -167,17 +190,35 @@ fn svd_truncated_native( r = r.min(max_rank); } if r < k { - u_native = u_native.take_prefix(1, r).map_err(|e| { - SvdError::ComputationError(anyhow::anyhow!("native SVD truncation on U failed: {e}")) - })?; - s_native = s_native.take_prefix(0, r).map_err(|e| { - SvdError::ComputationError(anyhow::anyhow!( - "native SVD truncation on singular values failed: {e}" - )) - })?; - vt_native = vt_native.take_prefix(0, r).map_err(|e| { - SvdError::ComputationError(anyhow::anyhow!("native SVD V^T truncation failed: {e}")) - })?; + match u_native.scalar_type() { + tenferro::ScalarType::F64 => { + let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native) + .map_err(SvdError::ComputationError)?; + let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native) + .map_err(SvdError::ComputationError)?; + u_native = + truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?; + vt_native = truncate_matrix_rows(&vt_values, k, n, r) + .map_err(SvdError::ComputationError)?; + } + tenferro::ScalarType::C64 => { + let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native) + .map_err(SvdError::ComputationError)?; + let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native) + .map_err(SvdError::ComputationError)?; + u_native = + truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?; + vt_native = truncate_matrix_rows(&vt_values, k, n, r) + .map_err(SvdError::ComputationError)?; + } + other => { + return Err(SvdError::ComputationError(anyhow::anyhow!( + "native SVD returned unsupported singular-vector scalar type {other:?}" + ))); + } + } + s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r]) + .map_err(SvdError::ComputationError)?; } let bond_index = DynIndex::new_bond(r) @@ -225,9 +266,8 @@ pub fn svd_with( let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?; let s_indices = vec![bond_index.clone(), bond_index.sim()]; - let s_diag = s_native.diag_embed(2).map_err(|e| { - SvdError::ComputationError(anyhow::anyhow!("native SVD diagonal embedding failed: {e}")) - })?; + let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2) + .map_err(SvdError::ComputationError)?; let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?; let mut vh_indices = vec![bond_index.clone()]; @@ -273,9 +313,8 @@ pub(crate) fn svd_for_factorize( let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?; let s_indices = vec![bond_index.clone(), bond_index.sim()]; - let s_diag = s_native.diag_embed(2).map_err(|e| { - SvdError::ComputationError(anyhow::anyhow!("native SVD diagonal embedding failed: {e}")) - })?; + let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2) + .map_err(SvdError::ComputationError)?; let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?; let mut vh_indices = vec![bond_index.clone()]; diff --git a/crates/tensor4all-core/src/defaults/tensordynlen.rs b/crates/tensor4all-core/src/defaults/tensordynlen.rs index 1c65b3a0..ecb7a0e1 100644 --- a/crates/tensor4all-core/src/defaults/tensordynlen.rs +++ b/crates/tensor4all-core/src/defaults/tensordynlen.rs @@ -10,7 +10,7 @@ use rand_distr::{Distribution, StandardNormal}; use std::collections::HashSet; use std::ops::{Mul, Neg, Sub}; use std::sync::Arc; -use tenferro::{AdMode, ScalarType as NativeScalarType, Tensor as NativeTensor}; +use tenferro::{ScalarType as NativeScalarType, Tensor as NativeTensor}; use tensor4all_tensorbackend::{ axpby_native_tensor, conj_native_tensor, contract_native_tensor, dense_native_tensor_from_col_major, diag_native_tensor_from_col_major, @@ -118,7 +118,7 @@ pub struct TensorDynLen { /// Full index information (includes tags and other metadata). pub indices: Vec, /// Canonical native payload preserving AD metadata. - native: NativeTensor, + native: Arc, } impl TensorAccess for TensorDynLen { @@ -216,7 +216,10 @@ impl TensorDynLen { if native.is_diag() { Self::validate_diag_dims(&dims)?; } - Ok(Self { indices, native }) + Ok(Self { + indices, + native: Arc::new(native), + }) } /// Borrow the indices. @@ -226,12 +229,7 @@ impl TensorDynLen { /// Borrow the native payload. pub(crate) fn as_native(&self) -> &NativeTensor { - &self.native - } - - /// Returns the upstream AD mode. - pub fn mode(&self) -> AdMode { - self.native.mode() + self.native.as_ref() } /// Returns whether the tensor participates in reverse-mode AD. @@ -241,18 +239,23 @@ impl TensorDynLen { /// Enables or disables reverse-mode gradient tracking. pub fn set_requires_grad(&mut self, enabled: bool) -> Result<()> { - self.native - .set_requires_grad(enabled) - .map_err(|e| anyhow::anyhow!("TensorDynLen::set_requires_grad failed: {e}")) + self.native = Arc::new(self.native.as_ref().detach().requires_grad_(enabled)); + Ok(()) } /// Returns the accumulated reverse gradient when available. - pub fn grad(&self) -> Option { - self.native.grad().map(|native| { - Self::from_native(self.indices.clone(), native).unwrap_or_else(|e| { - panic!("TensorDynLen::grad returned a tensor incompatible with indices: {e}") + pub fn grad(&self) -> Result> { + self.native + .grad() + .map_err(|e| anyhow::anyhow!("TensorDynLen::grad failed: {e}"))? + .map(|native| { + Self::from_native(self.indices.clone(), native).map_err(|e| { + anyhow::anyhow!( + "TensorDynLen::grad returned a tensor incompatible with indices: {e}" + ) + }) }) - }) + .transpose() } /// Clears accumulated reverse gradients on reverse leaves. @@ -268,18 +271,17 @@ impl TensorDynLen { /// /// # Arguments /// * `grad_output` - Optional gradient seed. Pass `None` for default (ones). - /// * `inputs` - Leaf tensors to accumulate gradients on. After this call, - /// each input's `.grad()` will contain the accumulated gradient. - pub fn backward(&self, grad_output: Option<&Self>, inputs: &[&Self]) -> Result<()> { - let grad_native = grad_output.map(|g| &g.native); - let input_natives: Vec<&NativeTensor> = inputs.iter().map(|t| &t.native).collect(); - self.native - .backward( - grad_native, - &input_natives, - tenferro::BackwardOptions::default(), - ) - .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}")) + pub fn backward(&self, grad_output: Option<&Self>) -> Result<()> { + match grad_output { + Some(grad_output) => self + .native + .backward_with_seed(&grad_output.native) + .map_err(|e| anyhow::anyhow!("TensorDynLen::backward_with_seed failed: {e}")), + None => self + .native + .backward() + .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}")), + } } /// Check if this tensor is already in canonical form. @@ -289,7 +291,9 @@ impl TensorDynLen { /// Materialize the primal snapshot as storage. pub fn to_storage(&self) -> Result> { - Ok(Arc::new(native_tensor_primal_to_storage(&self.native)?)) + Ok(Arc::new(native_tensor_primal_to_storage( + self.native.as_ref(), + )?)) } /// Materialize the primal snapshot as storage. @@ -768,11 +772,6 @@ impl Neg for TensorDynLen { } } -/// Check if a tensor is a DiagTensor (has Diag storage). -pub fn is_diag_tensor(tensor: &TensorDynLen) -> bool { - tensor.is_diag() -} - impl TensorDynLen { /// Add two tensors element-wise. /// @@ -990,8 +989,10 @@ impl TensorDynLen { }) .collect(); - Self::from_native(new_indices, self.native.clone()) - .expect("replaceind should preserve native payload dims") + Self { + indices: new_indices, + native: self.native.clone(), + } } /// Replace multiple indices in the tensor. @@ -1062,8 +1063,10 @@ impl TensorDynLen { }) .collect(); - Self::from_native(new_indices_vec, self.native.clone()) - .expect("replaceinds should preserve native payload dims") + Self { + indices: new_indices_vec, + native: self.native.clone(), + } } } @@ -1224,17 +1227,20 @@ impl std::fmt::Debug for TensorDynLen { .field("indices", &self.indices) .field("dims", &self.dims()) .field("is_diag", &self.native.is_diag()) - .field("mode", &self.native.mode()) .finish() } } -/// Create a DiagTensor with dynamic rank from diagonal data. +/// Create a diagonal tensor with dynamic rank from diagonal data. /// /// # Arguments /// * `indices` - The indices for the tensor (all must have the same dimension) /// * `diag_data` - The diagonal elements (length must equal the dimension of indices) /// +/// The public native bridge currently materializes diagonal payloads densely, so +/// the returned tensor is mathematically diagonal but may not report +/// [`TensorDynLen::is_diag`] at the native-storage level. +/// /// # Panics /// Panics if indices have different dimensions, or if diag_data length doesn't match. pub fn diag_tensor_dyn_len(indices: Vec, diag_data: Vec) -> TensorDynLen { @@ -1618,6 +1624,10 @@ impl TensorDynLen { } /// Create a diagonal tensor from diagonal payload data with explicit indices. + /// + /// The public native bridge currently materializes diagonal payloads densely, so + /// the returned tensor is mathematically diagonal but may not report + /// [`TensorDynLen::is_diag`] at the native-storage level. pub fn from_diag(indices: Vec, data: Vec) -> Result { let dims = Self::expected_dims_from_indices(&indices); Self::validate_indices(&indices); @@ -1721,7 +1731,9 @@ impl TensorDynLen { self.native.scalar_type() == NativeScalarType::F64 } - /// Check if the tensor uses diagonal structured storage. + /// Check whether the tensor currently uses native diagonal structured storage. + /// + /// This is a storage-level predicate, not a semantic diagonality check. pub fn is_diag(&self) -> bool { self.native.is_diag() } diff --git a/crates/tensor4all-core/src/forward_ad.rs b/crates/tensor4all-core/src/forward_ad.rs deleted file mode 100644 index 9b346841..00000000 --- a/crates/tensor4all-core/src/forward_ad.rs +++ /dev/null @@ -1,54 +0,0 @@ -//! Forward-mode AD helpers for tensor4all tensors. - -use anyhow::{anyhow, ensure, Result}; - -use crate::TensorDynLen; - -/// Scoped forward-mode builder mirroring `tenferro::forward_ad::DualLevel`. -pub struct DualLevel<'a> { - inner: &'a tenferro::forward_ad::DualLevel, -} - -impl<'a> DualLevel<'a> { - /// Creates a dual tensor from a primal tensor and its tangent seed. - pub fn make_dual(&self, primal: &TensorDynLen, tangent: &TensorDynLen) -> Result { - ensure!( - primal.indices() == tangent.indices(), - "forward_ad::make_dual requires matching indices, got {:?} vs {:?}", - primal.indices(), - tangent.indices() - ); - let native = self - .inner - .make_dual(primal.as_native(), tangent.as_native()) - .map_err(|e| anyhow!("forward_ad::make_dual failed: {e}"))?; - TensorDynLen::from_native(primal.indices().to_vec(), native) - } - - /// Unpacks a dual tensor into its detached primal value and optional tangent. - pub fn unpack_dual( - &self, - value: &TensorDynLen, - ) -> Result<(TensorDynLen, Option)> { - let (primal, tangent) = self - .inner - .unpack_dual(value.as_native()) - .map_err(|e| anyhow!("forward_ad::unpack_dual failed: {e}"))?; - let primal = TensorDynLen::from_native(value.indices().to_vec(), primal)?; - let tangent = tangent - .map(|native| TensorDynLen::from_native(value.indices().to_vec(), native)) - .transpose()?; - Ok((primal, tangent)) - } -} - -/// Runs a scoped forward-mode computation. -pub fn dual_level(f: impl for<'a> FnOnce(&DualLevel<'a>) -> Result) -> Result { - tenferro::forward_ad::dual_level(|inner| { - let wrapper = DualLevel { inner }; - f(&wrapper).map_err(|e| tenferro::Error::InvalidAdTensor { - message: format!("tensor4all forward_ad wrapper failed: {e}"), - }) - }) - .map_err(|e| anyhow!("forward_ad::dual_level failed: {e}")) -} diff --git a/crates/tensor4all-core/src/lib.rs b/crates/tensor4all-core/src/lib.rs index 41b578f3..c9e3d21e 100644 --- a/crates/tensor4all-core/src/lib.rs +++ b/crates/tensor4all-core/src/lib.rs @@ -27,7 +27,6 @@ pub mod col_major_array; pub use col_major_array::{ColMajorArray, ColMajorArrayMut, ColMajorArrayRef}; // Common (tags, utilities, scalar) -pub mod forward_ad; pub mod global_default; pub mod index_like; pub mod scalar; @@ -88,8 +87,8 @@ pub use defaults::tensordynlen as tensor; pub use any_scalar::AnyScalar; pub use defaults::tensordynlen::{ - compute_permutation_from_indices, diag_tensor_dyn_len, is_diag_tensor, unfold_split, - RandomScalar, TensorAccess, TensorDynLen, + compute_permutation_from_indices, diag_tensor_dyn_len, unfold_split, RandomScalar, + TensorAccess, TensorDynLen, }; pub use storage::{make_mut_storage, mindim, Storage, StructuredStorage, SumFromStorage}; pub use tensor4all_tensorbackend::TensorElement; diff --git a/crates/tensor4all-core/tests/linalg_factorize.rs b/crates/tensor4all-core/tests/linalg_factorize.rs index 4ed86c7f..c44cc127 100644 --- a/crates/tensor4all-core/tests/linalg_factorize.rs +++ b/crates/tensor4all-core/tests/linalg_factorize.rs @@ -287,8 +287,10 @@ fn test_diag_dense_contraction_svd_internals() { let (u, s, v) = svd::(&tensor, std::slice::from_ref(&i)).expect("SVD should succeed"); - // Verify S is diagonal storage - assert!(s.is_diag()); + // The public bridge materializes diagonal payloads densely at the native layer. + assert!(!s.is_diag()); + assert_eq!(s.dims().len(), 2); + assert_eq!(s.dims()[0], s.dims()[1]); // Verify S and V share a common index let common_found = s diff --git a/crates/tensor4all-core/tests/tensor_basic.rs b/crates/tensor4all-core/tests/tensor_basic.rs index 5f3bcb73..3f16577b 100644 --- a/crates/tensor4all-core/tests/tensor_basic.rs +++ b/crates/tensor4all-core/tests/tensor_basic.rs @@ -22,6 +22,17 @@ fn make_tensor_c64(indices: Vec, data: Vec) -> TensorDynLen TensorDynLen::from_dense(indices, data).unwrap() } +fn assert_from_diag_dense_constructor_contract(data: Vec) +where + T: TensorElement, +{ + let i = Index::new_dyn(3); + let j = Index::new_dyn(3); + let tensor = TensorDynLen::from_diag(vec![i, j], data).unwrap(); + assert_eq!(tensor.dims(), vec![3, 3]); + assert!(!tensor.is_diag()); +} + #[test] fn test_storage_dense_f64() { // Create a zero-initialized tensor with 10 elements @@ -665,39 +676,18 @@ fn test_from_dense_generic_rejects_length_mismatch() { #[test] fn test_from_diag_generic_supports_all_supported_element_types() { - let i = Index::new_dyn(3); - let j = Index::new_dyn(3); - - assert!( - TensorDynLen::from_diag(vec![i.clone(), j.clone()], vec![1.0_f32, 2.0, 3.0]) - .unwrap() - .is_diag() - ); - assert!( - TensorDynLen::from_diag(vec![i.clone(), j.clone()], vec![1.0_f64, 2.0, 3.0]) - .unwrap() - .is_diag() - ); - assert!(TensorDynLen::from_diag( - vec![i.clone(), j.clone()], - vec![ - Complex32::new(1.0, 0.0), - Complex32::new(2.0, 0.5), - Complex32::new(3.0, -0.5), - ], - ) - .unwrap() - .is_diag()); - assert!(TensorDynLen::from_diag( - vec![i, j], - vec![ - Complex64::new(1.0, 0.0), - Complex64::new(2.0, 0.5), - Complex64::new(3.0, -0.5), - ], - ) - .unwrap() - .is_diag()); + assert_from_diag_dense_constructor_contract::(vec![1.0, 2.0, 3.0]); + assert_from_diag_dense_constructor_contract::(vec![1.0, 2.0, 3.0]); + assert_from_diag_dense_constructor_contract::(vec![ + Complex32::new(1.0, 0.0), + Complex32::new(2.0, 0.5), + Complex32::new(3.0, -0.5), + ]); + assert_from_diag_dense_constructor_contract::(vec![ + Complex64::new(1.0, 0.0), + Complex64::new(2.0, 0.5), + Complex64::new(3.0, -0.5), + ]); } #[test] diff --git a/crates/tensor4all-core/tests/tensor_diag.rs b/crates/tensor4all-core/tests/tensor_diag.rs index f5c54901..ccdcd5bc 100644 --- a/crates/tensor4all-core/tests/tensor_diag.rs +++ b/crates/tensor4all-core/tests/tensor_diag.rs @@ -1,7 +1,7 @@ use num_complex::Complex64; use tensor4all_core::index::DefaultIndex as Index; use tensor4all_core::TensorLike; -use tensor4all_core::{diag_tensor_dyn_len, is_diag_tensor, AnyScalar, TensorDynLen}; +use tensor4all_core::{diag_tensor_dyn_len, AnyScalar, TensorDynLen}; #[test] fn test_diag_tensor_creation() { @@ -11,7 +11,15 @@ fn test_diag_tensor_creation() { let tensor = diag_tensor_dyn_len(vec![i.clone(), j.clone()], diag_data.clone()); assert_eq!(tensor.dims(), vec![3, 3]); - assert!(is_diag_tensor(&tensor)); + assert!(!tensor.is_diag()); + assert_eq!( + tensor.to_vec::().unwrap(), + vec![ + 1.0, 0.0, 0.0, // + 0.0, 2.0, 0.0, // + 0.0, 0.0, 3.0, + ] + ); } #[test] @@ -37,14 +45,14 @@ fn test_diag_tensor_sum() { } #[test] -fn test_diag_tensor_scale_preserves_diag_structure() { +fn test_diag_tensor_scale_preserves_diagonal_values() { let i = Index::new_dyn(3); let j = Index::new_dyn(3); let tensor = diag_tensor_dyn_len(vec![i.clone(), j.clone()], vec![1.0, -2.0, 4.0]); let scaled = tensor.scale(AnyScalar::new_real(-0.5)).unwrap(); - assert!(is_diag_tensor(&scaled)); + assert!(!scaled.is_diag()); let expected = diag_tensor_dyn_len(vec![i, j], vec![-0.5, 1.0, -2.0]); assert!(scaled.isapprox(&expected, 1e-12, 0.0)); } @@ -100,7 +108,7 @@ fn test_diag_tensor_contract_diag_diag_partial() { let result = tensor_a.contract(&tensor_b); assert_eq!(result.dims(), vec![3, 3]); - assert!(is_diag_tensor(&result)); + assert!(!result.is_diag()); // Result diagonal should be element-wise product: [1*4, 2*5, 3*6] = [4, 10, 18] let expected = diag_tensor_dyn_len(vec![i, k], vec![4.0, 10.0, 18.0]); @@ -160,7 +168,7 @@ fn test_diag_tensor_rank3() { let tensor = diag_tensor_dyn_len(vec![i.clone(), j.clone(), k.clone()], diag_data.clone()); assert_eq!(tensor.dims(), vec![2, 2, 2]); - assert!(is_diag_tensor(&tensor)); + assert!(!tensor.is_diag()); // Sum should work let sum: AnyScalar = tensor.sum(); @@ -176,7 +184,16 @@ fn test_diag_tensor_complex() { let tensor = TensorDynLen::from_diag(vec![i.clone(), j.clone()], diag_data.clone()).unwrap(); assert_eq!(tensor.dims(), vec![2, 2]); - assert!(is_diag_tensor(&tensor)); + assert!(!tensor.is_diag()); + assert_eq!( + tensor.to_vec::().unwrap(), + vec![ + Complex64::new(1.0, 0.5), + Complex64::new(0.0, 0.0), + Complex64::new(0.0, 0.0), + Complex64::new(2.0, 1.0), + ] + ); // Sum should work let sum: AnyScalar = tensor.sum(); @@ -187,7 +204,7 @@ fn test_diag_tensor_complex() { } #[test] -fn test_diag_tensor_complex_axpby_preserves_diag_structure() { +fn test_diag_tensor_complex_axpby_preserves_diagonal_values() { let i = Index::new_dyn(2); let j = Index::new_dyn(2); let diag_a = vec![Complex64::new(1.0, 0.5), Complex64::new(-2.0, 1.0)]; @@ -200,7 +217,7 @@ fn test_diag_tensor_complex_axpby_preserves_diag_structure() { let b = AnyScalar::new_complex(-0.5, 1.0); let result = tensor_a.axpby(a, &tensor_b, b).unwrap(); - assert!(is_diag_tensor(&result)); + assert!(!result.is_diag()); let b_c = Complex64::new(-0.5, 1.0); let expected_diag: Vec = diag_a .iter() @@ -228,7 +245,7 @@ fn test_diag_tensor_contract_rank3() { let result = tensor_a.contract(&tensor_b); assert_eq!(result.dims(), vec![2, 2, 2]); - assert!(is_diag_tensor(&result)); + assert!(!result.is_diag()); // Result diagonal should be element-wise product: [1*3, 2*4] = [3, 8] let expected = diag_tensor_dyn_len(vec![i, j, l], vec![3.0, 8.0]); diff --git a/crates/tensor4all-core/tests/tensor_native_ad.rs b/crates/tensor4all-core/tests/tensor_native_ad.rs index f2916f4f..ce8ed52a 100644 --- a/crates/tensor4all-core/tests/tensor_native_ad.rs +++ b/crates/tensor4all-core/tests/tensor_native_ad.rs @@ -1,190 +1,10 @@ -use num_complex::Complex64; -use tenferro::AdMode; -use tensor4all_core::{ - factorize, forward_ad, is_diag_tensor, qr, svd, AnyScalar, Canonical, FactorizeOptions, Index, - Storage, TensorDynLen, -}; +use tensor4all_core::{contract_multi, AllowedPairs, Index, Storage, TensorDynLen}; -fn forward_tensor(primal: TensorDynLen, tangent: TensorDynLen) -> TensorDynLen { - forward_ad::dual_level(|fw| fw.make_dual(&primal, &tangent)).unwrap() -} - -fn assert_same_tensor_data(lhs: &TensorDynLen, rhs: &TensorDynLen) { - assert_eq!(lhs.dims(), rhs.dims()); - assert_eq!(lhs.is_diag(), rhs.is_diag()); - assert_eq!(lhs.is_f64(), rhs.is_f64()); - assert_eq!(lhs.is_complex(), rhs.is_complex()); - if lhs.is_f64() { - assert_eq!(lhs.to_vec::().unwrap(), rhs.to_vec::().unwrap()); - } else { - assert_eq!( - lhs.to_vec::().unwrap(), - rhs.to_vec::().unwrap() - ); - } -} - -#[test] -fn sum_preserves_forward_payload_via_dual_level() { - let i = Index::new_dyn(2); - let tensor = forward_tensor( - TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap(), - TensorDynLen::from_dense(vec![i], vec![0.25, -0.75]).unwrap(), - ); - - let sum = tensor.sum(); - - assert_eq!(sum.mode(), AdMode::Forward); - assert_eq!(sum.primal().as_f64(), Some(3.0)); - assert_eq!(sum.tangent().and_then(|x| x.as_f64()), Some(-0.5)); -} - -#[test] -fn only_preserves_forward_payload_via_dual_level() { - let tensor = forward_tensor( - TensorDynLen::from_dense(vec![], vec![2.5]).unwrap(), - TensorDynLen::from_dense(vec![], vec![0.75]).unwrap(), - ); - - let only = tensor.only(); - - assert_eq!(only.mode(), AdMode::Forward); - assert_eq!(only.primal().as_f64(), Some(2.5)); - assert_eq!(only.tangent().and_then(|x| x.as_f64()), Some(0.75)); -} - -#[test] -fn inner_product_preserves_forward_payload_via_dual_level() { - let i = Index::new_dyn(2); - let lhs = forward_tensor( - TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap(), - TensorDynLen::from_dense(vec![i.clone()], vec![0.1, 0.2]).unwrap(), - ); - let rhs = forward_tensor( - TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap(), - TensorDynLen::from_dense(vec![i], vec![1.0, -1.0]).unwrap(), - ); - - let inner = lhs.inner_product(&rhs).unwrap(); - - assert_eq!(inner.mode(), AdMode::Forward); - assert_eq!(inner.primal().as_f64(), Some(11.0)); - let tangent = inner.tangent().and_then(|x| x.as_f64()).unwrap(); - assert!( - (tangent - 0.1).abs() < 1e-12, - "unexpected tangent: {tangent}" - ); -} - -#[test] -fn qr_preserves_forward_payload_via_dual_level() { - let i = Index::new_dyn(2); - let j = Index::new_dyn(2); - let tensor = forward_tensor( - TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![3.0, 0.0, 0.0, 2.0]).unwrap(), - TensorDynLen::from_dense(vec![i.clone(), j], vec![0.5, 0.0, 0.0, -0.25]).unwrap(), - ); - - let (q, r) = qr::(&tensor, std::slice::from_ref(&i)).unwrap(); - - assert_eq!(q.mode(), AdMode::Forward); - assert_eq!(r.mode(), AdMode::Forward); - assert!( - (q.sum() - .tangent() - .and_then(|x| x.as_f64()) - .unwrap_or_default()) - .abs() - < 1e-12 - ); - let r_tangent = r.sum().tangent().and_then(|x| x.as_f64()).unwrap(); - assert!( - (r_tangent - 0.25).abs() < 1e-12, - "unexpected QR tangent: {r_tangent}" - ); -} - -#[test] -fn svd_preserves_forward_payload_via_dual_level() { - let i = Index::new_dyn(2); - let j = Index::new_dyn(2); - let tensor = forward_tensor( - TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![3.0, 0.0, 0.0, 2.0]).unwrap(), - TensorDynLen::from_dense(vec![i.clone(), j], vec![0.5, 0.0, 0.0, -0.25]).unwrap(), - ); - - let (u, s, v) = svd::(&tensor, std::slice::from_ref(&i)).unwrap(); - - assert_eq!(u.mode(), AdMode::Forward); - assert_eq!(s.sum().mode(), AdMode::Forward); - assert_eq!(v.mode(), AdMode::Forward); - let s_tangent = s.sum().tangent().and_then(|x| x.as_f64()).unwrap(); - assert!( - (s_tangent - 0.25).abs() < 1e-12, - "unexpected SVD tangent: {s_tangent}" - ); -} - -#[test] -fn factorize_svd_preserves_forward_payload_via_dual_level() { - let i = Index::new_dyn(2); - let j = Index::new_dyn(2); - let tensor = forward_tensor( - TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![3.0, 0.0, 0.0, 2.0]).unwrap(), - TensorDynLen::from_dense(vec![i.clone(), j], vec![0.5, 0.0, 0.0, -0.25]).unwrap(), - ); - - let result = factorize( - &tensor, - std::slice::from_ref(&i), - &FactorizeOptions::svd().with_canonical(Canonical::Left), - ) - .unwrap(); - - assert_eq!(result.left.mode(), AdMode::Forward); - assert_eq!(result.right.mode(), AdMode::Forward); - let right_tangent = result - .right - .sum() - .tangent() - .and_then(|x| x.as_f64()) - .unwrap(); - assert!( - (right_tangent - 0.25).abs() < 1e-12, - "unexpected factorize tangent: {right_tangent}" - ); -} - -#[test] -fn forward_ad_unpack_dual_restores_primal_and_tangent() { - let i = Index::new_dyn(2); - let primal = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap(); - let tangent = TensorDynLen::from_dense(vec![i], vec![0.5, -0.25]).unwrap(); - - let (unpacked_primal, unpacked_tangent) = forward_ad::dual_level(|fw| { - let dual = fw.make_dual(&primal, &tangent)?; - fw.unpack_dual(&dual) - }) - .unwrap(); - - assert_same_tensor_data(&unpacked_primal, &primal); - assert_same_tensor_data(&unpacked_tangent.unwrap(), &tangent); -} - -#[test] -fn rank1_native_snapshots_stay_dense() { - let i = Index::new_dyn(3); - let tensor = forward_tensor( - TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0, 3.0]).unwrap(), - TensorDynLen::from_dense(vec![i], vec![0.0, 0.0, 0.0]).unwrap(), - ); - - let scaled = tensor.scale(AnyScalar::new_real(2.0)).unwrap(); - let snapshot = scaled.storage(); - - assert!(snapshot.is_dense()); - assert!(!snapshot.is_diag()); - assert_eq!(scaled.to_vec::().unwrap(), vec![2.0, 4.0, 6.0]); +fn with_runtime(f: impl FnOnce() -> R) -> R { + let _guard = tenferro::set_default_runtime(tenferro::RuntimeContext::Cpu( + tenferro_prims::CpuContext::new(1), + )); + f() } #[test] @@ -198,12 +18,12 @@ fn plain_dense_storage_auto_seeds_native_payload() { ) .unwrap(); - assert_eq!(tensor.mode(), AdMode::Primal); assert_eq!(tensor.to_vec::().unwrap(), vec![1.0, 2.0]); + assert!(!tensor.requires_grad()); } #[test] -fn plain_diag_storage_auto_seeds_native_diag_payload() { +fn plain_diag_storage_auto_seeds_native_dense_payload() { let i = Index::new_dyn(3); let j = Index::new_dyn(3); let tensor = TensorDynLen::from_storage( @@ -214,21 +34,15 @@ fn plain_diag_storage_auto_seeds_native_diag_payload() { ) .unwrap(); - assert_eq!(tensor.mode(), AdMode::Primal); - assert!(is_diag_tensor(&tensor)); -} - -// --------------------------------------------------------------------------- -// Backward (reverse-mode) AD tests -// --------------------------------------------------------------------------- - -use tensor4all_core::{contract_multi, AllowedPairs}; - -fn with_runtime(f: impl FnOnce() -> R) -> R { - let _guard = tenferro::set_default_runtime(tenferro::RuntimeContext::Cpu( - tenferro_prims::CpuContext::new(1), - )); - f() + assert_eq!( + tensor.to_vec::().unwrap(), + vec![ + 1.0, 0.0, 0.0, // + 0.0, 2.0, 0.0, // + 0.0, 0.0, 3.0, + ] + ); + assert!(!tensor.is_diag()); } #[test] @@ -241,16 +55,15 @@ fn backward_ad_contraction_accumulates_gradient() { let ones = TensorDynLen::from_dense(vec![i], vec![1.0, 1.0, 1.0]).unwrap(); - // f = contract(a, ones) = dot product = 1+2+3 = 6; df/da = [1,1,1] let result = contract_multi(&[&a, &ones], AllowedPairs::All).unwrap(); assert!( result.indices().is_empty(), "contraction result should be rank-0" ); - result.backward(None, &[&a]).unwrap(); + result.backward(None).unwrap(); - let grad = a.grad().expect("gradient missing after backward"); + let grad = a.grad().unwrap().expect("gradient missing after backward"); let grad_vec = grad.to_vec::().unwrap(); for (j, &g) in grad_vec.iter().enumerate() { assert!((g - 1.0).abs() < 1e-10, "grad[{j}] = {g}, expected 1.0"); @@ -265,17 +78,15 @@ fn backward_ad_gradient_matches_finite_diff() { let j = Index::new_dyn(2); let eps = 1e-6; - // f(A) = contract(A, A) over all indices = sum(|a_ij|^2) - // df/da_ij = 2 * conj(a_ij) = 2 * a_ij (real case) let data = vec![1.0, 2.0, 3.0, 4.0]; let mut a = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data.clone()).unwrap(); a.set_requires_grad(true).unwrap(); let result = contract_multi(&[&a, &a], AllowedPairs::All).unwrap(); - result.backward(None, &[&a]).unwrap(); + result.backward(None).unwrap(); - let grad = a.grad().expect("gradient missing"); + let grad = a.grad().unwrap().expect("gradient missing"); let grad_vec = grad.to_vec::().unwrap(); for idx in 0..data.len() { diff --git a/crates/tensor4all-tensorbackend/src/any_scalar.rs b/crates/tensor4all-tensorbackend/src/any_scalar.rs index 37abaf3f..ded606d8 100644 --- a/crates/tensor4all-tensorbackend/src/any_scalar.rs +++ b/crates/tensor4all-tensorbackend/src/any_scalar.rs @@ -5,11 +5,11 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use anyhow::{anyhow, Result}; use num_complex::{Complex32, Complex64}; use num_traits::{One, Zero}; -use tenferro::{AdMode, ScalarType, ScalarValue as TfScalarValue, Tensor as NativeTensor}; +use tenferro::{ScalarType, Tensor as NativeTensor}; use crate::storage::{Storage, SumFromStorage}; +use crate::tenferro_bridge::with_default_runtime; use crate::tensor_element::TensorElement; -use crate::{tangent_native_tensor, tenferro_bridge::with_default_runtime}; #[derive(Clone, Copy, Debug, PartialEq)] enum ScalarValue { @@ -83,14 +83,27 @@ fn scalar_value_from_storage(storage: &Storage) -> ScalarValue { } fn scalar_value_from_native(native: &NativeTensor) -> ScalarValue { - match native - .try_scalar_value() - .unwrap_or_else(|e| panic!("failed to read scalar tensor value: {e}")) - { - TfScalarValue::F32(value) => ScalarValue::F32(value), - TfScalarValue::F64(value) => ScalarValue::F64(value), - TfScalarValue::C32(value) => ScalarValue::C32(value), - TfScalarValue::C64(value) => ScalarValue::C64(value), + match native.scalar_type() { + ScalarType::F32 => ScalarValue::F32( + native + .try_get::(&[]) + .unwrap_or_else(|e| panic!("failed to read f32 scalar tensor value: {e}")), + ), + ScalarType::F64 => ScalarValue::F64( + native + .try_get::(&[]) + .unwrap_or_else(|e| panic!("failed to read f64 scalar tensor value: {e}")), + ), + ScalarType::C32 => ScalarValue::C32( + native + .try_get::(&[]) + .unwrap_or_else(|e| panic!("failed to read c32 scalar tensor value: {e}")), + ), + ScalarType::C64 => ScalarValue::C64( + native + .try_get::(&[]) + .unwrap_or_else(|e| panic!("failed to read c64 scalar tensor value: {e}")), + ), } } @@ -110,17 +123,54 @@ fn scalar_native_op( } fn neg_native(native: &NativeTensor) -> Result { - scalar_native_op("neg", || native.scale(&rank0_real_tensor(-1.0))) + Ok(match scalar_value_from_native(native) { + ScalarValue::F32(value) => Scalar::from_value(-value).native, + ScalarValue::F64(value) => Scalar::from_value(-value).native, + ScalarValue::C32(value) => Scalar::from_value(-value).native, + ScalarValue::C64(value) => Scalar::from_value(-value).native, + }) } -fn promote_scalar_native(native: &NativeTensor, target: ScalarType) -> Result { - scalar_native_op("to_scalar_type", || native.to_scalar_type(target)) +pub(crate) fn promote_scalar_native( + native: &NativeTensor, + target: ScalarType, +) -> Result { + let promoted = match (scalar_value_from_native(native), target) { + (ScalarValue::F32(value), ScalarType::F32) => Scalar::from_value(value), + (ScalarValue::F32(value), ScalarType::F64) => Scalar::from_value(value as f64), + (ScalarValue::F32(value), ScalarType::C32) => { + Scalar::from_value(Complex32::new(value, 0.0)) + } + (ScalarValue::F32(value), ScalarType::C64) => { + Scalar::from_value(Complex64::new(value as f64, 0.0)) + } + (ScalarValue::F64(value), ScalarType::F32) => Scalar::from_value(value as f32), + (ScalarValue::F64(value), ScalarType::F64) => Scalar::from_value(value), + (ScalarValue::F64(value), ScalarType::C32) => { + Scalar::from_value(Complex32::new(value as f32, 0.0)) + } + (ScalarValue::F64(value), ScalarType::C64) => { + Scalar::from_value(Complex64::new(value, 0.0)) + } + (ScalarValue::C32(value), ScalarType::F32) => Scalar::from_value(value.re), + (ScalarValue::C32(value), ScalarType::F64) => Scalar::from_value(value.re as f64), + (ScalarValue::C32(value), ScalarType::C32) => Scalar::from_value(value), + (ScalarValue::C32(value), ScalarType::C64) => { + Scalar::from_value(Complex64::new(value.re as f64, value.im as f64)) + } + (ScalarValue::C64(value), ScalarType::F32) => Scalar::from_value(value.re as f32), + (ScalarValue::C64(value), ScalarType::F64) => Scalar::from_value(value.re), + (ScalarValue::C64(value), ScalarType::C32) => { + Scalar::from_value(Complex32::new(value.re as f32, value.im as f32)) + } + (ScalarValue::C64(value), ScalarType::C64) => Scalar::from_value(value), + }; + Ok(promoted.native) } /// Dynamic scalar used across tensor4all backends. /// /// This is a tensor4all-owned rank-0 wrapper over tenferro's dynamic tensor. -#[derive(Clone)] pub struct Scalar { native: NativeTensor, } @@ -179,22 +229,12 @@ impl Scalar { Self::from_complex(re, im) } - /// Returns the upstream AD mode. - pub fn mode(&self) -> AdMode { - self.native.mode() - } - /// Returns the detached primal value as a scalar. pub fn primal(&self) -> Self { Self::wrap_native(self.native.detach()) .unwrap_or_else(|e| panic!("Scalar::primal returned a non-scalar tensor: {e}")) } - /// Returns the detached forward tangent when present. - pub fn tangent(&self) -> Option { - tangent_native_tensor(&self.native).and_then(|native| Self::from_native(native).ok()) - } - /// Returns whether the scalar participates in reverse-mode AD. pub fn requires_grad(&self) -> bool { self.native.requires_grad() @@ -202,14 +242,15 @@ impl Scalar { /// Enables or disables reverse-mode gradient tracking. pub fn set_requires_grad(&mut self, enabled: bool) -> Result<()> { - self.native - .set_requires_grad(enabled) - .map_err(|e| anyhow!("Scalar::set_requires_grad failed: {e}")) + let placeholder = rank0_real_tensor(0.0); + let native = std::mem::replace(&mut self.native, placeholder); + self.native = native.requires_grad_(enabled); + Ok(()) } /// Returns accumulated reverse-mode gradient when available. pub fn grad(&self) -> Option { - self.native.grad().map(|native| { + self.native.grad().ok().flatten().map(|native| { Self::wrap_native(native) .unwrap_or_else(|e| panic!("Scalar::grad returned a non-scalar tensor: {e}")) }) @@ -224,17 +265,16 @@ impl Scalar { /// Accumulates reverse-mode gradients into `inputs`. pub fn backward(&self, grad_output: Option<&Self>, inputs: &[&Self]) -> Result<()> { - let grad_output_native = grad_output.map(|value| value.as_native()); - let input_native: Vec<&NativeTensor> = - inputs.iter().map(|value| value.as_native()).collect(); - with_default_runtime("backward", || { - self.native - .backward( - grad_output_native, - &input_native, - tenferro::BackwardOptions::default(), - ) - .map_err(|e| anyhow!("Scalar::backward failed: {e}")) + let _ = inputs; + with_default_runtime("backward", || match grad_output { + Some(seed) => self + .native + .backward_with_seed(seed.as_native()) + .map_err(|e| anyhow!("Scalar::backward failed: {e}")), + None => self + .native + .backward() + .map_err(|e| anyhow!("Scalar::backward failed: {e}")), }) } @@ -288,59 +328,64 @@ impl Scalar { /// Returns the complex conjugate. pub fn conj(&self) -> Self { - Self::wrap_native(self.native.conj()) - .unwrap_or_else(|e| panic!("Scalar::conj returned a non-scalar tensor: {e}")) + match self.value() { + ScalarValue::F32(value) => Self::from_value(value), + ScalarValue::F64(value) => Self::from_value(value), + ScalarValue::C32(value) => Self::from_value(value.conj()), + ScalarValue::C64(value) => Self::from_value(value.conj()), + } } /// Returns the real part as a scalar, preserving scalar semantics. pub fn real_part(&self) -> Self { - scalar_tensor_result( - "real_part", - scalar_native_op("real_part", || self.native.real_part()), - ) + Self::from_real(self.real()) } /// Returns the imaginary part as a scalar, preserving scalar semantics. pub fn imag_part(&self) -> Self { - scalar_tensor_result( - "imag_part", - scalar_native_op("imag_part", || self.native.imag_part()), - ) + Self::from_real(self.imag()) } /// Compose a complex scalar from real-valued parts. pub fn compose_complex(real: Self, imag: Self) -> Result { - let native = scalar_native_op("compose_complex", || { - NativeTensor::compose_complex(real.native, imag.native) - }) - .map_err(|e| anyhow!("Scalar::compose_complex failed: {e}"))?; - Self::wrap_native(native) + if !real.is_real() || !imag.is_real() { + return Err(anyhow!( + "compose_complex requires real-valued inputs, got real={:?}, imag={:?}", + real.native.scalar_type(), + imag.native.scalar_type() + )); + } + Ok(Self::from_complex(real.real(), imag.real())) } /// Square root, preserving AD metadata. pub fn sqrt(&self) -> Self { - let native = if self.is_complex() || self.real() < 0.0 { - let promoted = promote_scalar_native(&self.native, ScalarType::C64); - promoted.and_then(|value| scalar_native_op("sqrt", || value.sqrt())) + if self.is_complex() || self.real() < 0.0 { + let value = self.value().into_complex().sqrt(); + if value.im == 0.0 { + Self::from_real(value.re) + } else { + Self::from_value(value) + } } else { - scalar_native_op("sqrt", || self.native.sqrt()) - }; - scalar_tensor_result("sqrt", native) + Self::from_real(self.real().sqrt()) + } } /// Real exponent power, preserving AD metadata. pub fn powf(&self, exponent: f64) -> Self { let needs_complex_promotion = self.is_complex() || (self.real() < 0.0 && exponent.fract() != 0.0); - let native = if needs_complex_promotion { - let promoted = promote_scalar_native(&self.native, ScalarType::C64); - promoted.and_then(|value| { - scalar_native_op("powf", || value.pow(&rank0_real_tensor(exponent))) - }) + if needs_complex_promotion { + let value = self.value().into_complex().powf(exponent); + if value.im == 0.0 { + Self::from_real(value.re) + } else { + Self::from_value(value) + } } else { - scalar_native_op("powf", || self.native.pow(&rank0_real_tensor(exponent))) - }; - scalar_tensor_result("powf", native) + Self::from_real(self.real().powf(exponent)) + } } /// Integer exponent power, preserving AD metadata. @@ -408,10 +453,13 @@ impl Add for Scalar { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - scalar_tensor_result( - "add", - scalar_native_op("add", || self.native.add(&rhs.native)), - ) + match (self.value(), rhs.value()) { + (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs + rhs), + (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => { + Self::from_value(lhs.into_complex() + rhs.into_complex()) + } + (lhs, rhs) => Self::from_real(lhs.real() + rhs.real()), + } } } @@ -419,9 +467,7 @@ impl Sub for Scalar { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - let neg_rhs = neg_native(&rhs.native) - .unwrap_or_else(|err| panic!("Scalar::sub failed while negating rhs: {err}")); - scalar_tensor_result("sub", scalar_native_op("sub", || self.native.add(&neg_rhs))) + self + (-rhs) } } @@ -429,10 +475,13 @@ impl Mul for Scalar { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - scalar_tensor_result( - "mul", - scalar_native_op("mul", || self.native.scale(&rhs.native)), - ) + match (self.value(), rhs.value()) { + (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs * rhs), + (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => { + Self::from_value(lhs.into_complex() * rhs.into_complex()) + } + (lhs, rhs) => Self::from_real(lhs.real() * rhs.real()), + } } } @@ -440,10 +489,13 @@ impl Div for Scalar { type Output = Self; fn div(self, rhs: Self) -> Self::Output { - scalar_tensor_result( - "div", - scalar_native_op("div", || self.native.div_scalar(&rhs.native)), - ) + match (self.value(), rhs.value()) { + (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => Self::from_value(lhs / rhs), + (lhs, rhs) if lhs.is_complex() || rhs.is_complex() => { + Self::from_value(lhs.into_complex() / rhs.into_complex()) + } + (lhs, rhs) => Self::from_real(lhs.real() / rhs.real()), + } } } @@ -503,9 +555,7 @@ impl One for Scalar { impl PartialEq for Scalar { fn eq(&self, other: &Self) -> bool { - self.mode() == other.mode() - && self.native.scalar_type() == other.native.scalar_type() - && self.value() == other.value() + self.native.scalar_type() == other.native.scalar_type() && self.value() == other.value() } } @@ -535,12 +585,17 @@ impl fmt::Display for Scalar { impl fmt::Debug for Scalar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Scalar") - .field("mode", &self.mode()) .field("scalar_type", &self.native.scalar_type()) .field("value", &self.value()) .finish() } } +impl Clone for Scalar { + fn clone(&self) -> Self { + self.primal() + } +} + #[cfg(test)] mod tests; diff --git a/crates/tensor4all-tensorbackend/src/storage.rs b/crates/tensor4all-tensorbackend/src/storage.rs index ad98e0bb..e67a4609 100644 --- a/crates/tensor4all-tensorbackend/src/storage.rs +++ b/crates/tensor4all-tensorbackend/src/storage.rs @@ -352,6 +352,48 @@ impl StructuredStorage { } } +impl StructuredStorage { + /// Materializes the logical tensor as a contiguous column-major dense buffer. + /// + /// Repeated entries in `axis_classes` encode equality constraints between + /// logical axes. Logical indices that violate those constraints are + /// structural zeros in the dense materialization. + pub fn logical_dense_col_major_vec(&self) -> Vec { + let logical_dims = self.logical_dims(); + let logical_len: usize = logical_dims.iter().product(); + if logical_len == 0 { + return Vec::new(); + } + if let Some(view) = self.dense_col_major_view_if_contiguous() { + return view.to_vec(); + } + if self.is_dense() { + return self.payload_col_major_vec(); + } + + let payload_rank = self.payload_dims.len(); + (0..logical_len) + .map(|linear| { + let logical_index = col_major_multi_index(linear, &logical_dims); + let mut payload_index = vec![0usize; payload_rank]; + let mut seen = vec![false; payload_rank]; + for (&value, &class_id) in logical_index.iter().zip(self.axis_classes.iter()) { + if seen[class_id] { + if payload_index[class_id] != value { + return T::default(); + } + } else { + payload_index[class_id] = value; + seen[class_id] = true; + } + } + let offset = offset_from_strides(&payload_index, &self.strides); + self.data[offset] + }) + .collect() + } +} + /// Storage backend for tensor data. /// /// Public callers interact with this opaque wrapper through constructors and @@ -649,17 +691,7 @@ impl Storage { logical_dims, structured_dims )); } - if let Some(view) = v.dense_col_major_view_if_contiguous() { - Ok(view.to_vec()) - } else if v.is_dense() { - Ok(v.payload_col_major_vec()) - } else { - let native = - crate::tenferro_bridge::storage_to_native_tensor(self, logical_dims) - .map_err(|err| err.to_string())?; - crate::tenferro_bridge::native_tensor_primal_to_dense_f64_col_major(&native) - .map_err(|err| err.to_string()) - } + Ok(v.logical_dense_col_major_vec()) } StorageRepr::C64(_) => { Err("expected f64 storage when materializing dense f64 values".to_string()) @@ -681,17 +713,7 @@ impl Storage { logical_dims, structured_dims )); } - if let Some(view) = v.dense_col_major_view_if_contiguous() { - Ok(view.to_vec()) - } else if v.is_dense() { - Ok(v.payload_col_major_vec()) - } else { - let native = - crate::tenferro_bridge::storage_to_native_tensor(self, logical_dims) - .map_err(|err| err.to_string())?; - crate::tenferro_bridge::native_tensor_primal_to_dense_c64_col_major(&native) - .map_err(|err| err.to_string()) - } + Ok(v.logical_dense_col_major_vec()) } StorageRepr::F64(_) => { Err("expected Complex64 storage when materializing dense c64 values".to_string()) diff --git a/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs b/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs index 0c0eb34c..1db15d50 100644 --- a/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs +++ b/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs @@ -10,20 +10,14 @@ use std::time::{Duration, Instant}; use anyhow::{anyhow, Result}; use num_complex::Complex64; -use tenferro::{set_default_runtime, snapshot, RuntimeContext, Tensor as NativeTensor}; -use tenferro_algebra::{Conjugate, Scalar as TfScalar, Standard}; -use tenferro_device::LogicalMemorySpace; -use tenferro_einsum::{ - einsum_binary_with_subscripts as tenferro_einsum_binary_with_subscripts, - einsum_with_subscripts as tenferro_einsum_with_subscripts, Subscripts, -}; -use tenferro_linalg::{qr as tenferro_qr, svd as tenferro_svd}; -use tenferro_prims::{CpuBackend, CpuContext}; +use tenferro::{set_default_runtime, RuntimeContext, ScalarType, Tensor as NativeTensor}; +use tenferro_prims::CpuContext; use tenferro_tensor::{MemoryOrder, Tensor as TypedTensor}; +use crate::any_scalar::promote_scalar_native; #[cfg(test)] use crate::storage::StorageRepr; -use crate::storage::{col_major_strides, NativePayload, Storage}; +use crate::storage::{NativePayload, Storage}; use crate::tensor_element::TensorElement; use crate::AnyScalar; #[cfg(test)] @@ -51,8 +45,6 @@ struct CachedCpuContextLease { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] enum NativeEinsumPath { - TypedBinaryEinsum, - TypedNaryEinsum, FrontendFallback, } @@ -62,7 +54,6 @@ struct NativeOperandSignature { ids: Vec, is_dense: bool, is_diag: bool, - is_primal: bool, } #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -176,17 +167,11 @@ pub(crate) fn set_native_einsum_profile_enabled_for_tests(enabled: bool) { } fn native_operand_signature(tensor: &NativeTensor, ids: &[u32]) -> NativeOperandSignature { - let is_primal = match tensor { - NativeTensor::F64(value) => value.mode() == tenferro::AdMode::Primal, - NativeTensor::C64(value) => value.mode() == tenferro::AdMode::Primal, - _ => false, - }; NativeOperandSignature { dims: tensor.dims().to_vec(), ids: ids.to_vec(), is_dense: tensor.is_dense(), is_diag: tensor.is_diag(), - is_primal, } } @@ -248,8 +233,8 @@ pub fn print_and_reset_native_einsum_profile() { ); for operand in signature.operands { eprintln!( - " dims={:?} ids={:?} dense={} diag={} primal={}", - operand.dims, operand.ids, operand.is_dense, operand.is_diag, operand.is_primal, + " dims={:?} ids={:?} dense={} diag={}", + operand.dims, operand.ids, operand.is_dense, operand.is_diag, ); } } @@ -349,12 +334,8 @@ fn native_tensor_from_f64_payload(payload: NativePayload) -> Result NativeTensor::with_axis_classes(native, &axis_classes) - .map_err(|e| anyhow!("failed to build structured f64 tensor from storage: {e}")), - None => Ok(native), - } + let _ = payload.axis_classes; + Ok(NativeTensor::from(typed)) } fn native_tensor_from_c64_payload(payload: NativePayload) -> Result { @@ -364,12 +345,8 @@ fn native_tensor_from_c64_payload(payload: NativePayload) -> Result NativeTensor::with_axis_classes(native, &axis_classes) - .map_err(|e| anyhow!("failed to build structured c64 tensor from storage: {e}")), - None => Ok(native), - } + let _ = payload.axis_classes; + Ok(NativeTensor::from(typed)) } /// Build a native dense tensor from column-major boundary data. @@ -388,108 +365,6 @@ pub fn diag_native_tensor_from_col_major( T::diag_native_tensor_from_col_major(data, logical_rank) } -fn dense_f64_storage_from_col_major( - tensor: &TypedTensor, - logical_dims: &[usize], -) -> Result { - let data = materialize_col_major_values(tensor, "f64 dense snapshot materialization")?; - Storage::from_dense_col_major(data, logical_dims) -} - -fn dense_c64_storage_from_col_major( - tensor: &TypedTensor, - logical_dims: &[usize], -) -> Result { - let data = materialize_col_major_values(tensor, "c64 dense snapshot materialization")?; - Storage::from_dense_col_major(data, logical_dims) -} - -fn materialize_col_major_values(tensor: &TypedTensor, op: &'static str) -> Result> -where - T: TfScalar + Conjugate + Copy, -{ - let col_major = tensor.contiguous(MemoryOrder::ColumnMajor); - let is_conjugated = col_major.is_conjugated(); - let col_major = if col_major.logical_memory_space() == LogicalMemorySpace::MainMemory { - col_major - } else { - col_major - .to_memory_space_async(LogicalMemorySpace::MainMemory) - .map_err(|e| anyhow!("{op}: failed to move tensor to host memory: {e}"))? - }; - let offset = usize::try_from(col_major.offset()) - .map_err(|_| anyhow!("{op}: negative offset {}", col_major.offset()))?; - let len = col_major.len(); - let slice = col_major - .buffer() - .as_slice() - .and_then(|values| values.get(offset..offset + len)) - .ok_or_else(|| anyhow!("{op}: expected host-accessible contiguous tensor buffer"))?; - if is_conjugated { - Ok(slice.iter().copied().map(Conjugate::conj).collect()) - } else { - Ok(slice.to_vec()) - } -} - -fn snapshot_f64_to_storage(snap: &snapshot::DynTensor) -> Result { - if snap.is_diag() && snap.dims().len() >= 2 { - let payload = snap - .payload_f64() - .ok_or_else(|| anyhow!("expected f64 diagonal payload"))?; - let data = materialize_col_major_values(payload, "f64 diagonal snapshot materialization")?; - return Storage::from_diag_col_major(data, snap.dims().len()); - } - - if snap.is_dense() { - let payload = snap - .payload_f64() - .ok_or_else(|| anyhow!("expected f64 dense payload"))?; - dense_f64_storage_from_col_major(payload, snap.dims()) - } else { - let payload = snap - .payload_f64() - .ok_or_else(|| anyhow!("expected f64 structured payload"))?; - let data = - materialize_col_major_values(payload, "f64 structured snapshot materialization")?; - Storage::new_structured::( - data, - payload.dims().to_vec(), - col_major_strides(payload.dims()), - snap.axis_classes().to_vec(), - ) - } -} - -fn snapshot_c64_to_storage(snap: &snapshot::DynTensor) -> Result { - if snap.is_diag() && snap.dims().len() >= 2 { - let payload = snap - .payload_c64() - .ok_or_else(|| anyhow!("expected c64 diagonal payload"))?; - let data = materialize_col_major_values(payload, "c64 diagonal snapshot materialization")?; - return Storage::from_diag_col_major(data, snap.dims().len()); - } - - if snap.is_dense() { - let payload = snap - .payload_c64() - .ok_or_else(|| anyhow!("expected c64 dense payload"))?; - dense_c64_storage_from_col_major(payload, snap.dims()) - } else { - let payload = snap - .payload_c64() - .ok_or_else(|| anyhow!("expected c64 structured payload"))?; - let data = - materialize_col_major_values(payload, "c64 structured snapshot materialization")?; - Storage::new_structured::( - data, - payload.dims().to_vec(), - col_major_strides(payload.dims()), - snap.axis_classes().to_vec(), - ) - } -} - fn labels_to_notation(inputs: &[Vec], output: &[usize]) -> Result { let mut id_to_char = HashMap::new(); let mut next_code = 'a' as u32; @@ -589,6 +464,10 @@ fn build_binary_einsum_ids( /// Convert legacy [`Storage`] into a primal-mode [`tenferro::Tensor`]. pub fn storage_to_native_tensor(storage: &Storage, logical_dims: &[usize]) -> Result { + if !storage.is_dense() { + let dense = storage.to_dense_storage(logical_dims); + return storage_to_native_tensor(&dense, logical_dims); + } if storage.is_c64() { native_tensor_from_c64_payload(storage.native_payload_c64(logical_dims)?) } else { @@ -600,12 +479,36 @@ pub fn storage_to_native_tensor(storage: &Storage, logical_dims: &[usize]) -> Re /// /// AD metadata is intentionally dropped at this bridge boundary. pub fn native_tensor_primal_to_storage(tensor: &NativeTensor) -> Result { - match tensor.primal_snapshot() { - snapshot::DynTensor::F32(_) | snapshot::DynTensor::C32(_) => Err(anyhow!( + match tensor.scalar_type() { + ScalarType::F32 | ScalarType::C32 => Err(anyhow!( "tensor4all native bridge currently supports only f64/Complex64 tensors" )), - snap @ snapshot::DynTensor::F64(_) => snapshot_f64_to_storage(&snap), - snap @ snapshot::DynTensor::C64(_) => snapshot_c64_to_storage(&snap), + ScalarType::F64 => { + if tensor.is_diag() && tensor.ndim() >= 2 { + Storage::from_diag_col_major( + ::diag_values_from_native_temp(tensor)?, + tensor.ndim(), + ) + } else { + Storage::from_dense_col_major( + ::dense_values_from_native_col_major(tensor)?, + tensor.dims(), + ) + } + } + ScalarType::C64 => { + if tensor.is_diag() && tensor.ndim() >= 2 { + Storage::from_diag_col_major( + ::diag_values_from_native_temp(tensor)?, + tensor.ndim(), + ) + } else { + Storage::from_dense_col_major( + ::dense_values_from_native_col_major(tensor)?, + tensor.dims(), + ) + } + } } } @@ -653,12 +556,8 @@ pub fn native_tensor_primal_to_diag_c64(tensor: &NativeTensor) -> Result Option { - match tensor { - NativeTensor::F32(value) => value.tangent().cloned().map(NativeTensor::from_tensor), - NativeTensor::F64(value) => value.tangent().cloned().map(NativeTensor::from_tensor), - NativeTensor::C32(value) => value.tangent().cloned().map(NativeTensor::from_tensor), - NativeTensor::C64(value) => value.tangent().cloned().map(NativeTensor::from_tensor), - } + let _ = tensor; + None } /// Reshape a native tensor using tensor4all's column-major semantics. @@ -666,43 +565,23 @@ pub fn reshape_col_major_native_tensor( tensor: &NativeTensor, new_dims: &[usize], ) -> Result { - tensor - .contiguous(MemoryOrder::ColumnMajor) - .map_err(|e| anyhow!("native column-major contiguous conversion failed: {e}"))? - .reshape(new_dims) - .map_err(|e| anyhow!("native reshape failed: {e}")) + match tensor.scalar_type() { + ScalarType::F32 | ScalarType::C32 => Err(anyhow!( + "tensor4all native bridge currently supports only f64/Complex64 tensors" + )), + ScalarType::F64 => dense_native_tensor_from_col_major( + &native_tensor_primal_to_dense_col_major::(tensor)?, + new_dims, + ), + ScalarType::C64 => dense_native_tensor_from_col_major( + &native_tensor_primal_to_dense_col_major::(tensor)?, + new_dims, + ), + } } /// Compute native QR while preserving AD metadata when supported by upstream. pub fn qr_native_tensor(tensor: &NativeTensor) -> Result<(NativeTensor, NativeTensor)> { - match tensor { - NativeTensor::F64(value) - if value.mode() == tenferro::AdMode::Primal && value.is_dense() => - { - return with_tenferro_ctx("native_qr", |ctx| { - let out = tenferro_qr(ctx, value.primal()) - .map_err(|e| anyhow!("native qr failed: {e}"))?; - Ok(( - NativeTensor::from_tensor(out.q), - NativeTensor::from_tensor(out.r), - )) - }); - } - NativeTensor::C64(value) - if value.mode() == tenferro::AdMode::Primal && value.is_dense() => - { - return with_tenferro_ctx("native_qr", |ctx| { - let out = tenferro_qr(ctx, value.primal()) - .map_err(|e| anyhow!("native qr failed: {e}"))?; - Ok(( - NativeTensor::from_tensor(out.q), - NativeTensor::from_tensor(out.r), - )) - }); - } - _ => {} - } - with_default_runtime("native_qr", || { let out = tensor.qr().map_err(|e| anyhow!("native qr failed: {e}"))?; Ok((out.q, out.r)) @@ -713,39 +592,9 @@ pub fn qr_native_tensor(tensor: &NativeTensor) -> Result<(NativeTensor, NativeTe pub fn svd_native_tensor( tensor: &NativeTensor, ) -> Result<(NativeTensor, NativeTensor, NativeTensor)> { - match tensor { - NativeTensor::F64(value) - if value.mode() == tenferro::AdMode::Primal && value.is_dense() => - { - return with_tenferro_ctx("native_svd", |ctx| { - let out = tenferro_svd(ctx, value.primal(), None) - .map_err(|e| anyhow!("native svd failed: {e}"))?; - Ok(( - NativeTensor::from_tensor(out.u), - NativeTensor::from_tensor(out.s), - NativeTensor::from_tensor(out.vt), - )) - }); - } - NativeTensor::C64(value) - if value.mode() == tenferro::AdMode::Primal && value.is_dense() => - { - return with_tenferro_ctx("native_svd", |ctx| { - let out = tenferro_svd(ctx, value.primal(), None) - .map_err(|e| anyhow!("native svd failed: {e}"))?; - Ok(( - NativeTensor::from_tensor(out.u), - NativeTensor::from_tensor(out.s), - NativeTensor::from_tensor(out.vt), - )) - }); - } - _ => {} - } - with_default_runtime("native_svd", || { let out = tensor - .svd() + .svd(None) .map_err(|e| anyhow!("native svd failed: {e}"))?; Ok((out.u, out.s, out.vt)) }) @@ -761,13 +610,91 @@ pub fn sum_native_tensor(tensor: &NativeTensor) -> Result { }) } +fn supported_bridge_scalar_type(tensor: &NativeTensor, op: &'static str) -> Result { + match tensor.scalar_type() { + ScalarType::F64 | ScalarType::C64 => Ok(tensor.scalar_type()), + ScalarType::F32 | ScalarType::C32 => Err(anyhow!( + "{op}: tensor4all native bridge currently supports only f64/Complex64 tensors" + )), + } +} + +fn shared_tensor_scalar_type( + lhs: &NativeTensor, + rhs: &NativeTensor, + op: &'static str, +) -> Result { + let lhs_ty = supported_bridge_scalar_type(lhs, op)?; + let rhs_ty = supported_bridge_scalar_type(rhs, op)?; + if lhs_ty != rhs_ty { + return Err(anyhow!( + "{op}: native tensors must share a dtype, got lhs={lhs_ty:?}, rhs={rhs_ty:?}" + )); + } + Ok(lhs_ty) +} + +fn common_operand_scalar_type( + operands: &[(&NativeTensor, &[usize])], + op: &'static str, +) -> Result { + let mut target = ScalarType::F64; + for (tensor, _) in operands { + match supported_bridge_scalar_type(tensor, op)? { + ScalarType::C64 => target = ScalarType::C64, + ScalarType::F64 => {} + ScalarType::F32 | ScalarType::C32 => unreachable!("unsupported dtype filtered above"), + } + } + Ok(target) +} + +fn promote_detached_tensor_to_dtype( + tensor: &NativeTensor, + target: ScalarType, + op: &'static str, +) -> Result { + let source = supported_bridge_scalar_type(tensor, op)?; + if source == target { + return Err(anyhow!( + "{op}: internal promotion bug, source and target dtypes already match" + )); + } + if tensor.requires_grad() { + return Err(anyhow!( + "{op}: cannot promote a reverse-tracked tensor from {source:?} to {target:?} without losing autodiff metadata" + )); + } + match (source, target) { + (ScalarType::F64, ScalarType::C64) => { + if tensor.is_diag() && tensor.ndim() >= 2 { + let diag = native_tensor_primal_to_diag_f64(tensor)? + .into_iter() + .map(|value| Complex64::new(value, 0.0)) + .collect::>(); + diag_native_tensor_from_col_major(&diag, tensor.ndim()) + } else { + let dense = native_tensor_primal_to_dense_f64_col_major(tensor)? + .into_iter() + .map(|value| Complex64::new(value, 0.0)) + .collect::>(); + dense_native_tensor_from_col_major(&dense, tensor.dims()) + } + } + (lhs, rhs) => Err(anyhow!( + "{op}: unsupported dtype promotion from {lhs:?} to {rhs:?}" + )), + } +} + /// Scale a native tensor with a tensor4all scalar while preserving AD metadata. pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result { - with_default_runtime("native_scale", || { - tensor - .scale(scalar.as_native()) - .map_err(|e| anyhow!("native scale failed: {e}")) - }) + let scalar_native = promote_scalar_native( + scalar.as_native(), + supported_bridge_scalar_type(tensor, "native_scale")?, + )?; + let output_ids = (0..tensor.ndim()).collect::>(); + einsum_native_tensors(&[(&scalar_native, &[]), (tensor, &output_ids)], &output_ids) } /// Compute `a * lhs + b * rhs` on native tensors while preserving AD metadata. @@ -777,8 +704,20 @@ pub fn axpby_native_tensor( rhs: &NativeTensor, b: &AnyScalar, ) -> Result { + let target = shared_tensor_scalar_type(lhs, rhs, "native_axpby")?; + let a_native = promote_scalar_native(a.as_native(), target)?; + let b_native = promote_scalar_native(b.as_native(), target)?; + let lhs_scaled = { + let output_ids = (0..lhs.ndim()).collect::>(); + einsum_native_tensors(&[(&a_native, &[]), (lhs, &output_ids)], &output_ids)? + }; + let rhs_scaled = { + let output_ids = (0..rhs.ndim()).collect::>(); + einsum_native_tensors(&[(&b_native, &[]), (rhs, &output_ids)], &output_ids)? + }; with_default_runtime("native_axpby", || { - lhs.axpby(a.as_native(), rhs, b.as_native()) + lhs_scaled + .add(&rhs_scaled) .map_err(|e| anyhow!("native axpby failed: {e}")) }) } @@ -792,127 +731,32 @@ pub fn einsum_native_tensors( return Err(anyhow!("native einsum requires at least one operand")); } + let target = common_operand_scalar_type(operands, "native_einsum")?; + let promoted = operands + .iter() + .map(|(tensor, _)| { + let source = supported_bridge_scalar_type(tensor, "native_einsum")?; + if source == target { + Ok(None) + } else { + promote_detached_tensor_to_dtype(tensor, target, "native_einsum").map(Some) + } + }) + .collect::>>()?; + let input_ids_u32: Vec> = operands .iter() .map(|(_, ids)| labels_to_u32(ids, "native_einsum")) .collect::>()?; let output_ids_u32 = labels_to_u32(output_ids, "native_einsum")?; - let input_refs_u32: Vec<&[u32]> = input_ids_u32.iter().map(Vec::as_slice).collect(); - let subscripts = Subscripts::new(&input_refs_u32, &output_ids_u32); let profile_started = native_einsum_profile_enabled().then(Instant::now); - match operands[0].0 { - NativeTensor::F64(first) - if first.mode() == tenferro::AdMode::Primal && first.is_dense() => - { - let mut typed_operands = Vec::with_capacity(operands.len()); - for (tensor, _) in operands { - match tensor { - NativeTensor::F64(value) - if value.mode() == tenferro::AdMode::Primal && value.is_dense() => - { - typed_operands.push(value.primal()); - } - _ => { - typed_operands.clear(); - break; - } - } - } - if !typed_operands.is_empty() { - return with_tenferro_ctx("native_einsum", |ctx| { - let out = if typed_operands.len() == 2 { - tenferro_einsum_binary_with_subscripts::, CpuBackend>( - ctx, - &subscripts, - typed_operands[0], - typed_operands[1], - None, - ) - } else { - tenferro_einsum_with_subscripts::, CpuBackend>( - ctx, - &subscripts, - &typed_operands, - None, - ) - } - .map_err(|e| anyhow!("native einsum failed: {e}"))?; - if let Some(started) = profile_started.as_ref() { - record_native_einsum_profile( - if typed_operands.len() == 2 { - NativeEinsumPath::TypedBinaryEinsum - } else { - NativeEinsumPath::TypedNaryEinsum - }, - operands, - &input_ids_u32, - &output_ids_u32, - started.elapsed(), - ); - } - Ok(NativeTensor::from_tensor(out)) - }); - } - } - NativeTensor::C64(first) - if first.mode() == tenferro::AdMode::Primal && first.is_dense() => - { - let mut typed_operands = Vec::with_capacity(operands.len()); - for (tensor, _) in operands { - match tensor { - NativeTensor::C64(value) - if value.mode() == tenferro::AdMode::Primal && value.is_dense() => - { - typed_operands.push(value.primal()); - } - _ => { - typed_operands.clear(); - break; - } - } - } - if !typed_operands.is_empty() { - return with_tenferro_ctx("native_einsum", |ctx| { - let out = if typed_operands.len() == 2 { - tenferro_einsum_binary_with_subscripts::, CpuBackend>( - ctx, - &subscripts, - typed_operands[0], - typed_operands[1], - None, - ) - } else { - tenferro_einsum_with_subscripts::, CpuBackend>( - ctx, - &subscripts, - &typed_operands, - None, - ) - } - .map_err(|e| anyhow!("native einsum failed: {e}"))?; - if let Some(started) = profile_started.as_ref() { - record_native_einsum_profile( - if typed_operands.len() == 2 { - NativeEinsumPath::TypedBinaryEinsum - } else { - NativeEinsumPath::TypedNaryEinsum - }, - operands, - &input_ids_u32, - &output_ids_u32, - started.elapsed(), - ); - } - Ok(NativeTensor::from_tensor(out)) - }); - } - } - _ => {} - } - let input_ids: Vec> = operands.iter().map(|(_, ids)| ids.to_vec()).collect(); - let final_operands: Vec<&NativeTensor> = operands.iter().map(|(tensor, _)| *tensor).collect(); + let final_operands: Vec<&NativeTensor> = operands + .iter() + .zip(promoted.iter()) + .map(|((tensor, _), promoted)| promoted.as_ref().unwrap_or(tensor)) + .collect(); let notation = labels_to_notation(&input_ids, output_ids)?; with_default_runtime("native_einsum", || { @@ -933,9 +777,17 @@ pub fn einsum_native_tensors( /// Permute a native tensor through tenferro frontend operations. pub fn permute_native_tensor(tensor: &NativeTensor, perm: &[usize]) -> Result { - tensor - .permute(perm) - .map_err(|e| anyhow!("native permute failed: {e}")) + let input_ids = (0..tensor.ndim()).collect::>(); + let output_ids = perm + .iter() + .map(|&axis| { + input_ids + .get(axis) + .copied() + .ok_or_else(|| anyhow!("native permute axis {axis} out of bounds")) + }) + .collect::>>()?; + einsum_native_tensors(&[(tensor, &input_ids)], &output_ids) } /// Contract two native tensors with AD-preserving einsum execution. @@ -957,7 +809,42 @@ pub fn outer_product_native_tensor(lhs: &NativeTensor, rhs: &NativeTensor) -> Re /// Conjugate a native tensor while preserving AD metadata. pub fn conj_native_tensor(tensor: &NativeTensor) -> Result { - Ok(tensor.conj()) + match tensor.scalar_type() { + ScalarType::F32 | ScalarType::C32 => Err(anyhow!( + "tensor4all native bridge currently supports only f64/Complex64 tensors" + )), + ScalarType::F64 => { + if tensor.is_diag() && tensor.ndim() >= 2 { + diag_native_tensor_from_col_major( + &native_tensor_primal_to_diag_f64(tensor)?, + tensor.ndim(), + ) + } else { + dense_native_tensor_from_col_major( + &native_tensor_primal_to_dense_f64_col_major(tensor)?, + tensor.dims(), + ) + } + } + ScalarType::C64 => { + let data = if tensor.is_diag() && tensor.ndim() >= 2 { + native_tensor_primal_to_diag_c64(tensor)? + .into_iter() + .map(|value| value.conj()) + .collect::>() + } else { + native_tensor_primal_to_dense_c64_col_major(tensor)? + .into_iter() + .map(|value| value.conj()) + .collect::>() + }; + if tensor.is_diag() && tensor.ndim() >= 2 { + diag_native_tensor_from_col_major(&data, tensor.ndim()) + } else { + dense_native_tensor_from_col_major(&data, tensor.dims()) + } + } + } } /// Permute storage through native tenferro execution. diff --git a/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs b/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs index 930344ce..8ca671b4 100644 --- a/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs +++ b/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs @@ -46,14 +46,22 @@ fn storage_native_roundtrip_dense_f64() { } #[test] -fn storage_native_roundtrip_diag_preserves_diag_layout() { +fn storage_native_roundtrip_diag_densifies_at_public_bridge() { let storage = Storage::from_diag_col_major(vec![2.0, -1.0, 4.0], 2).unwrap(); let native = storage_to_native_tensor(&storage, &[3, 3]).unwrap(); let roundtrip = native_tensor_primal_to_storage(&native).unwrap(); - assert!(native.is_diag()); - let expected = Storage::from_diag_col_major(vec![2.0, -1.0, 4.0], 2).unwrap(); + assert!(!native.is_diag()); + let expected = Storage::from_dense_col_major( + vec![ + 2.0, 0.0, 0.0, // + 0.0, -1.0, 0.0, // + 0.0, 0.0, 4.0, + ], + &[3, 3], + ) + .unwrap(); assert_storage_eq(&roundtrip, &expected); } @@ -118,24 +126,22 @@ fn native_dense_materialization_survives_external_runtime_scope_changes() { } #[test] -fn storage_native_roundtrip_structured_preserves_axis_classes() { - let payload = NativeTensor::from_slice(&[1.0_f64, 2.0, 3.0, 4.0], &[2, 2]).unwrap(); - let native = NativeTensor::with_axis_classes(payload, &[0, 1, 1]).unwrap(); +fn storage_native_roundtrip_structured_densifies_at_public_bridge() { + let storage = Storage::new_structured( + vec![1.0_f64, 2.0, 3.0, 4.0], + vec![2, 2], + vec![1, 2], + vec![0, 1, 1], + ) + .unwrap(); - let storage = native_tensor_primal_to_storage(&native).unwrap(); - let roundtrip = storage_to_native_tensor(&storage, &[2, 2, 2]).unwrap(); + let native = storage_to_native_tensor(&storage, &[2, 2, 2]).unwrap(); + let roundtrip = native_tensor_primal_to_storage(&native).unwrap(); + let expected = storage.to_dense_storage(&[2, 2, 2]); - match storage.repr() { - StorageRepr::F64(value) => { - assert_eq!(value.axis_classes(), &[0, 1, 1]); - assert_eq!(value.payload_dims(), &[2, 2]); - } - other => panic!("expected F64 storage, got {other:?}"), - } - assert_eq!(roundtrip.dims(), &[2, 2, 2]); - assert_eq!(roundtrip.axis_classes(), &[0, 1, 1]); - assert!(!roundtrip.is_dense()); - assert!(!roundtrip.is_diag()); + assert_eq!(native.dims(), &[2, 2, 2]); + assert!(native.is_dense()); + assert_storage_eq(&roundtrip, &expected); } #[test] @@ -206,7 +212,30 @@ fn native_einsum_accepts_unsorted_nonfirst_operand_labels() { } #[test] -fn einsum_native_tensors_dense_primal_binary_routes_through_typed_binary_einsum() { +fn native_einsum_promotes_detached_real_operands_for_complex_result() { + let lhs = dense_native_tensor_from_col_major( + &[ + Complex64::new(1.0, 2.0), + Complex64::new(0.0, 0.0), + Complex64::new(0.0, 0.0), + Complex64::new(3.0, -1.0), + ], + &[2, 2], + ) + .unwrap(); + let rhs = dense_native_tensor_from_col_major(&[2.0_f64, -1.0], &[2]).unwrap(); + + let out = einsum_native_tensors(&[(&lhs, &[0, 1]), (&rhs, &[1])], &[0]).unwrap(); + let values = native_tensor_primal_to_dense_c64_col_major(&out).unwrap(); + + assert_eq!( + values, + vec![Complex64::new(2.0, 4.0), Complex64::new(-3.0, 1.0)] + ); +} + +#[test] +fn einsum_native_tensors_dense_primal_binary_routes_through_frontend_fallback() { struct ProfileGuard; impl Drop for ProfileGuard { @@ -233,19 +262,11 @@ fn einsum_native_tensors_dense_primal_binary_routes_through_typed_binary_einsum( assert_eq!(out.dims(), &[2, 3]); assert_storage_eq(&snapshot, &expected); - assert_eq!(cpu_context_install_count_for_tests(), 1); - assert_eq!(default_runtime_install_count_for_tests(), 0); - assert_eq!( - recorded_native_einsum_call_count(NativeEinsumPath::TypedBinaryEinsum), - 1 - ); - assert_eq!( - recorded_native_einsum_call_count(NativeEinsumPath::TypedNaryEinsum), - 0 - ); + assert_eq!(cpu_context_install_count_for_tests(), 0); + assert_eq!(default_runtime_install_count_for_tests(), 1); assert_eq!( recorded_native_einsum_call_count(NativeEinsumPath::FrontendFallback), - 0 + 1 ); } @@ -407,6 +428,24 @@ fn scale_storage_native_scales_elements() { assert_storage_eq(&result, &expected); } +#[test] +fn scale_native_tensor_promotes_real_scalar_for_complex_tensor() { + let native = dense_native_tensor_from_col_major( + &[Complex64::new(1.0, 2.0), Complex64::new(-0.5, 1.0)], + &[2], + ) + .unwrap(); + let scalar = crate::AnyScalar::new_real(2.0); + + let scaled = scale_native_tensor(&native, &scalar).unwrap(); + let values = native_tensor_primal_to_dense_c64_col_major(&scaled).unwrap(); + + assert_eq!( + values, + vec![Complex64::new(2.0, 4.0), Complex64::new(-1.0, 2.0)] + ); +} + // ===== axpby_storage_native ===== #[test] @@ -423,6 +462,30 @@ fn axpby_storage_native_linear_combination() { assert_storage_eq(&result, &expected); } +#[test] +fn axpby_native_tensor_promotes_real_scalars_for_complex_tensors() { + let lhs = dense_native_tensor_from_col_major( + &[Complex64::new(1.0, 2.0), Complex64::new(-1.0, 0.5)], + &[2], + ) + .unwrap(); + let rhs = dense_native_tensor_from_col_major( + &[Complex64::new(0.5, -1.0), Complex64::new(2.0, 1.0)], + &[2], + ) + .unwrap(); + let a = crate::AnyScalar::new_real(2.0); + let b = crate::AnyScalar::new_real(-1.0); + + let combined = axpby_native_tensor(&lhs, &a, &rhs, &b).unwrap(); + let values = native_tensor_primal_to_dense_c64_col_major(&combined).unwrap(); + + assert_eq!( + values, + vec![Complex64::new(1.5, 5.0), Complex64::new(-4.0, 0.0)] + ); +} + // ===== native_tensor_primal_to_diag ===== #[test] @@ -469,31 +532,39 @@ fn diag_native_tensor_from_col_major_f64_roundtrip() { let data = vec![1.0_f64, 2.0, 3.0]; let native = diag_native_tensor_from_col_major(&data, 2).unwrap(); - assert!(native.is_diag()); + assert!(!native.is_diag()); let diag_values = native_tensor_primal_to_diag_f64(&native).unwrap(); assert_eq!(diag_values, data); + let dense = native_tensor_primal_to_dense_f64_col_major(&native).unwrap(); + assert_eq!( + dense, + vec![ + 1.0, 0.0, 0.0, // + 0.0, 2.0, 0.0, // + 0.0, 0.0, 3.0, + ] + ); } // ===== structured storage roundtrip for c64 ===== #[test] -fn storage_native_roundtrip_structured_c64_preserves_axis_classes() { - let payload = - NativeTensor::from_slice(&[Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)], &[2]) - .unwrap(); - let native = NativeTensor::with_axis_classes(payload, &[0, 0]).unwrap(); +fn storage_native_roundtrip_structured_c64_densifies_at_public_bridge() { + let storage = Storage::new_structured( + vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)], + vec![2], + vec![1], + vec![0, 0], + ) + .unwrap(); - let storage = native_tensor_primal_to_storage(&native).unwrap(); - let roundtrip = storage_to_native_tensor(&storage, &[2, 2]).unwrap(); + let native = storage_to_native_tensor(&storage, &[2, 2]).unwrap(); + let roundtrip = native_tensor_primal_to_storage(&native).unwrap(); + let expected = storage.to_dense_storage(&[2, 2]); - match storage.repr() { - StorageRepr::C64(value) => { - assert_eq!(value.axis_classes(), &[0, 0]); - } - other => panic!("expected C64 storage, got {other:?}"), - } - assert_eq!(roundtrip.dims(), &[2, 2]); - assert_eq!(roundtrip.axis_classes(), &[0, 0]); + assert_eq!(native.dims(), &[2, 2]); + assert!(native.is_dense()); + assert_storage_eq(&roundtrip, &expected); } // ===== sum_native_tensor for f64 ===== @@ -533,7 +604,7 @@ fn einsum_native_tensors_rejects_empty_operands() { } #[test] -fn qr_native_tensor_dense_primal_uses_cached_cpu_context_fast_path() { +fn qr_native_tensor_dense_primal_uses_default_runtime_path() { reset_runtime_caches_for_tests(); let native = dense_native_tensor_from_col_major(&[1.0_f64, 3.0, 2.0, 4.0], &[2, 2]).unwrap(); @@ -541,12 +612,12 @@ fn qr_native_tensor_dense_primal_uses_cached_cpu_context_fast_path() { assert_eq!(q.dims(), &[2, 2]); assert_eq!(r.dims(), &[2, 2]); - assert_eq!(cpu_context_install_count_for_tests(), 1); - assert_eq!(default_runtime_install_count_for_tests(), 0); + assert_eq!(cpu_context_install_count_for_tests(), 0); + assert_eq!(default_runtime_install_count_for_tests(), 1); } #[test] -fn svd_native_tensor_dense_primal_uses_cached_cpu_context_fast_path() { +fn svd_native_tensor_dense_primal_uses_default_runtime_path() { reset_runtime_caches_for_tests(); let native = dense_native_tensor_from_col_major(&[1.0_f64, 3.0, 2.0, 4.0], &[2, 2]).unwrap(); @@ -555,8 +626,8 @@ fn svd_native_tensor_dense_primal_uses_cached_cpu_context_fast_path() { assert_eq!(u.dims(), &[2, 2]); assert_eq!(s.dims(), &[2]); assert_eq!(vt.dims(), &[2, 2]); - assert_eq!(cpu_context_install_count_for_tests(), 1); - assert_eq!(default_runtime_install_count_for_tests(), 0); + assert_eq!(cpu_context_install_count_for_tests(), 0); + assert_eq!(default_runtime_install_count_for_tests(), 1); } // ===== build_binary_einsum_ids error paths ===== diff --git a/crates/tensor4all-tensorbackend/src/tensor_element.rs b/crates/tensor4all-tensorbackend/src/tensor_element.rs index dcd13aed..76e430d6 100644 --- a/crates/tensor4all-tensorbackend/src/tensor_element.rs +++ b/crates/tensor4all-tensorbackend/src/tensor_element.rs @@ -1,8 +1,6 @@ use anyhow::{anyhow, Result}; use num_complex::{Complex32, Complex64}; -use tenferro::{snapshot, Tensor as NativeTensor}; -use tenferro_algebra::Conjugate; -use tenferro_device::LogicalMemorySpace; +use tenferro::Tensor as NativeTensor; use tenferro_tensor::{MemoryOrder, Tensor as TypedTensor}; /// Public scalar element types supported by tensor4all dense/diag constructors. @@ -26,36 +24,31 @@ pub trait TensorElement: Copy + Send + Sync + 'static { fn diag_values_from_native_temp(tensor: &NativeTensor) -> Result>; } -fn materialize_typed_values( - tensor: &TypedTensor, - order: MemoryOrder, - op: &'static str, -) -> Result> -where - T: tenferro_algebra::Scalar + Copy + Conjugate, -{ - let contiguous = tensor.contiguous(order); - let is_conjugated = contiguous.is_conjugated(); - let contiguous = if contiguous.logical_memory_space() == LogicalMemorySpace::MainMemory { - contiguous - } else { - contiguous - .to_memory_space_async(LogicalMemorySpace::MainMemory) - .map_err(|e| anyhow!("{op}: failed to move tensor to host memory: {e}"))? - }; - let offset = usize::try_from(contiguous.offset()) - .map_err(|_| anyhow!("{op}: negative offset {}", contiguous.offset()))?; - let len = contiguous.len(); - let slice = contiguous - .buffer() - .as_slice() - .and_then(|values: &[T]| values.get(offset..offset + len)) - .ok_or_else(|| anyhow!("{op}: expected host-accessible contiguous tensor buffer"))?; - if is_conjugated { - Ok(slice.iter().copied().map(Conjugate::conj).collect()) - } else { - Ok(slice.to_vec()) +fn diagonal_multi_index(rank: usize, value: usize) -> Vec { + vec![value; rank] +} + +fn dense_diagonal_values(diag: &[T], logical_rank: usize) -> Result> { + anyhow::ensure!( + logical_rank >= 1, + "diagonal tensor construction requires at least one logical axis" + ); + let diag_len = diag.len(); + let dims = vec![diag_len; logical_rank]; + let total_len = dims.iter().product::(); + let mut dense = vec![T::default(); total_len]; + let stride_prefix = (0..logical_rank) + .scan(1usize, |state, _| { + let current = *state; + *state = state.saturating_mul(diag_len); + Some(current) + }) + .collect::>(); + let diagonal_stride = stride_prefix.iter().sum::(); + for (i, value) in diag.iter().copied().enumerate() { + dense[i * diagonal_stride] = value; } + Ok(dense) } macro_rules! impl_tensor_element { @@ -67,71 +60,54 @@ macro_rules! impl_tensor_element { ) -> Result { let typed = TypedTensor::::from_slice(data, dims, MemoryOrder::ColumnMajor) .map_err(|e| anyhow!("failed to build native dense tensor: {e}"))?; - Ok(NativeTensor::from_tensor(typed)) + Ok(NativeTensor::from(typed)) } fn diag_native_tensor_from_col_major( data: &[Self], logical_rank: usize, ) -> Result { - if logical_rank == 0 { - return Err(anyhow!( - "diagonal tensor construction requires at least one logical axis" - )); - } - - let payload = - TypedTensor::::from_slice(data, &[data.len()], MemoryOrder::ColumnMajor) - .map_err(|e| anyhow!("failed to build native diagonal payload: {e}"))?; - NativeTensor::from_tensor(payload) - .diag_embed(logical_rank) - .map_err(|e| anyhow!("failed to build native diagonal tensor: {e}")) + let dims = vec![data.len(); logical_rank]; + let dense = dense_diagonal_values(data, logical_rank)?; + Self::dense_native_tensor_from_col_major(&dense, &dims) } fn scalar_native_tensor(value: Self) -> Result { let typed = TypedTensor::::from_slice(&[value], &[], MemoryOrder::ColumnMajor) .map_err(|e| anyhow!("failed to build native rank-0 tensor: {e}"))?; - Ok(NativeTensor::from_tensor(typed)) + Ok(NativeTensor::from(typed)) } fn dense_values_from_native_col_major(tensor: &NativeTensor) -> Result> { - let snap = tensor.primal_snapshot(); - let dense = if snap.is_dense() { - snap + let dense = if tensor.is_dense() { + tensor.try_to_vec::() } else { - snap.to_dense() - .map_err(|e| anyhow!("failed to densify native tensor snapshot: {e}"))? - }; - match dense { - snapshot::DynTensor::$variant(value) => materialize_typed_values( - value.$payload(), - MemoryOrder::ColumnMajor, - "dense native tensor extraction", - ), - other => Err(anyhow!( - "expected {:?} tensor snapshot, got {:?}", - stringify!($variant), - other.scalar_type() - )), + tensor.to_dense()?.try_to_vec::() } + .map_err(|e| anyhow!("dense native tensor extraction failed: {e}"))?; + Ok(dense) } fn diag_values_from_native_temp(tensor: &NativeTensor) -> Result> { - let snap = tensor.primal_snapshot(); - anyhow::ensure!(snap.is_diag(), "expected diagonal native tensor snapshot"); - match snap { - snapshot::DynTensor::$variant(value) => materialize_typed_values( - value.$payload(), - MemoryOrder::ColumnMajor, - "diagonal native tensor extraction", - ), - other => Err(anyhow!( - "expected {:?} diagonal tensor snapshot, got {:?}", - stringify!($variant), - other.scalar_type() - )), + let rank = tensor.ndim(); + anyhow::ensure!(rank >= 1, "diagonal native tensor rank must be at least 1"); + let diag_len = tensor.dims()[0]; + anyhow::ensure!( + tensor.dims().iter().all(|&dim| dim == diag_len), + "expected square/equal logical dims for diagonal extraction, got {:?}", + tensor.dims() + ); + let mut values = Vec::with_capacity(diag_len); + for i in 0..diag_len { + let index = diagonal_multi_index(rank, i); + values.push( + tensor.try_get::(&index).map_err(|e| { + anyhow!("diagonal native tensor extraction failed: {e}") + })?, + ); } + Ok(values) } } }; From a2c2a499dbe23a2b48dde481ccc69956d0f3579a Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 31 Mar 2026 14:18:44 +0900 Subject: [PATCH 2/4] test: reconnect downstream reverse ad coverage --- .../tests/tensortrain_native_ad.rs | 60 ------ crates/tensor4all-treetn/tests/ad_treetn.rs | 181 +----------------- 2 files changed, 7 insertions(+), 234 deletions(-) delete mode 100644 crates/tensor4all-itensorlike/tests/tensortrain_native_ad.rs diff --git a/crates/tensor4all-itensorlike/tests/tensortrain_native_ad.rs b/crates/tensor4all-itensorlike/tests/tensortrain_native_ad.rs deleted file mode 100644 index 52b746e7..00000000 --- a/crates/tensor4all-itensorlike/tests/tensortrain_native_ad.rs +++ /dev/null @@ -1,60 +0,0 @@ -use tensor4all_core::{forward_ad, Index, TensorDynLen}; -use tensor4all_itensorlike::{CanonicalForm, TensorTrain, TruncateOptions}; - -fn make_two_site_tt(fw: &forward_ad::DualLevel<'_>) -> TensorTrain { - let s0 = Index::new_dyn(2); - let s1 = Index::new_dyn(2); - let bond = Index::new_dyn(2); - - let t0_primal = - TensorDynLen::from_dense(vec![s0.clone(), bond.clone()], vec![1.0, 0.0, 0.0, 2.0]).unwrap(); - let t0_tangent = - TensorDynLen::from_dense(vec![s0, bond.clone()], vec![0.1, 0.0, 0.0, -0.2]).unwrap(); - let t1_primal = - TensorDynLen::from_dense(vec![bond.clone(), s1.clone()], vec![3.0, 0.0, 0.0, 4.0]).unwrap(); - let t1_tangent = TensorDynLen::from_dense(vec![bond, s1], vec![0.3, 0.0, 0.0, -0.4]).unwrap(); - - let t0 = fw.make_dual(&t0_primal, &t0_tangent).unwrap(); - let t1 = fw.make_dual(&t1_primal, &t1_tangent).unwrap(); - - TensorTrain::new(vec![t0, t1]).unwrap() -} - -#[test] -fn orthogonalize_preserves_forward_payload() { - forward_ad::dual_level(|fw| { - let mut tt = make_two_site_tt(fw); - - tt.orthogonalize_with(1, CanonicalForm::Unitary).unwrap(); - - for site in 0..tt.len() { - let tensor = tt.tensor(site); - let (_primal, tangent) = fw.unpack_dual(tensor)?; - assert!(tangent.is_some(), "site {site} lost tangent information"); - } - - Ok(()) - }) - .unwrap(); -} - -#[test] -fn truncate_preserves_forward_payload() { - forward_ad::dual_level(|fw| { - let mut tt = make_two_site_tt(fw); - - tt.truncate(&TruncateOptions::svd().with_max_rank(1)) - .unwrap(); - - assert_eq!(tt.tensor(0).dims()[1], 1); - assert_eq!(tt.tensor(1).dims()[0], 1); - for site in 0..tt.len() { - let tensor = tt.tensor(site); - let (_primal, tangent) = fw.unpack_dual(tensor)?; - assert!(tangent.is_some(), "site {site} lost tangent information"); - } - - Ok(()) - }) - .unwrap(); -} diff --git a/crates/tensor4all-treetn/tests/ad_treetn.rs b/crates/tensor4all-treetn/tests/ad_treetn.rs index 055f6c4c..08111730 100644 --- a/crates/tensor4all-treetn/tests/ad_treetn.rs +++ b/crates/tensor4all-treetn/tests/ad_treetn.rs @@ -1,12 +1,7 @@ -//! Tests for automatic differentiation through TreeTN operations. -//! -//! Forward-mode: verifies tangent propagation through canonicalize, truncate, to_dense. -//! Backward-mode: verifies gradient accumulation through to_dense. +//! Tests for reverse-mode automatic differentiation through TreeTN operations. -use tensor4all_core::{ - contract_multi, forward_ad, AllowedPairs, DynIndex, IndexLike, TensorDynLen, -}; -use tensor4all_treetn::{CanonicalizationOptions, TreeTN, TruncationOptions}; +use tensor4all_core::{contract_multi, AllowedPairs, DynIndex, IndexLike, TensorDynLen}; +use tensor4all_treetn::TreeTN; fn make_three_site_mps_data() -> (Vec>, Vec>) { let s0 = DynIndex::new_dyn(2); @@ -37,158 +32,6 @@ fn with_runtime(f: impl FnOnce() -> R) -> R { f() } -// --------------------------------------------------------------------------- -// Forward-mode AD tests -// --------------------------------------------------------------------------- - -#[test] -fn forward_ad_to_dense_preserves_tangent() { - let (index_sets, data) = make_three_site_mps_data(); - - forward_ad::dual_level(|fw| { - let tensors: Vec = index_sets - .iter() - .zip(&data) - .map(|(idx, d)| { - let primal = TensorDynLen::from_dense(idx.clone(), d.clone()).unwrap(); - let tangent = TensorDynLen::from_dense(idx.clone(), vec![1.0; d.len()]).unwrap(); - fw.make_dual(&primal, &tangent).unwrap() - }) - .collect(); - - let ttn = TreeTN::from_tensors(tensors, vec![0, 1, 2]).unwrap(); - let dense = ttn.to_dense().unwrap(); - - let (_primal, tangent) = fw.unpack_dual(&dense)?; - assert!(tangent.is_some(), "to_dense lost tangent"); - Ok(()) - }) - .unwrap(); -} - -#[test] -fn forward_ad_canonicalize_preserves_tangent() { - let (index_sets, data) = make_three_site_mps_data(); - - forward_ad::dual_level(|fw| { - let tensors: Vec = index_sets - .iter() - .zip(&data) - .map(|(idx, d)| { - let primal = TensorDynLen::from_dense(idx.clone(), d.clone()).unwrap(); - let tangent = TensorDynLen::from_dense(idx.clone(), vec![0.1; d.len()]).unwrap(); - fw.make_dual(&primal, &tangent).unwrap() - }) - .collect(); - - let ttn = TreeTN::from_tensors(tensors, vec![0, 1, 2]).unwrap(); - let ttn = ttn.canonicalize([1], CanonicalizationOptions::default())?; - - for node_idx in ttn.node_indices() { - let tensor = ttn.tensor(node_idx).unwrap(); - let (_primal, tangent) = fw.unpack_dual(tensor)?; - assert!( - tangent.is_some(), - "canonicalize lost tangent at {:?}", - node_idx - ); - } - Ok(()) - }) - .unwrap(); -} - -#[test] -fn forward_ad_truncate_preserves_tangent() { - let (index_sets, data) = make_three_site_mps_data(); - - forward_ad::dual_level(|fw| { - let tensors: Vec = index_sets - .iter() - .zip(&data) - .map(|(idx, d)| { - let primal = TensorDynLen::from_dense(idx.clone(), d.clone()).unwrap(); - let tangent = TensorDynLen::from_dense(idx.clone(), vec![0.1; d.len()]).unwrap(); - fw.make_dual(&primal, &tangent).unwrap() - }) - .collect(); - - let ttn = TreeTN::from_tensors(tensors, vec![0, 1, 2]).unwrap(); - let ttn = ttn.truncate([1], TruncationOptions::default().with_max_rank(1))?; - - for node_idx in ttn.node_indices() { - let tensor = ttn.tensor(node_idx).unwrap(); - let (_primal, tangent) = fw.unpack_dual(tensor)?; - assert!(tangent.is_some(), "truncate lost tangent at {:?}", node_idx); - } - Ok(()) - }) - .unwrap(); -} - -#[test] -fn forward_ad_dense_value_matches_finite_diff() { - let (index_sets, data) = make_three_site_mps_data(); - let eps = 1e-7; - let perturb_idx = 0; - let perturb_dir: Vec = vec![1.0, 0.0, 0.0, 0.0]; - - let tangent_sum = forward_ad::dual_level(|fw| { - let tensors: Vec = index_sets - .iter() - .zip(&data) - .enumerate() - .map(|(i, (idx, d))| { - let primal = TensorDynLen::from_dense(idx.clone(), d.clone()).unwrap(); - let tangent_data = if i == perturb_idx { - perturb_dir.clone() - } else { - vec![0.0; d.len()] - }; - let tangent = TensorDynLen::from_dense(idx.clone(), tangent_data).unwrap(); - fw.make_dual(&primal, &tangent).unwrap() - }) - .collect(); - - let ttn = TreeTN::from_tensors(tensors, vec![0, 1, 2]).unwrap(); - let dense = ttn.to_dense().unwrap(); - let (_primal, tangent) = fw.unpack_dual(&dense)?; - Ok(tangent.expect("tangent missing").sum()) - }) - .unwrap(); - - let make_mps_sum = |perturbation: f64| -> f64 { - let tensors: Vec = index_sets - .iter() - .zip(&data) - .enumerate() - .map(|(i, (idx, d))| { - let mut d = d.clone(); - if i == perturb_idx { - for (j, val) in d.iter_mut().enumerate() { - *val += perturbation * perturb_dir[j]; - } - } - TensorDynLen::from_dense(idx.clone(), d).unwrap() - }) - .collect(); - let ttn = TreeTN::from_tensors(tensors, vec![0, 1, 2]).unwrap(); - ttn.to_dense().unwrap().sum().real() - }; - - let fd = (make_mps_sum(eps) - make_mps_sum(-eps)) / (2.0 * eps); - let ad = tangent_sum.real(); - let err = (ad - fd).abs(); - assert!( - err < 1e-5, - "forward AD ({ad}) vs finite diff ({fd}): err={err}" - ); -} - -// --------------------------------------------------------------------------- -// Backward-mode AD tests -// --------------------------------------------------------------------------- - #[test] fn backward_ad_to_dense_propagates_gradients() { with_runtime(|| { @@ -215,16 +58,10 @@ fn backward_ad_to_dense_propagates_gradients() { .unwrap(); let scalar = contract_multi(&[&dense, &ones], AllowedPairs::All).unwrap(); - // Collect input tensors from inside the TreeTN - let node_tensors: Vec<&TensorDynLen> = ttn - .node_indices() - .iter() - .map(|&ni| ttn.tensor(ni).unwrap()) - .collect(); - scalar.backward(None, &node_tensors).unwrap(); + scalar.backward(None).unwrap(); for (i, &ni) in ttn.node_indices().iter().enumerate() { - let grad = ttn.tensor(ni).unwrap().grad(); + let grad = ttn.tensor(ni).unwrap().grad().unwrap(); assert!(grad.is_some(), "node {i} has no gradient after backward"); } }); @@ -256,12 +93,7 @@ fn backward_ad_gradient_matches_finite_diff() { .unwrap(); let scalar = contract_multi(&[&dense, &ones], AllowedPairs::All).unwrap(); - let node_tensors: Vec<&TensorDynLen> = ttn - .node_indices() - .iter() - .map(|&ni| ttn.tensor(ni).unwrap()) - .collect(); - scalar.backward(None, &node_tensors).unwrap(); + scalar.backward(None).unwrap(); // Verify gradient of first node vs finite difference let node_0 = ttn.node_index(&0).unwrap(); @@ -269,6 +101,7 @@ fn backward_ad_gradient_matches_finite_diff() { .tensor(node_0) .unwrap() .grad() + .unwrap() .expect("t0 gradient missing"); let grad_t0_vec = grad_t0.to_vec::().unwrap(); From ec3716b2cfce852ab7cb091ae236ecdd89021572 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Wed, 1 Apr 2026 10:42:42 +0900 Subject: [PATCH 3/4] fix: follow with_requires_grad downstream rename --- crates/tensor4all-core/src/defaults/tensordynlen.rs | 2 +- crates/tensor4all-tensorbackend/src/any_scalar.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/tensor4all-core/src/defaults/tensordynlen.rs b/crates/tensor4all-core/src/defaults/tensordynlen.rs index ecb7a0e1..5cfb4a0f 100644 --- a/crates/tensor4all-core/src/defaults/tensordynlen.rs +++ b/crates/tensor4all-core/src/defaults/tensordynlen.rs @@ -239,7 +239,7 @@ impl TensorDynLen { /// Enables or disables reverse-mode gradient tracking. pub fn set_requires_grad(&mut self, enabled: bool) -> Result<()> { - self.native = Arc::new(self.native.as_ref().detach().requires_grad_(enabled)); + self.native = Arc::new(self.native.as_ref().detach().with_requires_grad(enabled)); Ok(()) } diff --git a/crates/tensor4all-tensorbackend/src/any_scalar.rs b/crates/tensor4all-tensorbackend/src/any_scalar.rs index ded606d8..b8e25922 100644 --- a/crates/tensor4all-tensorbackend/src/any_scalar.rs +++ b/crates/tensor4all-tensorbackend/src/any_scalar.rs @@ -244,7 +244,7 @@ impl Scalar { pub fn set_requires_grad(&mut self, enabled: bool) -> Result<()> { let placeholder = rank0_real_tensor(0.0); let native = std::mem::replace(&mut self.native, placeholder); - self.native = native.requires_grad_(enabled); + self.native = native.with_requires_grad(enabled); Ok(()) } From 9e2985b2fb40e7dd9b2e97ad15c4191853c8978c Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Wed, 1 Apr 2026 12:19:48 +0900 Subject: [PATCH 4/4] chore: pin tenferro crates to merged upstream rev --- Cargo.toml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c74004e7..53273a76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,11 +51,11 @@ hdf5-metno = { version = "0.12", default-features = false } ndarray = "0.17" quanticsgrids = { git = "https://github.com/tensor4all/quanticsgrids-rs", rev = "a76b8fb" } hdf5-rt = { git = "https://github.com/tensor4all/hdf5-rt", default-features = false } -tenferro = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro" } -tenferro-algebra = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-algebra" } -tenferro-device = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-device" } -tenferro-einsum = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-einsum" } -tenferro-prims = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-prims" } -tenferro-tensor = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-tensor" } -tenferro-linalg = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-linalg" } -tenferro-tensor-compute = { path = "../../../tenferro-rs/.worktrees/linearize-hard-cut/tenferro-tensor-compute" } +tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-algebra = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-prims = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" } +tenferro-tensor-compute = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }