@@ -8,8 +8,9 @@ use crate::index_like::IndexLike;
88use crate :: truncation:: { HasTruncationParams , TruncationParams } ;
99use crate :: { unfold_split, TensorDynLen } ;
1010use tensor4all_tensorbackend:: {
11+ dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
1112 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
12- reshape_col_major_native_tensor, svd_native_tensor,
13+ reshape_col_major_native_tensor, svd_native_tensor, TensorElement ,
1314} ;
1415use thiserror:: Error ;
1516
@@ -134,6 +135,28 @@ fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, Sv
134135 }
135136}
136137
138+ fn truncate_matrix_cols < T : TensorElement > (
139+ data : & [ T ] ,
140+ rows : usize ,
141+ keep_cols : usize ,
142+ ) -> anyhow:: Result < tenferro:: Tensor > {
143+ dense_native_tensor_from_col_major ( & data[ ..rows * keep_cols] , & [ rows, keep_cols] )
144+ }
145+
146+ fn truncate_matrix_rows < T : TensorElement > (
147+ data : & [ T ] ,
148+ rows : usize ,
149+ cols : usize ,
150+ keep_rows : usize ,
151+ ) -> anyhow:: Result < tenferro:: Tensor > {
152+ let mut truncated = Vec :: with_capacity ( keep_rows * cols) ;
153+ for col in 0 ..cols {
154+ let start = col * rows;
155+ truncated. extend_from_slice ( & data[ start..start + keep_rows] ) ;
156+ }
157+ dense_native_tensor_from_col_major ( & truncated, & [ keep_rows, cols] )
158+ }
159+
137160type SvdTruncatedNativeResult = (
138161 tenferro:: Tensor ,
139162 tenferro:: Tensor ,
@@ -167,17 +190,35 @@ fn svd_truncated_native(
167190 r = r. min ( max_rank) ;
168191 }
169192 if r < k {
170- u_native = u_native. take_prefix ( 1 , r) . map_err ( |e| {
171- SvdError :: ComputationError ( anyhow:: anyhow!( "native SVD truncation on U failed: {e}" ) )
172- } ) ?;
173- s_native = s_native. take_prefix ( 0 , r) . map_err ( |e| {
174- SvdError :: ComputationError ( anyhow:: anyhow!(
175- "native SVD truncation on singular values failed: {e}"
176- ) )
177- } ) ?;
178- vt_native = vt_native. take_prefix ( 0 , r) . map_err ( |e| {
179- SvdError :: ComputationError ( anyhow:: anyhow!( "native SVD V^T truncation failed: {e}" ) )
180- } ) ?;
193+ match u_native. scalar_type ( ) {
194+ tenferro:: ScalarType :: F64 => {
195+ let u_values = native_tensor_primal_to_dense_f64_col_major ( & u_native)
196+ . map_err ( SvdError :: ComputationError ) ?;
197+ let vt_values = native_tensor_primal_to_dense_f64_col_major ( & vt_native)
198+ . map_err ( SvdError :: ComputationError ) ?;
199+ u_native =
200+ truncate_matrix_cols ( & u_values, m, r) . map_err ( SvdError :: ComputationError ) ?;
201+ vt_native = truncate_matrix_rows ( & vt_values, k, n, r)
202+ . map_err ( SvdError :: ComputationError ) ?;
203+ }
204+ tenferro:: ScalarType :: C64 => {
205+ let u_values = native_tensor_primal_to_dense_c64_col_major ( & u_native)
206+ . map_err ( SvdError :: ComputationError ) ?;
207+ let vt_values = native_tensor_primal_to_dense_c64_col_major ( & vt_native)
208+ . map_err ( SvdError :: ComputationError ) ?;
209+ u_native =
210+ truncate_matrix_cols ( & u_values, m, r) . map_err ( SvdError :: ComputationError ) ?;
211+ vt_native = truncate_matrix_rows ( & vt_values, k, n, r)
212+ . map_err ( SvdError :: ComputationError ) ?;
213+ }
214+ other => {
215+ return Err ( SvdError :: ComputationError ( anyhow:: anyhow!(
216+ "native SVD returned unsupported singular-vector scalar type {other:?}"
217+ ) ) ) ;
218+ }
219+ }
220+ s_native = dense_native_tensor_from_col_major ( & s_full[ ..r] , & [ r] )
221+ . map_err ( SvdError :: ComputationError ) ?;
181222 }
182223
183224 let bond_index = DynIndex :: new_bond ( r)
@@ -225,9 +266,8 @@ pub fn svd_with<T>(
225266 let u = TensorDynLen :: from_native ( u_indices, u_reshaped) . map_err ( SvdError :: ComputationError ) ?;
226267
227268 let s_indices = vec ! [ bond_index. clone( ) , bond_index. sim( ) ] ;
228- let s_diag = s_native. diag_embed ( 2 ) . map_err ( |e| {
229- SvdError :: ComputationError ( anyhow:: anyhow!( "native SVD diagonal embedding failed: {e}" ) )
230- } ) ?;
269+ let s_diag = diag_native_tensor_from_col_major ( & singular_values_from_native ( & s_native) ?, 2 )
270+ . map_err ( SvdError :: ComputationError ) ?;
231271 let s = TensorDynLen :: from_native ( s_indices, s_diag) . map_err ( SvdError :: ComputationError ) ?;
232272
233273 let mut vh_indices = vec ! [ bond_index. clone( ) ] ;
@@ -273,9 +313,8 @@ pub(crate) fn svd_for_factorize(
273313 let u = TensorDynLen :: from_native ( u_indices, u_reshaped) . map_err ( SvdError :: ComputationError ) ?;
274314
275315 let s_indices = vec ! [ bond_index. clone( ) , bond_index. sim( ) ] ;
276- let s_diag = s_native. diag_embed ( 2 ) . map_err ( |e| {
277- SvdError :: ComputationError ( anyhow:: anyhow!( "native SVD diagonal embedding failed: {e}" ) )
278- } ) ?;
316+ let s_diag = diag_native_tensor_from_col_major ( & singular_values_from_native ( & s_native) ?, 2 )
317+ . map_err ( SvdError :: ComputationError ) ?;
279318 let s = TensorDynLen :: from_native ( s_indices, s_diag) . map_err ( SvdError :: ComputationError ) ?;
280319
281320 let mut vh_indices = vec ! [ bond_index. clone( ) ] ;
0 commit comments