Skip to content

Commit 8a68db8

Browse files
authored
tensor4all: reconnect downstream AD bridge to the linearize-first stack (#394)
* fix: close core reconnect under public bridge contract * test: reconnect downstream reverse ad coverage * fix: follow with_requires_grad downstream rename * chore: pin tenferro crates to merged upstream rev
1 parent d8d85c8 commit 8a68db8

18 files changed

Lines changed: 838 additions & 1194 deletions

File tree

Cargo.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ hdf5-metno = { version = "0.12", default-features = false }
5151
ndarray = "0.17"
5252
quanticsgrids = { git = "https://github.com/tensor4all/quanticsgrids-rs", rev = "a76b8fb" }
5353
hdf5-rt = { git = "https://github.com/tensor4all/hdf5-rt", default-features = false }
54-
tenferro = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
55-
tenferro-algebra = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
56-
tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
57-
tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
58-
tenferro-prims = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
59-
tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
60-
tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
61-
tenferro-tensor-compute = { git = "https://github.com/tensor4all/tenferro-rs", rev = "a7b97c8" }
54+
tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
55+
tenferro-algebra = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
56+
tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
57+
tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
58+
tenferro-prims = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
59+
tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
60+
tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }
61+
tenferro-tensor-compute = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "c4e18845ba1735df4d37e122b808d75056b84b60" }

crates/tensor4all-core/src/defaults/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ pub use contract::{
3737
};
3838
pub use index::{DefaultIndex, DefaultTagSet, DynId, DynIndex, Index, TagSet};
3939
pub use tensordynlen::{
40-
compute_permutation_from_indices, diag_tensor_dyn_len, is_diag_tensor, unfold_split,
41-
RandomScalar, TensorAccess, TensorDynLen,
40+
compute_permutation_from_indices, diag_tensor_dyn_len, unfold_split, RandomScalar,
41+
TensorAccess, TensorDynLen,
4242
};
4343

4444
// Re-export linear algebra functions and types

crates/tensor4all-core/src/defaults/qr.rs

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use crate::truncation::TruncationParams;
88
use crate::{unfold_split, TensorDynLen};
99
use num_complex::ComplexFloat;
1010
use tensor4all_tensorbackend::{
11-
native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
12-
qr_native_tensor, reshape_col_major_native_tensor,
11+
dense_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major,
12+
native_tensor_primal_to_dense_f64_col_major, qr_native_tensor, reshape_col_major_native_tensor,
13+
TensorElement,
1314
};
1415
use thiserror::Error;
1516

@@ -112,6 +113,28 @@ where
112113
Ok(r.max(1))
113114
}
114115

116+
fn truncate_matrix_cols<T: TensorElement>(
117+
data: &[T],
118+
rows: usize,
119+
keep_cols: usize,
120+
) -> anyhow::Result<tenferro::Tensor> {
121+
dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
122+
}
123+
124+
fn truncate_matrix_rows<T: TensorElement>(
125+
data: &[T],
126+
rows: usize,
127+
cols: usize,
128+
keep_rows: usize,
129+
) -> anyhow::Result<tenferro::Tensor> {
130+
let mut truncated = Vec::with_capacity(keep_rows * cols);
131+
for col in 0..cols {
132+
let start = col * rows;
133+
truncated.extend_from_slice(&data[start..start + keep_rows]);
134+
}
135+
dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
136+
}
137+
115138
/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
116139
///
117140
/// This function uses the global default rtol for truncation.
@@ -227,12 +250,33 @@ pub fn qr_with<T>(
227250
}
228251
};
229252
if r < k {
230-
q_native = q_native.take_prefix(1, r).map_err(|e| {
231-
QrError::ComputationError(anyhow::anyhow!("native QR truncation on Q failed: {e}"))
232-
})?;
233-
r_native = r_native.take_prefix(0, r).map_err(|e| {
234-
QrError::ComputationError(anyhow::anyhow!("native QR truncation on R failed: {e}"))
235-
})?;
253+
match q_native.scalar_type() {
254+
tenferro::ScalarType::F64 => {
255+
let q_values = native_tensor_primal_to_dense_f64_col_major(&q_native)
256+
.map_err(QrError::ComputationError)?;
257+
let r_values = native_tensor_primal_to_dense_f64_col_major(&r_native)
258+
.map_err(QrError::ComputationError)?;
259+
q_native =
260+
truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
261+
r_native =
262+
truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
263+
}
264+
tenferro::ScalarType::C64 => {
265+
let q_values = native_tensor_primal_to_dense_c64_col_major(&q_native)
266+
.map_err(QrError::ComputationError)?;
267+
let r_values = native_tensor_primal_to_dense_c64_col_major(&r_native)
268+
.map_err(QrError::ComputationError)?;
269+
q_native =
270+
truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
271+
r_native =
272+
truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
273+
}
274+
other => {
275+
return Err(QrError::ComputationError(anyhow::anyhow!(
276+
"native QR returned unsupported scalar type {other:?}"
277+
)));
278+
}
279+
}
236280
}
237281

238282
let bond_index = DynIndex::new_bond(r)

crates/tensor4all-core/src/defaults/svd.rs

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use crate::index_like::IndexLike;
88
use crate::truncation::{HasTruncationParams, TruncationParams};
99
use crate::{unfold_split, TensorDynLen};
1010
use tensor4all_tensorbackend::{
11+
dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
1112
native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
12-
reshape_col_major_native_tensor, svd_native_tensor,
13+
reshape_col_major_native_tensor, svd_native_tensor, TensorElement,
1314
};
1415
use thiserror::Error;
1516

@@ -134,6 +135,28 @@ fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, Sv
134135
}
135136
}
136137

138+
fn truncate_matrix_cols<T: TensorElement>(
139+
data: &[T],
140+
rows: usize,
141+
keep_cols: usize,
142+
) -> anyhow::Result<tenferro::Tensor> {
143+
dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
144+
}
145+
146+
fn truncate_matrix_rows<T: TensorElement>(
147+
data: &[T],
148+
rows: usize,
149+
cols: usize,
150+
keep_rows: usize,
151+
) -> anyhow::Result<tenferro::Tensor> {
152+
let mut truncated = Vec::with_capacity(keep_rows * cols);
153+
for col in 0..cols {
154+
let start = col * rows;
155+
truncated.extend_from_slice(&data[start..start + keep_rows]);
156+
}
157+
dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
158+
}
159+
137160
type SvdTruncatedNativeResult = (
138161
tenferro::Tensor,
139162
tenferro::Tensor,
@@ -167,17 +190,35 @@ fn svd_truncated_native(
167190
r = r.min(max_rank);
168191
}
169192
if r < k {
170-
u_native = u_native.take_prefix(1, r).map_err(|e| {
171-
SvdError::ComputationError(anyhow::anyhow!("native SVD truncation on U failed: {e}"))
172-
})?;
173-
s_native = s_native.take_prefix(0, r).map_err(|e| {
174-
SvdError::ComputationError(anyhow::anyhow!(
175-
"native SVD truncation on singular values failed: {e}"
176-
))
177-
})?;
178-
vt_native = vt_native.take_prefix(0, r).map_err(|e| {
179-
SvdError::ComputationError(anyhow::anyhow!("native SVD V^T truncation failed: {e}"))
180-
})?;
193+
match u_native.scalar_type() {
194+
tenferro::ScalarType::F64 => {
195+
let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native)
196+
.map_err(SvdError::ComputationError)?;
197+
let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native)
198+
.map_err(SvdError::ComputationError)?;
199+
u_native =
200+
truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
201+
vt_native = truncate_matrix_rows(&vt_values, k, n, r)
202+
.map_err(SvdError::ComputationError)?;
203+
}
204+
tenferro::ScalarType::C64 => {
205+
let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native)
206+
.map_err(SvdError::ComputationError)?;
207+
let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native)
208+
.map_err(SvdError::ComputationError)?;
209+
u_native =
210+
truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
211+
vt_native = truncate_matrix_rows(&vt_values, k, n, r)
212+
.map_err(SvdError::ComputationError)?;
213+
}
214+
other => {
215+
return Err(SvdError::ComputationError(anyhow::anyhow!(
216+
"native SVD returned unsupported singular-vector scalar type {other:?}"
217+
)));
218+
}
219+
}
220+
s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r])
221+
.map_err(SvdError::ComputationError)?;
181222
}
182223

183224
let bond_index = DynIndex::new_bond(r)
@@ -225,9 +266,8 @@ pub fn svd_with<T>(
225266
let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
226267

227268
let s_indices = vec![bond_index.clone(), bond_index.sim()];
228-
let s_diag = s_native.diag_embed(2).map_err(|e| {
229-
SvdError::ComputationError(anyhow::anyhow!("native SVD diagonal embedding failed: {e}"))
230-
})?;
269+
let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
270+
.map_err(SvdError::ComputationError)?;
231271
let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
232272

233273
let mut vh_indices = vec![bond_index.clone()];
@@ -273,9 +313,8 @@ pub(crate) fn svd_for_factorize(
273313
let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
274314

275315
let s_indices = vec![bond_index.clone(), bond_index.sim()];
276-
let s_diag = s_native.diag_embed(2).map_err(|e| {
277-
SvdError::ComputationError(anyhow::anyhow!("native SVD diagonal embedding failed: {e}"))
278-
})?;
316+
let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
317+
.map_err(SvdError::ComputationError)?;
279318
let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
280319

281320
let mut vh_indices = vec![bond_index.clone()];

0 commit comments

Comments
 (0)