diff --git a/Manifest.toml.bak b/Manifest.toml.bak new file mode 100644 index 0000000..6a20dd0 --- /dev/null +++ b/Manifest.toml.bak @@ -0,0 +1,287 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.0" +manifest_format = "2.0" +project_hash = "efcb77218b46d3233173434c4aa34666730c5c95" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "9d8a54ce4b17aa5bdce0ea5c34bc5e7c340d16ad" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.18.1" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.5+1" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "e94f84da9af7ce9c6be049e9067e511e17ff89ec" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.6+0" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "XML2_jll", "Xorg_libpciaccess_jll"] +git-tree-sha1 = "157e2e5838984449e44af851a52fe374d56b9ada" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.13.0+0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "0533e564aae234aff59ab625543145446d8b6ec2" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.7.1" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "be484f5c92fad0bd8acfef35fe017900b0b73809" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.18.0+0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "9341048b9f723f2ae2a72a5269ac2f15f80534dc" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.3.2+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "36c2d142e7d45fb98b5f83925213feb3292ca348" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.5.5+0" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "bc95bf4149bf535c09602e3acdf950d9b4376227" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+3" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+2" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] +git-tree-sha1 = "ab6596a9d8236041dcd59b5b69316f28a8753592" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "5.0.9+0" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "c9cbeda6aceffc52d8a0017e71db27c7a7c0beaf" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.5.5+0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "522f093a29b31a93e34eaea17ba055d850edea28" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.5.1" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.QuanticsGrids]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae0eea18762a145ad9c10dc41f4f93cd6c88f48d" +uuid = "634c7f73-3e90-4749-a1bd-001b8efc642d" +version = "0.7.2" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.1" + +[[deps.RustToolChain]] +deps = ["Pkg"] +git-tree-sha1 = "1cd1bd0ee81956b9a4d48b32e9dae5f20f3471e6" +uuid = "e9dc52e2-edb8-4742-9783-5e542d30dbb5" +version = "0.1.2" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.XML2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] +git-tree-sha1 = "80d3930c6347cfce7ccf96bd3bafdf079d9c0390" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.13.9+0" + +[[deps.Xorg_libpciaccess_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "4909eb8f1cbf6bd4b1c30dd18b2ead9019ef2fad" +uuid = "a65dc6b1-eb27-53a1-bb3e-dea574b5389e" +version = "0.18.1+0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "13b760f97c6e753b47df30cb438d4dc3b50df282" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.5+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index fc2a30b..d7cf67e 100644 --- a/Project.toml +++ b/Project.toml @@ -11,19 +11,23 @@ RustToolChain = "e9dc52e2-edb8-4742-9783-5e542d30dbb5" [weakdeps] ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" +TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" [extensions] Tensor4allITensorsExt = ["ITensors"] +Tensor4allTCIExt = ["TensorCrossInterpolation"] [compat] HDF5 = "0.17" ITensors = "0.6, 0.7, 0.8, 0.9" RustToolChain = "0.1" +TensorCrossInterpolation = "0.9" julia = "1.9" [extras] ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" +TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "ITensors"] +test = ["Test", "ITensors", "TensorCrossInterpolation"] diff --git a/docs/plans/2026-03-30-treetci-capi.md b/docs/plans/2026-03-30-treetci-capi.md new file mode 100644 index 0000000..b7ccbc5 --- /dev/null +++ b/docs/plans/2026-03-30-treetci-capi.md @@ -0,0 +1,1441 @@ +# TreeTCI C API Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add C API bindings for tensor4all-treetci to the tensor4all-capi crate, enabling Julia/Python to call TreeTCI via FFI. + +**Architecture:** Stateful API exposing `SimpleTreeTci` as an opaque handle with graph construction, pivot management, sweep execution, state inspection, materialization, and a high-level convenience function. Single batch callback unifies point and batch evaluation. + +**Tech Stack:** Rust, C FFI, tensor4all-treetci, tensor4all-capi patterns (opaque handles, catch_unwind, status codes) + +**Reference:** Design spec at `Tensor4all.jl/docs/specs/2026-03-30-treetci-capi-and-julia-wrapper-design.md` + +**Working directory:** `/home/shinaoka/tensor4all/tensor4all-rs` + +--- + +## File Structure + +| Action | File | Responsibility | +|--------|------|----------------| +| Modify | `crates/tensor4all-capi/Cargo.toml` | Add `tensor4all-treetci` dependency | +| Modify | `crates/tensor4all-capi/src/lib.rs` | Add `mod treetci; pub use treetci::*;` | +| Modify | `crates/tensor4all-capi/src/types.rs` | Add `t4a_treetci_graph`, `t4a_treetci_f64`, `t4a_treetci_proposer_kind` | +| Create | `crates/tensor4all-capi/src/treetci.rs` | All TreeTCI C API functions | + +--- + +## Task 1: Scaffold — Cargo.toml, types, lib.rs, empty module + +**Files:** +- Modify: `crates/tensor4all-capi/Cargo.toml` +- Modify: `crates/tensor4all-capi/src/types.rs` +- Modify: `crates/tensor4all-capi/src/lib.rs` +- Create: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Add tensor4all-treetci dependency to Cargo.toml** + +In `crates/tensor4all-capi/Cargo.toml`, add to `[dependencies]`: + +```toml +tensor4all-treetci = { path = "../tensor4all-treetci" } +``` + +- [ ] **Step 2: Add opaque types and enum to types.rs** + +At the end of `crates/tensor4all-capi/src/types.rs`, add: + +```rust +// ============================================================================ +// TreeTCI types +// ============================================================================ + +use tensor4all_treetci::{TreeTciGraph, SimpleTreeTci}; +use std::ffi::c_void; + +/// Opaque tree graph type for TreeTCI +#[repr(C)] +pub struct t4a_treetci_graph { + pub(crate) _private: *const c_void, +} + +impl t4a_treetci_graph { + pub(crate) fn new(inner: TreeTciGraph) -> Self { + Self { + _private: Box::into_raw(Box::new(inner)) as *const c_void, + } + } + + pub(crate) fn inner(&self) -> &TreeTciGraph { + unsafe { &*(self._private as *const TreeTciGraph) } + } +} + +impl Clone for t4a_treetci_graph { + fn clone(&self) -> Self { + Self::new(self.inner().clone()) + } +} + +impl Drop for t4a_treetci_graph { + fn drop(&mut self) { + if !self._private.is_null() { + unsafe { + let _ = Box::from_raw(self._private as *mut TreeTciGraph); + } + } + } +} + +unsafe impl Send for t4a_treetci_graph {} +unsafe impl Sync for t4a_treetci_graph {} + +/// Opaque TreeTCI state (f64) +#[repr(C)] +pub struct t4a_treetci_f64 { + pub(crate) _private: *const c_void, +} + +impl t4a_treetci_f64 { + pub(crate) fn new(inner: SimpleTreeTci) -> Self { + Self { + _private: Box::into_raw(Box::new(inner)) as *const c_void, + } + } + + pub(crate) fn inner(&self) -> &SimpleTreeTci { + unsafe { &*(self._private as *const SimpleTreeTci) } + } + + pub(crate) fn inner_mut(&mut self) -> &mut SimpleTreeTci { + unsafe { &mut *(self._private as *mut SimpleTreeTci) } + } +} + +impl Drop for t4a_treetci_f64 { + fn drop(&mut self) { + if !self._private.is_null() { + unsafe { + let _ = Box::from_raw(self._private as *mut SimpleTreeTci); + } + } + } +} + +// No Clone — same as t4a_tci2_f64 +unsafe impl Send for t4a_treetci_f64 {} +unsafe impl Sync for t4a_treetci_f64 {} + +/// Proposer kind selection for TreeTCI +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum t4a_treetci_proposer_kind { + /// DefaultProposer: neighbor-product (matches TreeTCI.jl) + Default = 0, + /// SimpleProposer: random with seed + Simple = 1, + /// TruncatedDefaultProposer: truncated random subset of default candidates + TruncatedDefault = 2, +} +``` + +Note: `c_void` is likely already imported at the top of types.rs. If not, add `use std::ffi::c_void;`. Similarly, check if `TreeTciGraph` and `SimpleTreeTci` imports conflict with existing imports — they should not since they're from a new crate. + +- [ ] **Step 3: Add mod and re-export to lib.rs** + +In `crates/tensor4all-capi/src/lib.rs`, add alongside the existing module declarations: + +```rust +mod treetci; +``` + +And alongside the existing `pub use` statements: + +```rust +pub use treetci::*; +``` + +- [ ] **Step 4: Create empty treetci.rs** + +Create `crates/tensor4all-capi/src/treetci.rs`: + +```rust +//! C API for TreeTCI (tree-structured tensor cross interpolation) + +use crate::types::{t4a_treetci_f64, t4a_treetci_graph, t4a_treetci_proposer_kind}; +use crate::{err_status, set_last_error, StatusCode, T4A_INTERNAL_ERROR, T4A_NULL_POINTER, T4A_SUCCESS}; +use crate::t4a_treetn; +use std::ffi::c_void; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use tensor4all_treetci::{ + DefaultProposer, GlobalIndexBatch, SimpleProposer, SimpleTreeTci, TreeTciEdge, + TreeTciGraph, TreeTciOptions, TruncatedDefaultProposer, +}; +``` + +- [ ] **Step 5: Verify it compiles** + +Run: `cargo build -p tensor4all-capi --release 2>&1 | head -20` + +Expected: Successful compilation (warnings OK, no errors). + +- [ ] **Step 6: Commit** + +```bash +git add crates/tensor4all-capi/Cargo.toml crates/tensor4all-capi/src/types.rs \ + crates/tensor4all-capi/src/lib.rs crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): scaffold TreeTCI C API module with opaque types" +``` + +--- + +## Task 2: Graph construction functions + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Write the failing test** + +Append to `crates/tensor4all-capi/src/treetci.rs`: + +```rust +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: 7-site branching tree + /// 0 + /// | + /// 1---2 + /// | + /// 3 + /// | + /// 4 + /// / \ + /// 5 6 + fn sample_edges() -> Vec { + // flat: [u0,v0, u1,v1, ...] + vec![0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 4, 6] + } + + #[test] + fn test_graph_new_and_query() { + let edges = sample_edges(); + let graph = t4a_treetci_graph_new(7, edges.as_ptr(), 6); + assert!(!graph.is_null()); + + let graph_ref = unsafe { &*graph }; + + let mut n_sites: libc::size_t = 0; + let status = t4a_treetci_graph_n_sites(graph, &mut n_sites); + assert_eq!(status, T4A_SUCCESS); + assert_eq!(n_sites, 7); + + let mut n_edges: libc::size_t = 0; + let status = t4a_treetci_graph_n_edges(graph, &mut n_edges); + assert_eq!(status, T4A_SUCCESS); + assert_eq!(n_edges, 6); + + t4a_treetci_graph_release(graph); + } + + #[test] + fn test_graph_invalid_disconnected() { + // 4 sites but only 2 edges, disconnected + let edges: Vec = vec![0, 1, 2, 3]; + let graph = t4a_treetci_graph_new(4, edges.as_ptr(), 2); + assert!(graph.is_null()); // should fail validation + } +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo nextest run --release -p tensor4all-capi test_graph_new_and_query 2>&1 | tail -10` + +Expected: FAIL — functions `t4a_treetci_graph_new`, etc. not defined. + +- [ ] **Step 3: Implement graph functions** + +Add to `crates/tensor4all-capi/src/treetci.rs` (before the `#[cfg(test)]` block): + +```rust +// ============================================================================ +// Graph lifecycle +// ============================================================================ + +impl_opaque_type_common!(treetci_graph); + +/// Create a new tree graph. +/// +/// # Arguments +/// - `n_sites`: Number of sites (>= 1) +/// - `edges_flat`: Edge pairs [u0, v0, u1, v1, ...] (length = n_edges * 2) +/// - `n_edges`: Number of edges (must equal n_sites - 1 for a tree) +/// +/// # Returns +/// New graph handle, or NULL on error (invalid tree structure). +/// +/// # Safety +/// `edges_flat` must point to a valid buffer of `n_edges * 2` elements. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_graph_new( + n_sites: libc::size_t, + edges_flat: *const libc::size_t, + n_edges: libc::size_t, +) -> *mut t4a_treetci_graph { + if edges_flat.is_null() && n_edges > 0 { + set_last_error("edges_flat is null"); + return std::ptr::null_mut(); + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let edges: Vec = (0..n_edges) + .map(|i| { + let u = unsafe { *edges_flat.add(2 * i) }; + let v = unsafe { *edges_flat.add(2 * i + 1) }; + TreeTciEdge::new(u, v) + }) + .collect(); + + match TreeTciGraph::new(n_sites, &edges) { + Ok(graph) => Box::into_raw(Box::new(t4a_treetci_graph::new(graph))), + Err(e) => { + set_last_error(&e.to_string()); + std::ptr::null_mut() + } + } + })); + + crate::unwrap_catch_ptr(result) +} + +/// Get the number of sites in the graph. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_graph_n_sites( + graph: *const t4a_treetci_graph, + out: *mut libc::size_t, +) -> StatusCode { + if graph.is_null() || out.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let g = unsafe { &*graph }; + unsafe { *out = g.inner().n_sites() }; + T4A_SUCCESS + })); + + crate::unwrap_catch(result) +} + +/// Get the number of edges in the graph. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_graph_n_edges( + graph: *const t4a_treetci_graph, + out: *mut libc::size_t, +) -> StatusCode { + if graph.is_null() || out.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let g = unsafe { &*graph }; + unsafe { *out = g.inner().edges().len() }; + T4A_SUCCESS + })); + + crate::unwrap_catch(result) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cargo nextest run --release -p tensor4all-capi test_graph 2>&1 | tail -10` + +Expected: 2 tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): add TreeTCI graph construction functions" +``` + +--- + +## Task 3: Callback type and closure helpers + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Add callback type and closure helpers** + +Add to `crates/tensor4all-capi/src/treetci.rs` (after the imports, before graph functions): + +```rust +// ============================================================================ +// Callback type +// ============================================================================ + +/// Batch evaluation callback for TreeTCI. +/// +/// Evaluates the target function at multiple points simultaneously. +/// When `n_points == 1`, this acts as a single-point evaluation. +/// +/// # Arguments +/// * `batch_data` - Column-major (n_sites, n_points) index array. +/// Element at (site, point) is at `batch_data[site + n_sites * point]`. +/// * `n_sites` - Number of sites +/// * `n_points` - Number of evaluation points +/// * `results` - Output buffer for `n_points` f64 values +/// * `user_data` - User data pointer passed through from the calling function +/// +/// # Returns +/// 0 on success, non-zero on error +pub type TreeTciBatchEvalCallback = extern "C" fn( + batch_data: *const libc::size_t, + n_sites: libc::size_t, + n_points: libc::size_t, + results: *mut libc::c_double, + user_data: *mut c_void, +) -> i32; + +// ============================================================================ +// Internal helpers +// ============================================================================ + +/// Create a batch eval closure from the C callback. +/// +/// Returns a closure compatible with `Fn(GlobalIndexBatch<'_>) -> Result>`. +fn make_batch_eval_closure( + eval_fn: TreeTciBatchEvalCallback, + user_data: *mut c_void, +) -> impl Fn(GlobalIndexBatch<'_>) -> anyhow::Result> { + move |batch: GlobalIndexBatch<'_>| -> anyhow::Result> { + let mut results = vec![0.0f64; batch.n_points()]; + let status = eval_fn( + batch.data().as_ptr(), + batch.n_sites(), + batch.n_points(), + results.as_mut_ptr(), + user_data, + ); + if status != 0 { + anyhow::bail!("TreeTCI batch eval callback returned error status {}", status); + } + Ok(results) + } +} + +/// Create a point eval closure from the C batch callback (n_points=1). +fn make_point_eval_closure( + eval_fn: TreeTciBatchEvalCallback, + user_data: *mut c_void, +) -> impl Fn(&[usize]) -> f64 { + move |indices: &[usize]| -> f64 { + let mut result: f64 = 0.0; + let status = eval_fn( + indices.as_ptr(), + indices.len(), + 1, + &mut result, + user_data, + ); + if status != 0 { + f64::NAN + } else { + result + } + } +} + +/// Convert proposer kind enum to a boxed proposer trait object is not needed; +/// instead we dispatch at call sites. This helper creates TreeTciOptions from +/// C API parameters. +fn make_options(tolerance: f64, max_bond_dim: libc::size_t, max_iter: libc::size_t, normalize_error: bool) -> TreeTciOptions { + TreeTciOptions { + tolerance, + max_bond_dim: if max_bond_dim == 0 { usize::MAX } else { max_bond_dim }, + max_iter, + normalize_error, + } +} +``` + +Also add `anyhow` import at the top if not already present. Check existing imports — if `anyhow` is not a dependency of tensor4all-capi, add it to Cargo.toml: + +```toml +anyhow.workspace = true +``` + +(Check if `anyhow` is in the workspace `[workspace.dependencies]` in the root `Cargo.toml` first. If not, use `anyhow = "1"` directly.) + +- [ ] **Step 2: Verify it compiles** + +Run: `cargo build -p tensor4all-capi --release 2>&1 | head -20` + +Expected: Successful compilation. + +- [ ] **Step 3: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs crates/tensor4all-capi/Cargo.toml +git commit -m "feat(capi): add TreeTCI callback type and closure helpers" +``` + +--- + +## Task 4: State lifecycle and pivot management + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Write the failing test** + +Add to the `tests` module in `treetci.rs`: + +```rust + #[test] + fn test_state_new_and_add_pivots() { + let edges = sample_edges(); + let graph = t4a_treetci_graph_new(7, edges.as_ptr(), 6); + assert!(!graph.is_null()); + + let local_dims: Vec = vec![2; 7]; + let state = t4a_treetci_f64_new(local_dims.as_ptr(), 7, graph); + assert!(!state.is_null()); + + // Add one pivot: all zeros (column-major, n_sites=7, n_pivots=1) + let pivot: Vec = vec![0; 7]; + let status = t4a_treetci_f64_add_global_pivots(state, pivot.as_ptr(), 7, 1); + assert_eq!(status, T4A_SUCCESS); + + // Add two pivots at once (column-major) + // pivot0 = [0,0,0,0,0,0,0], pivot1 = [1,0,1,0,1,0,1] + let pivots: Vec = vec![ + 0, 0, 0, 0, 0, 0, 0, // column 0 (sites 0-6 for point 0) + 1, 0, 1, 0, 1, 0, 1, // column 1 (sites 0-6 for point 1) + ]; + let status = t4a_treetci_f64_add_global_pivots(state, pivots.as_ptr(), 7, 2); + assert_eq!(status, T4A_SUCCESS); + + t4a_treetci_f64_release(state); + t4a_treetci_graph_release(graph); + } +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo nextest run --release -p tensor4all-capi test_state_new_and_add_pivots 2>&1 | tail -10` + +Expected: FAIL — functions not defined. + +- [ ] **Step 3: Implement state lifecycle and pivot functions** + +Add to `treetci.rs` (after graph functions, before `#[cfg(test)]`): + +```rust +// ============================================================================ +// State lifecycle +// ============================================================================ + +/// Create a new TreeTCI state. +/// +/// # Arguments +/// - `local_dims`: Local dimension at each site (length = n_sites) +/// - `n_sites`: Number of sites (must match graph) +/// - `graph`: Tree graph handle (not consumed; cloned internally) +/// +/// # Returns +/// New state handle, or NULL on error. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_new( + local_dims: *const libc::size_t, + n_sites: libc::size_t, + graph: *const t4a_treetci_graph, +) -> *mut t4a_treetci_f64 { + if local_dims.is_null() || graph.is_null() { + set_last_error("local_dims or graph is null"); + return std::ptr::null_mut(); + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let dims: Vec = (0..n_sites) + .map(|i| unsafe { *local_dims.add(i) }) + .collect(); + let g = unsafe { &*graph }; + let graph_clone = g.inner().clone(); + + match SimpleTreeTci::new(dims, graph_clone) { + Ok(state) => Box::into_raw(Box::new(t4a_treetci_f64::new(state))), + Err(e) => { + set_last_error(&e.to_string()); + std::ptr::null_mut() + } + } + })); + + crate::unwrap_catch_ptr(result) +} + +/// Release a TreeTCI state. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_release(ptr: *mut t4a_treetci_f64) { + if !ptr.is_null() { + unsafe { + let _ = Box::from_raw(ptr); + } + } +} + +// ============================================================================ +// Pivot management +// ============================================================================ + +/// Add global pivots to the TreeTCI state. +/// +/// Each pivot is a multi-index over all sites. The pivots are projected +/// to per-edge pivot sets internally. +/// +/// # Arguments +/// - `ptr`: State handle +/// - `pivots_flat`: Column-major (n_sites, n_pivots) index array +/// - `n_sites`: Number of sites (must match state) +/// - `n_pivots`: Number of pivots +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_add_global_pivots( + ptr: *mut t4a_treetci_f64, + pivots_flat: *const libc::size_t, + n_sites: libc::size_t, + n_pivots: libc::size_t, +) -> StatusCode { + if ptr.is_null() { + return T4A_NULL_POINTER; + } + if pivots_flat.is_null() && n_pivots > 0 { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &mut *ptr }; + let state_inner = state.inner_mut(); + + // Unpack column-major (n_sites, n_pivots) to Vec> + let pivots: Vec> = (0..n_pivots) + .map(|p| { + (0..n_sites) + .map(|s| unsafe { *pivots_flat.add(s + n_sites * p) }) + .collect() + }) + .collect(); + + match state_inner.add_global_pivots(&pivots) { + Ok(()) => T4A_SUCCESS, + Err(e) => err_status(e, T4A_INTERNAL_ERROR), + } + })); + + crate::unwrap_catch(result) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cargo nextest run --release -p tensor4all-capi test_state_new_and_add_pivots 2>&1 | tail -10` + +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): add TreeTCI state lifecycle and pivot management" +``` + +--- + +## Task 5: Sweep function + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Write the failing test** + +Add to the `tests` module: + +```rust + /// Product function: f(idx) = prod(idx[s] + 1.0) + /// This has an exact TT representation with bond dim 1. + extern "C" fn product_batch_eval( + batch_data: *const libc::size_t, + n_sites: libc::size_t, + n_points: libc::size_t, + results: *mut libc::c_double, + _user_data: *mut c_void, + ) -> i32 { + for p in 0..n_points { + let mut val = 1.0f64; + for s in 0..n_sites { + let idx = unsafe { *batch_data.add(s + n_sites * p) }; + val *= (idx as f64) + 1.0; + } + unsafe { *results.add(p) = val }; + } + 0 + } + + #[test] + fn test_sweep() { + let edges = sample_edges(); + let graph = t4a_treetci_graph_new(7, edges.as_ptr(), 6); + let local_dims: Vec = vec![2; 7]; + let state = t4a_treetci_f64_new(local_dims.as_ptr(), 7, graph); + + // Add initial pivot + let pivot: Vec = vec![0; 7]; + t4a_treetci_f64_add_global_pivots(state, pivot.as_ptr(), 7, 1); + + // Run one sweep + let status = t4a_treetci_f64_sweep( + state, + product_batch_eval, + std::ptr::null_mut(), // no user_data needed + t4a_treetci_proposer_kind::Default, + 1e-12, + 0, // unlimited bond dim + ); + assert_eq!(status, T4A_SUCCESS); + + t4a_treetci_f64_release(state); + t4a_treetci_graph_release(graph); + } +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo nextest run --release -p tensor4all-capi test_sweep 2>&1 | tail -10` + +Expected: FAIL — `t4a_treetci_f64_sweep` not defined. + +- [ ] **Step 3: Implement sweep function** + +Add to `treetci.rs` (after pivot management): + +```rust +// ============================================================================ +// Sweep execution +// ============================================================================ + +/// Run one optimization iteration (visit all edges once). +/// +/// Internally calls `optimize_with_proposer` with `max_iter=1`. +/// +/// # Arguments +/// - `ptr`: State handle (mutable) +/// - `eval_cb`: Batch evaluation callback +/// - `user_data`: User data passed to callback +/// - `proposer_kind`: Proposer selection +/// - `tolerance`: Relative tolerance for this iteration +/// - `max_bond_dim`: Maximum bond dimension (0 = unlimited) +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_sweep( + ptr: *mut t4a_treetci_f64, + eval_cb: TreeTciBatchEvalCallback, + user_data: *mut c_void, + proposer_kind: t4a_treetci_proposer_kind, + tolerance: libc::c_double, + max_bond_dim: libc::size_t, +) -> StatusCode { + if ptr.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &mut *ptr }; + let state_inner = state.inner_mut(); + let batch_eval = make_batch_eval_closure(eval_cb, user_data); + let options = make_options(tolerance, max_bond_dim, 1, true); + + let res = match proposer_kind { + t4a_treetci_proposer_kind::Default => { + let proposer = DefaultProposer; + tensor4all_treetci::optimize_with_proposer( + state_inner, batch_eval, &options, &proposer, + ) + } + t4a_treetci_proposer_kind::Simple => { + let proposer = SimpleProposer::default(); + tensor4all_treetci::optimize_with_proposer( + state_inner, batch_eval, &options, &proposer, + ) + } + t4a_treetci_proposer_kind::TruncatedDefault => { + let proposer = TruncatedDefaultProposer::default(); + tensor4all_treetci::optimize_with_proposer( + state_inner, batch_eval, &options, &proposer, + ) + } + }; + + match res { + Ok(_) => T4A_SUCCESS, + Err(e) => err_status(e, T4A_INTERNAL_ERROR), + } + })); + + crate::unwrap_catch(result) +} +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cargo nextest run --release -p tensor4all-capi test_sweep 2>&1 | tail -10` + +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): add TreeTCI sweep function" +``` + +--- + +## Task 6: State inspection functions + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Write the failing test** + +Add to the `tests` module: + +```rust + #[test] + fn test_state_inspection() { + let edges = sample_edges(); + let graph = t4a_treetci_graph_new(7, edges.as_ptr(), 6); + let local_dims: Vec = vec![2; 7]; + let state = t4a_treetci_f64_new(local_dims.as_ptr(), 7, graph); + + let pivot: Vec = vec![0; 7]; + t4a_treetci_f64_add_global_pivots(state, pivot.as_ptr(), 7, 1); + + // Run a few sweeps + for _ in 0..4 { + t4a_treetci_f64_sweep( + state, + product_batch_eval, + std::ptr::null_mut(), + t4a_treetci_proposer_kind::Default, + 1e-12, + 0, + ); + } + + // max_bond_error + let mut error: libc::c_double = 0.0; + let status = t4a_treetci_f64_max_bond_error(state, &mut error); + assert_eq!(status, T4A_SUCCESS); + assert!(error < 1e-10, "error = {}", error); + + // max_rank + let mut rank: libc::size_t = 0; + let status = t4a_treetci_f64_max_rank(state, &mut rank); + assert_eq!(status, T4A_SUCCESS); + assert!(rank >= 1); + + // max_sample_value + let mut max_val: libc::c_double = 0.0; + let status = t4a_treetci_f64_max_sample_value(state, &mut max_val); + assert_eq!(status, T4A_SUCCESS); + assert!(max_val > 0.0); + + // bond_dims: query size first + let mut n_edges: libc::size_t = 0; + let status = t4a_treetci_f64_bond_dims( + state, + std::ptr::null_mut(), + 0, + &mut n_edges, + ); + assert_eq!(status, T4A_SUCCESS); + assert_eq!(n_edges, 6); + + // bond_dims: fill buffer + let mut dims = vec![0usize; n_edges]; + let status = t4a_treetci_f64_bond_dims( + state, + dims.as_mut_ptr(), + n_edges, + &mut n_edges, + ); + assert_eq!(status, T4A_SUCCESS); + for &d in &dims { + assert!(d >= 1); + } + + t4a_treetci_f64_release(state); + t4a_treetci_graph_release(graph); + } +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo nextest run --release -p tensor4all-capi test_state_inspection 2>&1 | tail -10` + +Expected: FAIL — inspection functions not defined. + +- [ ] **Step 3: Implement state inspection functions** + +Add to `treetci.rs`: + +```rust +// ============================================================================ +// State inspection +// ============================================================================ + +/// Get the maximum bond error across all edges. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_max_bond_error( + ptr: *const t4a_treetci_f64, + out: *mut libc::c_double, +) -> StatusCode { + if ptr.is_null() || out.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &*ptr }; + unsafe { *out = state.inner().max_bond_error() }; + T4A_SUCCESS + })); + + crate::unwrap_catch(result) +} + +/// Get the maximum rank (bond dimension) across all edges. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_max_rank( + ptr: *const t4a_treetci_f64, + out: *mut libc::size_t, +) -> StatusCode { + if ptr.is_null() || out.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &*ptr }; + unsafe { *out = state.inner().max_rank() }; + T4A_SUCCESS + })); + + crate::unwrap_catch(result) +} + +/// Get the maximum observed sample value (used for normalization). +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_max_sample_value( + ptr: *const t4a_treetci_f64, + out: *mut libc::c_double, +) -> StatusCode { + if ptr.is_null() || out.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &*ptr }; + unsafe { *out = state.inner().max_sample_value }; + T4A_SUCCESS + })); + + crate::unwrap_catch(result) +} + +/// Get the bond dimensions (ranks) at each edge. +/// +/// Uses query-then-fill: pass `out_ranks = NULL` to query `out_n_edges` only. +/// +/// # Arguments +/// - `out_ranks`: Output buffer (length >= n_edges), or NULL to query size +/// - `buf_len`: Buffer capacity +/// - `out_n_edges`: Outputs the number of edges +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_bond_dims( + ptr: *const t4a_treetci_f64, + out_ranks: *mut libc::size_t, + buf_len: libc::size_t, + out_n_edges: *mut libc::size_t, +) -> StatusCode { + if ptr.is_null() || out_n_edges.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &*ptr }; + let inner = state.inner(); + let edges = inner.graph.edges(); + let n_edges = edges.len(); + + unsafe { *out_n_edges = n_edges }; + + if out_ranks.is_null() { + return T4A_SUCCESS; + } + + if buf_len < n_edges { + return err_status( + format!("Buffer too small: need {}, got {}", n_edges, buf_len), + crate::T4A_BUFFER_TOO_SMALL, + ); + } + + for (i, edge) in edges.iter().enumerate() { + // Bond dim = number of pivot rows for either side of this edge + let (key_u, _key_v) = inner.graph.subregion_vertices(*edge).unwrap(); + let rank = inner.ijset.get(&key_u).map_or(0, |v| v.len()); + unsafe { *out_ranks.add(i) = rank }; + } + + T4A_SUCCESS + })); + + crate::unwrap_catch(result) +} +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cargo nextest run --release -p tensor4all-capi test_state_inspection 2>&1 | tail -10` + +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): add TreeTCI state inspection functions" +``` + +--- + +## Task 7: Materialization (to_treetn) + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Write the failing test** + +Add to the `tests` module: + +```rust + #[test] + fn test_to_treetn() { + let edges = sample_edges(); + let graph = t4a_treetci_graph_new(7, edges.as_ptr(), 6); + let local_dims: Vec = vec![2; 7]; + let state = t4a_treetci_f64_new(local_dims.as_ptr(), 7, graph); + + let pivot: Vec = vec![0; 7]; + t4a_treetci_f64_add_global_pivots(state, pivot.as_ptr(), 7, 1); + + for _ in 0..4 { + t4a_treetci_f64_sweep( + state, + product_batch_eval, + std::ptr::null_mut(), + t4a_treetci_proposer_kind::Default, + 1e-12, + 0, + ); + } + + // Materialize to TreeTN + let mut treetn_ptr: *mut t4a_treetn = std::ptr::null_mut(); + let status = t4a_treetci_f64_to_treetn( + state, + product_batch_eval, + std::ptr::null_mut(), + 0, // center_site + &mut treetn_ptr, + ); + assert_eq!(status, T4A_SUCCESS); + assert!(!treetn_ptr.is_null()); + + // Verify TreeTN is valid by checking vertex count + let mut n_vertices: libc::size_t = 0; + let status = crate::t4a_treetn_num_vertices(treetn_ptr, &mut n_vertices); + assert_eq!(status, T4A_SUCCESS); + assert_eq!(n_vertices, 7); + + crate::t4a_treetn_release(treetn_ptr); + t4a_treetci_f64_release(state); + t4a_treetci_graph_release(graph); + } +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo nextest run --release -p tensor4all-capi test_to_treetn 2>&1 | tail -10` + +Expected: FAIL — `t4a_treetci_f64_to_treetn` not defined. + +- [ ] **Step 3: Implement materialization** + +Add to `treetci.rs`: + +```rust +// ============================================================================ +// Materialization +// ============================================================================ + +/// Materialize the converged TreeTCI state into a TreeTN. +/// +/// Internally re-evaluates tensor values using the batch callback and +/// performs LU factorization to construct per-vertex tensors. +/// +/// # Arguments +/// - `ptr`: State handle (const — state is not modified) +/// - `eval_cb`: Batch evaluation callback +/// - `user_data`: User data passed to callback +/// - `center_site`: BFS root site for materialization +/// - `out_treetn`: Output TreeTN handle pointer +/// +/// # Returns +/// The result is a `t4a_treetn` handle. Release with `t4a_treetn_release`. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_to_treetn( + ptr: *const t4a_treetci_f64, + eval_cb: TreeTciBatchEvalCallback, + user_data: *mut c_void, + center_site: libc::size_t, + out_treetn: *mut *mut t4a_treetn, +) -> StatusCode { + if ptr.is_null() || out_treetn.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + let state = unsafe { &*ptr }; + let batch_eval = make_batch_eval_closure(eval_cb, user_data); + + match tensor4all_treetci::to_treetn(state.inner(), batch_eval, Some(center_site)) { + Ok(treetn) => { + unsafe { *out_treetn = Box::into_raw(Box::new(t4a_treetn::new(treetn))) }; + T4A_SUCCESS + } + Err(e) => err_status(e, T4A_INTERNAL_ERROR), + } + })); + + crate::unwrap_catch(result) +} +``` + +Note: `t4a_treetn::new(treetn)` — verify that `t4a_treetn` wraps `TreeTN` (which is `DefaultTreeTN`). The inner type alias may differ. Check `types.rs` for the exact inner type and adjust if needed. If `t4a_treetn` wraps a different type (e.g. `DefaultTreeTN` which may be a type alias), ensure the `to_treetn` return type matches. + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cargo nextest run --release -p tensor4all-capi test_to_treetn 2>&1 | tail -10` + +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): add TreeTCI materialization to TreeTN" +``` + +--- + +## Task 8: High-level convenience function + +**Files:** +- Modify: `crates/tensor4all-capi/src/treetci.rs` + +- [ ] **Step 1: Write the failing test** + +Add to the `tests` module: + +```rust + #[test] + fn test_crossinterpolate_tree_f64() { + let edges = sample_edges(); + let graph = t4a_treetci_graph_new(7, edges.as_ptr(), 6); + let local_dims: Vec = vec![2; 7]; + let initial_pivot: Vec = vec![0; 7]; + + let max_iter: libc::size_t = 10; + let mut out_treetn: *mut t4a_treetn = std::ptr::null_mut(); + let mut out_ranks = vec![0usize; max_iter]; + let mut out_errors = vec![0.0f64; max_iter]; + let mut out_n_iters: libc::size_t = 0; + + let status = t4a_crossinterpolate_tree_f64( + product_batch_eval, + std::ptr::null_mut(), + local_dims.as_ptr(), + 7, + graph, + initial_pivot.as_ptr(), + 1, // n_pivots + t4a_treetci_proposer_kind::Default, + 1e-12, // tolerance + 0, // max_bond_dim (unlimited) + max_iter, + 1, // normalize_error = true + 0, // center_site + &mut out_treetn, + out_ranks.as_mut_ptr(), + out_errors.as_mut_ptr(), + &mut out_n_iters, + ); + + assert_eq!(status, T4A_SUCCESS); + assert!(!out_treetn.is_null()); + assert!(out_n_iters > 0); + assert!(out_n_iters <= max_iter); + + // Verify convergence + let actual_iters = out_n_iters; + let last_error = out_errors[actual_iters - 1]; + assert!(last_error < 1e-10, "last_error = {}", last_error); + + // Verify TreeTN + let mut n_vertices: libc::size_t = 0; + crate::t4a_treetn_num_vertices(out_treetn, &mut n_vertices); + assert_eq!(n_vertices, 7); + + crate::t4a_treetn_release(out_treetn); + t4a_treetci_graph_release(graph); + } +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo nextest run --release -p tensor4all-capi test_crossinterpolate_tree_f64 2>&1 | tail -10` + +Expected: FAIL — function not defined. + +- [ ] **Step 3: Implement high-level function** + +Add to `treetci.rs`: + +```rust +// ============================================================================ +// High-level convenience function +// ============================================================================ + +/// Run TreeTCI to convergence and return a TreeTN. +/// +/// Equivalent to: new → add_pivots → sweep loop → materialize. +/// +/// # Arguments +/// - `eval_cb`: Batch evaluation callback +/// - `user_data`: User data passed to callback +/// - `local_dims`: Local dimension at each site (length = n_sites) +/// - `n_sites`: Number of sites +/// - `graph`: Tree graph handle +/// - `initial_pivots_flat`: Column-major (n_sites, n_pivots), or NULL for empty +/// - `n_pivots`: Number of initial pivots +/// - `proposer_kind`: Proposer selection +/// - `tolerance`: Relative tolerance +/// - `max_bond_dim`: Maximum bond dimension (0 = unlimited) +/// - `max_iter`: Maximum number of iterations +/// - `normalize_error`: Whether to normalize errors (0=false, 1=true) +/// - `center_site`: Materialization center site +/// - `out_treetn`: Output TreeTN handle +/// - `out_ranks`: Buffer for max rank per iteration (length >= max_iter), or NULL +/// - `out_errors`: Buffer for normalized error per iteration (length >= max_iter), or NULL +/// - `out_n_iters`: Output: actual number of iterations performed +#[unsafe(no_mangle)] +pub extern "C" fn t4a_crossinterpolate_tree_f64( + eval_cb: TreeTciBatchEvalCallback, + user_data: *mut c_void, + local_dims: *const libc::size_t, + n_sites: libc::size_t, + graph: *const t4a_treetci_graph, + initial_pivots_flat: *const libc::size_t, + n_pivots: libc::size_t, + proposer_kind: t4a_treetci_proposer_kind, + tolerance: libc::c_double, + max_bond_dim: libc::size_t, + max_iter: libc::size_t, + normalize_error: libc::c_int, + center_site: libc::size_t, + out_treetn: *mut *mut t4a_treetn, + out_ranks: *mut libc::size_t, + out_errors: *mut libc::c_double, + out_n_iters: *mut libc::size_t, +) -> StatusCode { + if local_dims.is_null() || graph.is_null() || out_treetn.is_null() || out_n_iters.is_null() { + return T4A_NULL_POINTER; + } + + let result = catch_unwind(AssertUnwindSafe(|| { + // Parse local_dims + let dims: Vec = (0..n_sites) + .map(|i| unsafe { *local_dims.add(i) }) + .collect(); + + // Clone graph + let g = unsafe { &*graph }; + let graph_clone = g.inner().clone(); + + // Create state + let mut state = match SimpleTreeTci::new(dims, graph_clone) { + Ok(s) => s, + Err(e) => return err_status(e, T4A_INTERNAL_ERROR), + }; + + // Add initial pivots + if !initial_pivots_flat.is_null() && n_pivots > 0 { + let pivots: Vec> = (0..n_pivots) + .map(|p| { + (0..n_sites) + .map(|s| unsafe { *initial_pivots_flat.add(s + n_sites * p) }) + .collect() + }) + .collect(); + if let Err(e) = state.add_global_pivots(&pivots) { + return err_status(e, T4A_INTERNAL_ERROR); + } + } + + // Run optimization + let batch_eval = make_batch_eval_closure(eval_cb, user_data); + let options = make_options( + tolerance, + max_bond_dim, + max_iter, + normalize_error != 0, + ); + + let (ranks, errors) = match proposer_kind { + t4a_treetci_proposer_kind::Default => { + let proposer = DefaultProposer; + match tensor4all_treetci::optimize_with_proposer( + &mut state, &batch_eval, &options, &proposer, + ) { + Ok(r) => r, + Err(e) => return err_status(e, T4A_INTERNAL_ERROR), + } + } + t4a_treetci_proposer_kind::Simple => { + let proposer = SimpleProposer::default(); + match tensor4all_treetci::optimize_with_proposer( + &mut state, &batch_eval, &options, &proposer, + ) { + Ok(r) => r, + Err(e) => return err_status(e, T4A_INTERNAL_ERROR), + } + } + t4a_treetci_proposer_kind::TruncatedDefault => { + let proposer = TruncatedDefaultProposer::default(); + match tensor4all_treetci::optimize_with_proposer( + &mut state, &batch_eval, &options, &proposer, + ) { + Ok(r) => r, + Err(e) => return err_status(e, T4A_INTERNAL_ERROR), + } + } + }; + + let n_iters = ranks.len(); + unsafe { *out_n_iters = n_iters }; + + // Copy ranks and errors to output buffers + if !out_ranks.is_null() { + for (i, &r) in ranks.iter().enumerate() { + unsafe { *out_ranks.add(i) = r }; + } + } + if !out_errors.is_null() { + for (i, &e) in errors.iter().enumerate() { + unsafe { *out_errors.add(i) = e }; + } + } + + // Materialize + match tensor4all_treetci::to_treetn(&state, &batch_eval, Some(center_site)) { + Ok(treetn) => { + unsafe { *out_treetn = Box::into_raw(Box::new(t4a_treetn::new(treetn))) }; + T4A_SUCCESS + } + Err(e) => err_status(e, T4A_INTERNAL_ERROR), + } + })); + + crate::unwrap_catch(result) +} +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cargo nextest run --release -p tensor4all-capi test_crossinterpolate_tree_f64 2>&1 | tail -10` + +Expected: PASS. + +- [ ] **Step 5: Run all TreeTCI tests** + +Run: `cargo nextest run --release -p tensor4all-capi treetci 2>&1 | tail -20` + +Expected: All tests PASS (test_graph_new_and_query, test_graph_invalid_disconnected, test_state_new_and_add_pivots, test_sweep, test_state_inspection, test_to_treetn, test_crossinterpolate_tree_f64). + +- [ ] **Step 6: Run clippy and fmt** + +Run: + +```bash +cargo fmt --all +cargo clippy --workspace --all-targets -- -D warnings 2>&1 | tail -20 +``` + +Fix any warnings or formatting issues. + +- [ ] **Step 7: Commit** + +```bash +git add crates/tensor4all-capi/src/treetci.rs +git commit -m "feat(capi): add TreeTCI high-level convenience function" +``` + +--- + +## Task 9: Final validation — full test suite + +**Files:** None (validation only) + +- [ ] **Step 1: Run the full capi test suite** + +Run: `cargo nextest run --release -p tensor4all-capi 2>&1 | tail -30` + +Expected: All existing tests + all new TreeTCI tests PASS. No regressions. + +- [ ] **Step 2: Run the full workspace test suite** + +Run: `cargo nextest run --release --workspace 2>&1 | tail -30` + +Expected: No regressions across the workspace. + +- [ ] **Step 3: Run CI checks** + +Run: `cargo xtask ci 2>&1 | tail -30` + +Expected: fmt, clippy, tests, docs all pass. + +- [ ] **Step 4: Update API docs** + +Run: `cargo run -p api-dump --release -- . -o docs/api 2>&1 | tail -10` + +Verify `docs/api/tensor4all_capi.md` now includes the TreeTCI functions. + +- [ ] **Step 5: Commit docs update** + +```bash +git add docs/api/ +git commit -m "docs(capi): update API reference with TreeTCI functions" +``` diff --git a/docs/plans/2026-03-30-treetci-global-pivot.md b/docs/plans/2026-03-30-treetci-global-pivot.md new file mode 100644 index 0000000..a299da4 --- /dev/null +++ b/docs/plans/2026-03-30-treetci-global-pivot.md @@ -0,0 +1,600 @@ +# TreeTCI Global Pivot Search Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** TreeTCI の optimize ループに global pivot search を追加し、局所的な特徴を持つ関数の収束を改善する。 + +**Architecture:** TCI2 の `DefaultGlobalPivotFinder` と同じ greedy local search アルゴリズムを TreeTCI 用に適応。`to_treetn` で materialization → `TreeTN::evaluate` で近似値を計算 → batch_eval で真値と比較。`global_search_interval` で実行頻度を制御。 + +**Tech Stack:** Rust, tensor4all-treetci crate, tensor4all-treetn (evaluate) + +**Working directory:** `/home/shinaoka/tensor4all/tensor4all-rs` + +**Reference:** +- TCI2 GlobalPivotFinder: `crates/tensor4all-tensorci/src/globalpivot.rs` +- TCI.jl テスト: `/home/shinaoka/tensor4all/TensorCrossInterpolation.jl/test/test_tensorci2.jl` L433-458 + +--- + +## File Structure + +| Action | File | Responsibility | +|--------|------|----------------| +| Modify | `crates/tensor4all-treetci/src/optimize.rs` | TreeTciOptions にパラメータ追加 + ループに global search 組み込み | +| Create | `crates/tensor4all-treetci/src/globalpivot.rs` | TreeTCI 用 global pivot finder | +| Modify | `crates/tensor4all-treetci/src/lib.rs` | mod globalpivot + re-exports | +| Modify | `crates/tensor4all-treetci/src/optimize.rs` (tests) | テスト追加 | + +--- + +## Task 1: TreeTciOptions にパラメータ追加 + +**Files:** +- Modify: `crates/tensor4all-treetci/src/optimize.rs` + +- [ ] **Step 1: TreeTciOptions に global pivot パラメータを追加** + +```rust +#[derive(Clone, Debug)] +pub struct TreeTciOptions { + pub tolerance: f64, + pub max_iter: usize, + pub max_bond_dim: usize, + pub normalize_error: bool, + /// Run global pivot search every N iterations. 0 = disabled. + pub global_search_interval: usize, + /// Maximum number of global pivots to add per search. + pub max_global_pivots: usize, + /// Number of random starting points for global pivot search. + pub num_global_searches: usize, + /// Only add pivots where error > tolerance × this margin. + pub global_search_tol_margin: f64, +} + +impl Default for TreeTciOptions { + fn default() -> Self { + Self { + tolerance: 1e-8, + max_iter: 20, + max_bond_dim: usize::MAX, + normalize_error: true, + global_search_interval: 0, // disabled by default + max_global_pivots: 5, + num_global_searches: 5, + global_search_tol_margin: 10.0, + } + } +} +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `cargo build -p tensor4all-treetci --release 2>&1 | tail -5` + +- [ ] **Step 3: Commit** + +```bash +git add crates/tensor4all-treetci/src/optimize.rs +git commit -m "feat(treetci): add global pivot search parameters to TreeTciOptions" +``` + +--- + +## Task 2: Global pivot finder モジュール + +**Files:** +- Create: `crates/tensor4all-treetci/src/globalpivot.rs` +- Modify: `crates/tensor4all-treetci/src/lib.rs` + +- [ ] **Step 1: Write failing test** + +Append to `crates/tensor4all-treetci/src/globalpivot.rs` (at the end, inside `#[cfg(test)]` module): + +```rust +#[cfg(test)] +mod tests { + use super::*; + use crate::{TreeTciGraph, TreeTciEdge, SimpleTreeTci}; + + #[test] + fn test_find_global_pivots_finds_error_points() { + // 3-site chain: 0-1-2, local_dims = [4, 4, 4] + let graph = TreeTciGraph::new(3, &[ + TreeTciEdge::new(0, 1), + TreeTciEdge::new(1, 2), + ]).unwrap(); + let local_dims = vec![4, 4, 4]; + + // Function with localized feature: large only when all indices > 2 + let f = |idx: &[usize]| -> f64 { + if idx.iter().all(|&x| x >= 2) { + 100.0 + } else { + (idx[0] as f64) * 0.1 + (idx[1] as f64) * 0.01 + } + }; + + // Zero approximation → evaluate always returns 0 + // So error = |f(x) - 0| = |f(x)| + let approx_eval = |idx: &[usize]| -> f64 { 0.0 }; + + let finder = DefaultTreeGlobalPivotFinder::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let pivots = finder.find_pivots( + &local_dims, + &f, + &approx_eval, + 0.1, // abs_tol: only accept error > 0.1 * 10.0 = 1.0 + &mut rng, + ); + + // Should find pivots in the high-value region (indices >= 2) + assert!(!pivots.is_empty(), "Should find at least one global pivot"); + for pivot in &pivots { + assert_eq!(pivot.len(), 3); + } + } +} +``` + +- [ ] **Step 2: Implement global pivot finder** + +Create `crates/tensor4all-treetci/src/globalpivot.rs`: + +```rust +//! Global pivot finder for TreeTCI. +//! +//! Adapted from TCI2's DefaultGlobalPivotFinder. Uses greedy local search +//! from random starting points to find multi-indices with large interpolation +//! error, which are then added as global pivots. + +use rand::Rng; +use tensor4all_tcicore::MultiIndex; + +/// Default global pivot finder for TreeTCI. +/// +/// Algorithm (same as TCI2): +/// 1. Generate `num_searches` random initial points +/// 2. From each point, sweep all dimensions to find local error maximum +/// 3. Keep points where error > `abs_tol × tol_margin` +/// 4. Limit to `max_pivots` results +#[derive(Debug, Clone)] +pub struct DefaultTreeGlobalPivotFinder { + /// Number of random starting points for greedy search. + pub num_searches: usize, + /// Maximum number of pivots to return per call. + pub max_pivots: usize, + /// Only return pivots where error > abs_tol × tol_margin. + pub tol_margin: f64, +} + +impl Default for DefaultTreeGlobalPivotFinder { + fn default() -> Self { + Self { + num_searches: 5, + max_pivots: 5, + tol_margin: 10.0, + } + } +} + +impl DefaultTreeGlobalPivotFinder { + pub fn new(num_searches: usize, max_pivots: usize, tol_margin: f64) -> Self { + Self { + num_searches, + max_pivots, + tol_margin, + } + } + + /// Find global pivots by comparing `f` (true function) with `approx` + /// (current approximation) at random + locally-optimized points. + /// + /// # Arguments + /// - `local_dims`: dimension at each site + /// - `f`: true function, f(multi_index) -> scalar magnitude + /// - `approx`: current approximation, approx(multi_index) -> scalar magnitude + /// - `abs_tol`: absolute tolerance (combined with tol_margin) + /// - `rng`: random number generator + /// + /// The error at a point is `|f(x) - approx(x)|`. + /// Points with error > `abs_tol * tol_margin` are candidates. + pub fn find_pivots( + &self, + local_dims: &[usize], + f: &F, + approx: &G, + abs_tol: f64, + rng: &mut impl Rng, + ) -> Vec + where + F: Fn(&[usize]) -> f64, + G: Fn(&[usize]) -> f64, + { + let n = local_dims.len(); + let threshold = abs_tol * self.tol_margin; + let mut found: Vec<(MultiIndex, f64)> = Vec::new(); + + for _ in 0..self.num_searches { + // Random starting point + let mut point: MultiIndex = (0..n) + .map(|p| rng.random_range(0..local_dims[p])) + .collect(); + + // Greedy local search: sweep all dimensions + let mut best_error = 0.0f64; + let mut best_point = point.clone(); + + for p in 0..n { + for v in 0..local_dims[p] { + point[p] = v; + let error = (f(&point) - approx(&point)).abs(); + if error > best_error { + best_error = error; + best_point = point.clone(); + } + } + // Move to best value found for this dimension + point = best_point.clone(); + } + + if best_error > threshold { + found.push((best_point, best_error)); + } + } + + // Sort by error descending, deduplicate, truncate + found.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + found.dedup_by(|a, b| a.0 == b.0); + found.truncate(self.max_pivots); + found.into_iter().map(|(pivot, _)| pivot).collect() + } +} +``` + +注意: TCI2 の finder は `Fn(&MultiIndex) -> T` (ジェネリック型)を取るが、TreeTCI 版は +`f` と `approx` を分離して受け取る設計。これにより: +- `f` は batch_eval から作った point eval closure +- `approx` は materialized TreeTN の evaluate を使った closure +として呼び出し側で組み立てる。型パラメータ T に依存しない (f64 の誤差計算のみ)。 + +complex の場合: 呼び出し側で `|idx| { let v = f_complex(idx); v.abs_sq().sqrt() }` のように +abs 値に変換してから渡す。 + +- [ ] **Step 3: Add module to lib.rs** + +In `crates/tensor4all-treetci/src/lib.rs`, add: +```rust +pub mod globalpivot; +pub use globalpivot::DefaultTreeGlobalPivotFinder; +``` + +- [ ] **Step 4: Run test** + +Run: `cargo nextest run --release -p tensor4all-treetci test_find_global_pivots` + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-treetci/src/globalpivot.rs crates/tensor4all-treetci/src/lib.rs +git commit -m "feat(treetci): add DefaultTreeGlobalPivotFinder" +``` + +--- + +## Task 3: optimize ループに global pivot search を組み込み + +**Files:** +- Modify: `crates/tensor4all-treetci/src/optimize.rs` + +- [ ] **Step 1: optimize_with_proposer に global search を追加** + +`optimize_with_proposer` 関数の型境界に `FullPivLuScalar` を追加 +(`to_treetn` に必要)。ループの `_iter` ごとに、`global_search_interval > 0` かつ +`iter % interval == 0` のとき global pivot search を実行する。 + +```rust +use crate::globalpivot::DefaultTreeGlobalPivotFinder; +use crate::materialize::{to_treetn, FullPivLuScalar}; + +pub fn optimize_with_proposer( + state: &mut SimpleTreeTci, + batch_eval: F, + options: &TreeTciOptions, + proposer: &P, +) -> Result<(Vec, Vec)> +where + T: FullPivLuScalar, // was: Scalar. Tightened for to_treetn + DenseFaerLuKernel: PivotKernel, + F: Fn(GlobalIndexBatch<'_>) -> Result>, + P: PivotCandidateProposer, +{ + // ... existing setup ... + + let global_finder = if options.global_search_interval > 0 { + Some(DefaultTreeGlobalPivotFinder::new( + options.num_global_searches, + options.max_global_pivots, + options.global_search_tol_margin, + )) + } else { + None + }; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + for iter in 0..options.max_iter { + // ... existing inner edge passes ... + + ranks.push(state.max_rank()); + let normalized_error = /* ... existing ... */; + errors.push(normalized_error); + + // Global pivot search + if let Some(ref finder) = global_finder { + if (iter + 1) % options.global_search_interval == 0 { + let error_scale = if options.normalize_error && state.max_sample_value > 0.0 { + state.max_sample_value + } else { + 1.0 + }; + let abs_tol = options.tolerance * error_scale; + + // Materialize current state + let treetn = to_treetn(state, &batch_eval, Some(0))?; + + // Build point eval closures + let point_eval_f = |idx: &[usize]| -> f64 { + let batch_data: Vec = idx.to_vec(); + let batch = GlobalIndexBatch::new(&batch_data, idx.len(), 1).unwrap(); + match batch_eval(batch) { + Ok(vals) => T::abs_val(vals[0]), + Err(_) => 0.0, + } + }; + + let point_eval_approx = |idx: &[usize]| -> f64 { + let n_sites = state.local_dims.len(); + let index_values: std::collections::HashMap> = + (0..n_sites).map(|s| (s, vec![idx[s]])).collect(); + match treetn.evaluate(&index_values) { + Ok(scalar) => scalar.abs(), + Err(_) => 0.0, + } + }; + + let pivots = finder.find_pivots( + &state.local_dims, + &point_eval_f, + &point_eval_approx, + abs_tol, + &mut rng, + ); + + if !pivots.is_empty() { + state.add_global_pivots(&pivots)?; + } + } + } + + // Early exit on convergence (existing) + } + + Ok((ranks, errors)) +} +``` + +**重要な変更点:** +- 型境界: `T: Scalar` → `T: FullPivLuScalar` (`to_treetn` が必要) +- `optimize_default` も同じ型境界に更新が必要 +- `point_eval_f` で complex の場合は `T::abs_val` で f64 に変換 +- `point_eval_approx` は `AnyScalar::abs()` で f64 に変換 + +`Scalar` → `FullPivLuScalar` の変更が既存コードに影響しないか確認: +`FullPivLuScalar` は `f32, f64, Complex32, Complex64` に実装済み。 +`Scalar` のスーパートレイトなので、既存の呼び出し側が `f64` や `Complex64` を使っていれば問題なし。 + +- [ ] **Step 2: 早期終了条件に global pivot を考慮** + +現在は `max_iter` まで常に回っている。収束判定を追加: + +```rust + // Early exit: error below tolerance AND no global pivots added + if normalized_error < options.tolerance { + // If global search is enabled, only stop if last search found nothing + if global_finder.is_none() { + break; + } + // Otherwise continue until next global search finds nothing + } +``` + +(完全な収束判定は TCI2 の `convergence_criterion` を参考にするが、MVP ではシンプルに) + +- [ ] **Step 3: Verify it compiles** + +Run: `cargo build -p tensor4all-treetci --release` + +- [ ] **Step 4: Commit** + +```bash +git add crates/tensor4all-treetci/src/optimize.rs +git commit -m "feat(treetci): integrate global pivot search into optimize loop" +``` + +--- + +## Task 4: テスト — チェーン木 + nasty function (TCI.jl parity) + +**Files:** +- Create or modify: `crates/tensor4all-treetci/tests/global_pivot.rs` + +- [ ] **Step 1: Write integration test** + +Create `crates/tensor4all-treetci/tests/global_pivot.rs`: + +```rust +//! Integration test: global pivot search improves convergence on difficult functions. +//! +//! Adapted from TensorCrossInterpolation.jl test_tensorci2.jl "globalsearch" test. +//! Uses a chain tree (equivalent to MPS) with a nasty oscillatory function +//! that requires global pivot search for convergence. + +use tensor4all_treetci::{ + crossinterpolate_tree_with_proposer, SimpleProposer, + TreeTciEdge, TreeTciGraph, TreeTciOptions, +}; +use std::f64::consts::PI; + +/// Chain tree: 0--1--2--...--N-1 +fn chain_graph(n: usize) -> TreeTciGraph { + let edges: Vec = (0..n - 1) + .map(|i| TreeTciEdge::new(i, i + 1)) + .collect(); + TreeTciGraph::new(n, &edges).unwrap() +} + +/// Quantics-like encoding: bitlist → x in [0, 1) +fn bits_to_x(bits: &[usize], n_bits: usize) -> f64 { + let mut x = 0.0; + for (i, &b) in bits.iter().enumerate() { + x += (b as f64) * 2.0f64.powi(-(i as i32 + 1)); + } + x +} + +/// Nasty oscillatory function from TCI.jl test suite: +/// f(x) = exp(-10x) * sin(2π * 100 * x^1.1) +fn nasty_function(x: f64) -> f64 { + (-10.0 * x).exp() * (2.0 * PI * 100.0 * x.powf(1.1)).sin() +} + +#[test] +fn global_pivot_search_improves_convergence_on_nasty_function() { + let n_bits = 10; + let graph = chain_graph(n_bits); + let local_dims = vec![2; n_bits]; + + let f = |idx: &[usize]| -> f64 { + let x = bits_to_x(idx, n_bits); + nasty_function(x) + }; + + // Without global pivot search + let options_no_global = TreeTciOptions { + tolerance: 1e-8, + max_iter: 30, + max_bond_dim: 100, + normalize_error: true, + global_search_interval: 0, // disabled + ..Default::default() + }; + + let (_, _ranks_no, errors_no) = crossinterpolate_tree_with_proposer( + f, + None::) -> anyhow::Result>>, + local_dims.clone(), + graph.clone(), + vec![vec![0; n_bits]], + options_no_global, + Some(0), + &SimpleProposer::default(), + ).unwrap(); + + let error_no_global = errors_no.last().copied().unwrap_or(f64::INFINITY); + + // With global pivot search + let options_global = TreeTciOptions { + tolerance: 1e-8, + max_iter: 30, + max_bond_dim: 100, + normalize_error: true, + global_search_interval: 1, // every iteration + num_global_searches: 10, + max_global_pivots: 5, + global_search_tol_margin: 10.0, + }; + + let (_, _ranks_yes, errors_yes) = crossinterpolate_tree_with_proposer( + f, + None::) -> anyhow::Result>>, + local_dims.clone(), + graph.clone(), + vec![vec![0; n_bits]], + options_global, + Some(0), + &SimpleProposer::default(), + ).unwrap(); + + let error_global = errors_yes.last().copied().unwrap_or(f64::INFINITY); + + // Global pivot search should achieve better or equal convergence + // The nasty function typically requires global pivots for good convergence + eprintln!( + "Without global: error={:.2e}, With global: error={:.2e}", + error_no_global, error_global + ); + + // At minimum, with global search should converge below tolerance + assert!( + error_global < 1e-6, + "Global pivot search should help converge: got {:.2e}", + error_global + ); +} +``` + +注意: `crossinterpolate_tree_with_proposer` のシグネチャが `FullPivLuScalar` に変わるため、 +`f64` は問題なし。`point_eval` と `batch_eval` の両方を渡す現在の API に合わせる。 + +テスト関数は TCI.jl の `exp(-10x) * sin(2π * 100 * x^1.1)` そのもの。 +quantics grid の代わりに手動の `bits_to_x` で binary → float 変換。 + +- [ ] **Step 2: Run test** + +Run: `cargo nextest run --release -p tensor4all-treetci global_pivot` + +Expected: PASS. Global pivot search version achieves error < 1e-6. + +- [ ] **Step 3: Run full test suite** + +Run: `cargo nextest run --release --workspace` + +Expected: No regressions. `FullPivLuScalar` 型境界の変更が既存テストに影響しないことを確認。 + +- [ ] **Step 4: fmt + clippy** + +Run: `cargo fmt --all && cargo clippy --workspace --all-targets -- -D warnings` + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-treetci/tests/global_pivot.rs \ + crates/tensor4all-treetci/src/globalpivot.rs \ + crates/tensor4all-treetci/src/optimize.rs \ + crates/tensor4all-treetci/src/lib.rs +git commit -m "feat(treetci): global pivot search with nasty function test" +``` + +--- + +## 設計上の注意点 + +1. **`Scalar` → `FullPivLuScalar` 型境界の変更** + - `optimize_default` と `optimize_with_proposer` の両方で変更 + - `FullPivLuScalar: Scalar + TensorElement` なのでスーパートレイト + - `f32, f64, Complex32, Complex64` に実装済み → 既存コードに影響なし + - ただし `crossinterpolate_tree` / `crossinterpolate_tree_with_proposer` の型境界も連鎖的に更新が必要 + +2. **Materialization コスト** + - `to_treetn` は毎回全テンソルを再構築 → `global_search_interval` で頻度制御 + - デフォルト 0 (無効) なので、パフォーマンスへの影響はユーザーが opt-in + +3. **Complex 対応** + - `find_pivots` は `Fn(&[usize]) -> f64` を受け取る(実数の誤差値) + - 呼び出し側で `T::abs_val()` / `AnyScalar::abs()` で f64 に変換 + - Complex64 でも同じコードパスで動作 + +4. **RNG** + - `StdRng::seed_from_u64(42)` で決定論的。テストの再現性を保証 + - 将来的に `TreeTciOptions` に seed パラメータを追加可能 diff --git a/docs/plans/2026-03-30-treetci-julia-wrapper.md b/docs/plans/2026-03-30-treetci-julia-wrapper.md new file mode 100644 index 0000000..36d8f0d --- /dev/null +++ b/docs/plans/2026-03-30-treetci-julia-wrapper.md @@ -0,0 +1,797 @@ +# TreeTCI Julia Wrapper Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Julia ラッパーモジュール `TreeTCI` を Tensor4all.jl に追加し、tensor4all-rs の TreeTCI C API を呼び出せるようにする。 + +**Architecture:** `src/TreeTCI.jl` サブモジュールとして既存パターン(TensorCI, SimpleTT)に準拠。バッチコールバック trampoline で Julia 関数を C API に渡す。出力は既存の `TreeTN.TreeTensorNetwork` 型。 + +**Tech Stack:** Julia, ccall, Tensor4all.jl C_API module, tensor4all-capi TreeTCI functions + +**Prerequisite:** tensor4all-rs の `feat/treetci-capi` ブランチがビルドされ、`libtensor4all_capi.so` が `deps/` に配置されていること。 + +**Working directory:** `/home/shinaoka/tensor4all/Tensor4all.jl` + +--- + +## File Structure + +| Action | File | Responsibility | +|--------|------|----------------| +| Create | `src/TreeTCI.jl` | TreeTCI モジュール全体 | +| Modify | `src/Tensor4all.jl` | `include("TreeTCI.jl")` 追加 | +| Create | `test/test_treetci.jl` | TreeTCI テスト | +| Modify | `test/runtests.jl` | テストファイル include 追加 | + +--- + +## Task 1: Scaffold — モジュール骨格とメインモジュール統合 + +**Files:** +- Create: `src/TreeTCI.jl` +- Modify: `src/Tensor4all.jl` + +- [ ] **Step 1: Create empty TreeTCI module** + +Create `src/TreeTCI.jl`: + +```julia +""" + TreeTCI + +Tree-structured tensor cross interpolation via tensor4all-rs. + +Provides `TreeTciGraph` for defining tree topologies and `SimpleTreeTci` +for running TCI on arbitrary tree structures. Results are returned as +`TreeTN.TreeTensorNetwork`. + +# Usage +```julia +using Tensor4all.TreeTCI + +graph = TreeTciGraph(4, [(0,1), (1,2), (2,3)]) +f(batch) = [sum(Float64, batch[:, j]) for j in 1:size(batch, 2)] +ttn, ranks, errors = crossinterpolate_tree(f, [3, 3, 3, 3], graph) +``` +""" +module TreeTCI + +using ..C_API +import ..TreeTN: TreeTensorNetwork + +export TreeTciGraph, SimpleTreeTci +export crossinterpolate_tree + +end # module TreeTCI +``` + +- [ ] **Step 2: Add include to Tensor4all.jl** + +In `src/Tensor4all.jl`, add after the `include("QuanticsTransform.jl")` line (line 938): + +```julia +# ============================================================================ +# Tree-structured TCI (tree tensor cross interpolation). +# Use: using Tensor4all.TreeTCI +include("TreeTCI.jl") +``` + +- [ ] **Step 3: Verify it loads** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); using Tensor4all; using Tensor4all.TreeTCI; println("OK")'` + +Expected: `OK` (module loads without error) + +- [ ] **Step 4: Commit** + +```bash +git add src/TreeTCI.jl src/Tensor4all.jl +git commit -m "feat: scaffold TreeTCI module" +``` + +--- + +## Task 2: TreeTciGraph type + +**Files:** +- Modify: `src/TreeTCI.jl` +- Create: `test/test_treetci.jl` +- Modify: `test/runtests.jl` + +- [ ] **Step 1: Write the failing test** + +Create `test/test_treetci.jl`: + +```julia +using Tensor4all.TreeTCI +using Tensor4all.TreeTN: TreeTensorNetwork +using Test + +@testset "TreeTCI" begin + @testset "TreeTciGraph" begin + # Linear chain: 0-1-2-3 + graph = TreeTciGraph(4, [(0, 1), (1, 2), (2, 3)]) + @test graph.n_sites == 4 + @test graph.ptr != C_NULL + + # Star graph: 0 at center + graph_star = TreeTciGraph(4, [(0, 1), (0, 2), (0, 3)]) + @test graph_star.n_sites == 4 + + # Invalid: disconnected + @test_throws ErrorException TreeTciGraph(4, [(0, 1), (2, 3)]) + end +end +``` + +Add to `test/runtests.jl`, inside the `@testset "Tensor4all.jl"` block: + +```julia + include("test_treetci.jl") +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); Pkg.test()' 2>&1 | tail -20` + +Expected: FAIL — `TreeTciGraph` not defined. + +- [ ] **Step 3: Implement TreeTciGraph** + +Add to `src/TreeTCI.jl` (inside the module, before `end`): + +```julia +# ============================================================================ +# TreeTciGraph +# ============================================================================ + +""" + TreeTciGraph(n_sites, edges) + +Define a tree graph structure for TreeTCI. + +# Arguments +- `n_sites::Int`: Number of sites +- `edges::Vector{Tuple{Int,Int}}`: Edge list (0-based site indices) + +# Examples +```julia +# Linear chain: 0-1-2-3 +graph = TreeTciGraph(4, [(0, 1), (1, 2), (2, 3)]) + +# Star graph: 0 at center +graph = TreeTciGraph(4, [(0, 1), (0, 2), (0, 3)]) + +# 7-site branching tree +graph = TreeTciGraph(7, [(0,1), (1,2), (1,3), (3,4), (4,5), (4,6)]) +``` +""" +mutable struct TreeTciGraph + ptr::Ptr{Cvoid} + n_sites::Int + + function TreeTciGraph(n_sites::Int, edges::Vector{Tuple{Int,Int}}) + n_edges = length(edges) + edges_flat = Vector{Csize_t}(undef, n_edges * 2) + for (i, (u, v)) in enumerate(edges) + edges_flat[2i - 1] = Csize_t(u) + edges_flat[2i] = Csize_t(v) + end + ptr = ccall( + C_API._sym(:t4a_treetci_graph_new), + Ptr{Cvoid}, + (Csize_t, Ptr{Csize_t}, Csize_t), + Csize_t(n_sites), edges_flat, Csize_t(n_edges), + ) + if ptr == C_NULL + error("Failed to create TreeTciGraph: $(C_API.last_error_message())") + end + obj = new(ptr, n_sites) + finalizer(obj) do x + if x.ptr != C_NULL + ccall(C_API._sym(:t4a_treetci_graph_release), Cvoid, (Ptr{Cvoid},), x.ptr) + x.ptr = C_NULL + end + end + obj + end +end +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); include("test/test_treetci.jl")'` + +Expected: All 3 tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/TreeTCI.jl test/test_treetci.jl test/runtests.jl +git commit -m "feat: add TreeTciGraph type" +``` + +--- + +## Task 3: Batch callback trampoline and proposer helpers + +**Files:** +- Modify: `src/TreeTCI.jl` + +- [ ] **Step 1: Add trampoline and helpers** + +Add to `src/TreeTCI.jl` (after TreeTciGraph, before `end`): + +```julia +# ============================================================================ +# Batch Eval Trampoline +# ============================================================================ + +""" +Internal trampoline for C batch callback. + +The user function signature is: `f(batch::Matrix{Csize_t}) -> Vector{Float64}` +where `batch` is column-major (n_sites, n_points) with 0-based indices. +""" +function _treetci_batch_trampoline( + batch_data::Ptr{Csize_t}, + n_sites::Csize_t, + n_points::Csize_t, + results::Ptr{Cdouble}, + user_data::Ptr{Cvoid}, +)::Cint + try + f_ref = unsafe_pointer_to_objref(user_data)::Ref{Any} + f = f_ref[] + batch = unsafe_wrap(Array, batch_data, (Int(n_sites), Int(n_points))) + vals = f(batch) + for i in 1:Int(n_points) + unsafe_store!(results, Float64(vals[i]), i) + end + return Cint(0) + catch e + @error "TreeTCI batch eval callback error" exception = (e, catch_backtrace()) + return Cint(-1) + end +end + +const _BATCH_TRAMPOLINE_PTR = Ref{Ptr{Cvoid}}(C_NULL) + +function _get_batch_trampoline() + if _BATCH_TRAMPOLINE_PTR[] == C_NULL + _BATCH_TRAMPOLINE_PTR[] = @cfunction( + _treetci_batch_trampoline, + Cint, + (Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Cdouble}, Ptr{Cvoid}), + ) + end + _BATCH_TRAMPOLINE_PTR[] +end + +# ============================================================================ +# Proposer helpers +# ============================================================================ + +const _PROPOSER_DEFAULT = Cint(0) +const _PROPOSER_SIMPLE = Cint(1) +const _PROPOSER_TRUNCATED_DEFAULT = Cint(2) + +function _proposer_to_cint(proposer::Symbol)::Cint + if proposer === :default + _PROPOSER_DEFAULT + elseif proposer === :simple + _PROPOSER_SIMPLE + elseif proposer === :truncated_default + _PROPOSER_TRUNCATED_DEFAULT + else + error("Unknown proposer: $proposer. Use :default, :simple, or :truncated_default") + end +end +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); using Tensor4all.TreeTCI; println("OK")'` + +Expected: `OK` + +- [ ] **Step 3: Commit** + +```bash +git add src/TreeTCI.jl +git commit -m "feat: add TreeTCI batch trampoline and proposer helpers" +``` + +--- + +## Task 4: SimpleTreeTci — state, pivots, sweep, inspection + +**Files:** +- Modify: `src/TreeTCI.jl` +- Modify: `test/test_treetci.jl` + +- [ ] **Step 1: Write the failing test** + +Add to `test/test_treetci.jl`, inside the `@testset "TreeTCI"` block: + +```julia + @testset "Stateful API" begin + n_sites = 7 + local_dims = fill(2, n_sites) + edges = [(0, i) for i in 1:6] # star graph + graph = TreeTciGraph(n_sites, edges) + + # Product function: f(idx) = prod(idx[s] + 1.0) + function f_batch(batch) + n_pts = size(batch, 2) + results = Vector{Float64}(undef, n_pts) + for j in 1:n_pts + val = 1.0 + for i in 1:size(batch, 1) + val *= (batch[i, j] + 1.0) + end + results[j] = val + end + results + end + + tci = SimpleTreeTci(local_dims, graph) + add_global_pivots!(tci, [zeros(Int, n_sites)]) + + for _ in 1:4 + sweep!(tci, f_batch; tolerance=1e-12) + end + + @test max_bond_error(tci) < 1e-10 + @test max_rank(tci) >= 1 + @test max_sample_value(tci) > 0.0 + + bd = bond_dims(tci) + @test length(bd) == 6 # n_edges + @test all(d -> d >= 1, bd) + + ttn = to_treetn(tci, f_batch) + @test ttn isa TreeTensorNetwork + end +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); include("test/test_treetci.jl")' 2>&1 | tail -10` + +Expected: FAIL — `SimpleTreeTci` not defined. + +- [ ] **Step 3: Implement SimpleTreeTci and all methods** + +Add to `src/TreeTCI.jl` (after proposer helpers, before `end`): + +```julia +# ============================================================================ +# SimpleTreeTci +# ============================================================================ + +""" + SimpleTreeTci(local_dims, graph) + +Stateful TreeTCI object for tree-structured tensor cross interpolation. + +# Arguments +- `local_dims::Vector{Int}`: Local dimension at each site (length = graph.n_sites) +- `graph::TreeTciGraph`: Tree graph structure + +# Lifecycle +```julia +tci = SimpleTreeTci([2, 2, 2, 2], graph) +add_global_pivots!(tci, [zeros(Int, 4)]) +for i in 1:max_iter + sweep!(tci, f; tolerance=1e-8) + max_bond_error(tci) < 1e-8 && break +end +ttn = to_treetn(tci, f) +``` +""" +mutable struct SimpleTreeTci + ptr::Ptr{Cvoid} + graph::TreeTciGraph # prevent GC + local_dims::Vector{Int} + + function SimpleTreeTci(local_dims::Vector{Int}, graph::TreeTciGraph) + length(local_dims) == graph.n_sites || + error("local_dims length ($(length(local_dims))) != graph.n_sites ($(graph.n_sites))") + dims_csize = Csize_t.(local_dims) + ptr = ccall( + C_API._sym(:t4a_treetci_f64_new), + Ptr{Cvoid}, + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}), + dims_csize, Csize_t(length(dims_csize)), graph.ptr, + ) + if ptr == C_NULL + error("Failed to create SimpleTreeTci: $(C_API.last_error_message())") + end + obj = new(ptr, graph, local_dims) + finalizer(obj) do x + if x.ptr != C_NULL + ccall(C_API._sym(:t4a_treetci_f64_release), Cvoid, (Ptr{Cvoid},), x.ptr) + x.ptr = C_NULL + end + end + obj + end +end + +# ============================================================================ +# Pivot management +# ============================================================================ + +""" + add_global_pivots!(tci, pivots) + +Add global pivots. Each pivot is a full multi-index over all sites (0-based). + +# Arguments +- `tci::SimpleTreeTci` +- `pivots::Vector{Vector{Int}}`: Each element has length n_sites, 0-based indices +""" +function add_global_pivots!(tci::SimpleTreeTci, pivots::Vector{Vector{Int}}) + n_sites = length(tci.local_dims) + n_pivots = length(pivots) + n_pivots == 0 && return + pivots_flat = Vector{Csize_t}(undef, n_sites * n_pivots) + for j in 1:n_pivots + length(pivots[j]) == n_sites || + error("Pivot $j has length $(length(pivots[j])), expected $n_sites") + for i in 1:n_sites + pivots_flat[i + n_sites * (j - 1)] = Csize_t(pivots[j][i]) + end + end + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_add_global_pivots), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t), + tci.ptr, pivots_flat, Csize_t(n_sites), Csize_t(n_pivots), + )) +end + +# ============================================================================ +# Sweep +# ============================================================================ + +""" + sweep!(tci, f; proposer=:default, tolerance=1e-8, max_bond_dim=0) + +Run one optimization iteration (visit all edges once). + +# Arguments +- `tci::SimpleTreeTci` +- `f`: Batch evaluation function `f(batch::Matrix{Csize_t}) -> Vector{Float64}` + - `batch` is column-major (n_sites, n_points), 0-based indices +- `proposer`: `:default`, `:simple`, or `:truncated_default` +- `tolerance`: Relative tolerance +- `max_bond_dim`: Maximum bond dimension (0 = unlimited) +""" +function sweep!(tci::SimpleTreeTci, f; + proposer::Symbol = :default, + tolerance::Float64 = 1e-8, + max_bond_dim::Int = 0, +) + f_ref = Ref{Any}(f) + GC.@preserve f_ref begin + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_sweep), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cint, Cdouble, Csize_t), + tci.ptr, + _get_batch_trampoline(), + pointer_from_objref(f_ref), + _proposer_to_cint(proposer), + tolerance, + Csize_t(max_bond_dim), + )) + end +end + +# ============================================================================ +# State inspection +# ============================================================================ + +"""Maximum bond error across all edges.""" +function max_bond_error(tci::SimpleTreeTci)::Float64 + out = Ref{Cdouble}(0.0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_max_bond_error), + Cint, (Ptr{Cvoid}, Ptr{Cdouble}), + tci.ptr, out, + )) + out[] +end + +"""Maximum rank (bond dimension) across all edges.""" +function max_rank(tci::SimpleTreeTci)::Int + out = Ref{Csize_t}(0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_max_rank), + Cint, (Ptr{Cvoid}, Ptr{Csize_t}), + tci.ptr, out, + )) + Int(out[]) +end + +"""Maximum observed sample value (for normalization).""" +function max_sample_value(tci::SimpleTreeTci)::Float64 + out = Ref{Cdouble}(0.0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_max_sample_value), + Cint, (Ptr{Cvoid}, Ptr{Cdouble}), + tci.ptr, out, + )) + out[] +end + +"""Bond dimensions (ranks) at each edge.""" +function bond_dims(tci::SimpleTreeTci)::Vector{Int} + n_edges_ref = Ref{Csize_t}(0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_bond_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + tci.ptr, C_NULL, Csize_t(0), n_edges_ref, + )) + n_edges = Int(n_edges_ref[]) + buf = Vector{Csize_t}(undef, n_edges) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_bond_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + tci.ptr, buf, Csize_t(n_edges), n_edges_ref, + )) + Int.(buf) +end + +# ============================================================================ +# Materialization +# ============================================================================ + +""" + to_treetn(tci, f; center_site=0) + +Convert converged TreeTCI state to a TreeTensorNetwork. + +# Arguments +- `tci::SimpleTreeTci`: Converged state +- `f`: Batch evaluation function (same as `sweep!`) +- `center_site`: BFS root site for materialization (0-based) +""" +function to_treetn(tci::SimpleTreeTci, f; center_site::Int = 0) + f_ref = Ref{Any}(f) + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + GC.@preserve f_ref begin + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_to_treetn), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Ptr{Ptr{Cvoid}}), + tci.ptr, + _get_batch_trampoline(), + pointer_from_objref(f_ref), + Csize_t(center_site), + out_ptr, + )) + end + TreeTensorNetwork(out_ptr[]) +end +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); include("test/test_treetci.jl")'` + +Expected: All tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/TreeTCI.jl test/test_treetci.jl +git commit -m "feat: add SimpleTreeTci with sweep, inspection, and materialization" +``` + +--- + +## Task 5: High-level convenience function + +**Files:** +- Modify: `src/TreeTCI.jl` +- Modify: `test/test_treetci.jl` + +- [ ] **Step 1: Write the failing test** + +Add to `test/test_treetci.jl`, inside the `@testset "TreeTCI"` block: + +```julia + @testset "High-level API" begin + n_sites = 4 + local_dims = fill(3, n_sites) + graph = TreeTciGraph(n_sites, [(0, 1), (1, 2), (2, 3)]) + + f_batch(batch) = [sum(Float64, batch[:, j]) for j in 1:size(batch, 2)] + + ttn, ranks, errors = crossinterpolate_tree( + f_batch, local_dims, graph; + initial_pivots = [zeros(Int, n_sites)], + tolerance = 1e-10, + max_iter = 20, + ) + + @test ttn isa TreeTensorNetwork + @test length(ranks) > 0 + @test last(errors) < 1e-8 + end + + @testset "High-level API without initial pivots" begin + # Tests the default zero-pivot behavior + n_sites = 3 + local_dims = fill(2, n_sites) + graph = TreeTciGraph(n_sites, [(0, 1), (1, 2)]) + + f_batch(batch) = [prod(batch[i, j] + 1.0 for i in 1:size(batch, 1)) for j in 1:size(batch, 2)] + + ttn, ranks, errors = crossinterpolate_tree(f_batch, local_dims, graph) + + @test ttn isa TreeTensorNetwork + @test length(ranks) > 0 + end +``` + +- [ ] **Step 2: Run test to verify it fails** + +Expected: FAIL — `crossinterpolate_tree` not defined. + +- [ ] **Step 3: Implement crossinterpolate_tree** + +Add to `src/TreeTCI.jl` (after `to_treetn`, before `end # module`): + +```julia +# ============================================================================ +# High-level convenience function +# ============================================================================ + +""" + crossinterpolate_tree(f, local_dims, graph; kwargs...) -> (ttn, ranks, errors) + +Run TreeTCI to convergence and return a TreeTensorNetwork. + +# Arguments +- `f`: Batch evaluation function `f(batch::Matrix{Csize_t}) -> Vector{Float64}` +- `local_dims::Vector{Int}`: Local dimension at each site +- `graph::TreeTciGraph`: Tree graph structure + +# Keyword Arguments +- `initial_pivots::Vector{Vector{Int}} = []`: Initial pivots (0-based). If empty, defaults to zero pivot. +- `proposer::Symbol = :default`: `:default`, `:simple`, or `:truncated_default` +- `tolerance::Float64 = 1e-8`: Relative tolerance +- `max_bond_dim::Int = 0`: Maximum bond dimension (0 = unlimited) +- `max_iter::Int = 20`: Maximum iterations +- `normalize_error::Bool = true`: Normalize errors by max sample value +- `center_site::Int = 0`: Materialization center site (0-based) + +# Returns +- `ttn::TreeTensorNetwork` +- `ranks::Vector{Int}`: Max rank per iteration +- `errors::Vector{Float64}`: Normalized error per iteration +""" +function crossinterpolate_tree( + f, local_dims::Vector{Int}, graph::TreeTciGraph; + initial_pivots::Vector{Vector{Int}} = Vector{Int}[], + proposer::Symbol = :default, + tolerance::Float64 = 1e-8, + max_bond_dim::Int = 0, + max_iter::Int = 20, + normalize_error::Bool = true, + center_site::Int = 0, +) + n_sites = length(local_dims) + n_pivots = length(initial_pivots) + + # Pack initial pivots column-major (n_sites, n_pivots) + pivots_flat = if n_pivots > 0 + buf = Vector{Csize_t}(undef, n_sites * n_pivots) + for j in 1:n_pivots + length(initial_pivots[j]) == n_sites || + error("Pivot $j has length $(length(initial_pivots[j])), expected $n_sites") + for i in 1:n_sites + buf[i + n_sites * (j - 1)] = Csize_t(initial_pivots[j][i]) + end + end + buf + else + Csize_t[] + end + + # Output buffers (pre-allocate max_iter) + out_ranks = Vector{Csize_t}(undef, max_iter) + out_errors = Vector{Cdouble}(undef, max_iter) + out_n_iters = Ref{Csize_t}(0) + out_treetn = Ref{Ptr{Cvoid}}(C_NULL) + + dims_csize = Csize_t.(local_dims) + f_ref = Ref{Any}(f) + + GC.@preserve f_ref pivots_flat out_ranks out_errors dims_csize begin + C_API.check_status(ccall( + C_API._sym(:t4a_crossinterpolate_tree_f64), + Cint, + ( + Ptr{Cvoid}, Ptr{Cvoid}, # eval_cb, user_data + Ptr{Csize_t}, Csize_t, # local_dims, n_sites + Ptr{Cvoid}, # graph + Ptr{Csize_t}, Csize_t, # initial_pivots, n_pivots + Cint, # proposer_kind + Cdouble, Csize_t, Csize_t, # tol, max_bond_dim, max_iter + Cint, # normalize_error + Csize_t, # center_site + Ptr{Ptr{Cvoid}}, # out_treetn + Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}, # out_ranks, errors, n_iters + ), + _get_batch_trampoline(), + pointer_from_objref(f_ref), + dims_csize, Csize_t(n_sites), + graph.ptr, + n_pivots > 0 ? pivots_flat : C_NULL, Csize_t(n_pivots), + _proposer_to_cint(proposer), + tolerance, Csize_t(max_bond_dim), Csize_t(max_iter), + normalize_error ? Cint(1) : Cint(0), + Csize_t(center_site), + out_treetn, + out_ranks, out_errors, out_n_iters, + )) + end + + n_iters = Int(out_n_iters[]) + ttn = TreeTensorNetwork(out_treetn[]) + ranks = Int.(out_ranks[1:n_iters]) + errors = Float64.(out_errors[1:n_iters]) + return (ttn, ranks, errors) +end +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no -e 'using Pkg; Pkg.activate("."); include("test/test_treetci.jl")'` + +Expected: All tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/TreeTCI.jl test/test_treetci.jl +git commit -m "feat: add crossinterpolate_tree high-level API" +``` + +--- + +## Task 6: Final validation + +**Files:** None (validation only) + +- [ ] **Step 1: Ensure tensor4all-rs is on the correct branch** + +```bash +cd /home/shinaoka/tensor4all/tensor4all-rs && git checkout feat/treetci-capi +``` + +- [ ] **Step 2: Rebuild the Rust library via Pkg.build** + +`deps/build.jl` が自動的に sibling の `../tensor4all-rs/` を検出して RustToolChain.jl 経由でビルドする。 + +```bash +cd /home/shinaoka/tensor4all/Tensor4all.jl +julia --startup-file=no -e 'using Pkg; Pkg.activate("."); Pkg.build()' +``` + +Expected: `Build complete! Library installed to: .../deps/libtensor4all_capi.so` + +- [ ] **Step 3: Run full test suite** + +```bash +julia --startup-file=no -e 'using Pkg; Pkg.activate("."); Pkg.test()' +``` + +Expected: All tests pass, including new TreeTCI tests. No regressions. diff --git a/docs/plans/2026-03-31-api-expansion.md b/docs/plans/2026-03-31-api-expansion.md new file mode 100644 index 0000000..87afafa --- /dev/null +++ b/docs/plans/2026-03-31-api-expansion.md @@ -0,0 +1,722 @@ +# Tensor4all.jl API Expansion Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Expand Tensor4all.jl to wrap new C API functions, unify to 1-indexed, align naming with Pure Julia ecosystem, add ComplexF64 support, and add type conversions. + +**Architecture:** Add ccall bindings in C_API.jl for ~40 new Rust C API functions. Refactor SimpleTT.jl for ComplexF64 support via type dispatch (_f64/_c64). Rename functions to match Pure Julia convention (no underscores). Convert all user-facing indices to 1-based. Add operator overloading. Update QuanticsTCI for ComplexF64 and tuple returns. Add SimpleTT↔TreeTN conversion and TCI extension. + +**Tech Stack:** Julia, ccall FFI, Tensor4all.jl modules (SimpleTT, TreeTN, QuanticsTCI, QuanticsGrids, C_API) + +**Spec:** `docs/specs/2026-03-31-api-expansion-design.md` + +--- + +## File Structure + +| File | Action | Responsibility | +|------|--------|---------------| +| `src/C_API.jl` | Modify | Add ccall bindings for ~40 new C API functions | +| `src/SimpleTT.jl` | Rewrite | ComplexF64, 1-indexed, renamed functions, new operations | +| `src/QuanticsTCI.jl` | Rewrite | ComplexF64, tuple return, QtciOptions, new accessors | +| `src/QuanticsGrids.jl` | Modify | Rename `local_dimensions` → `localdimensions` | +| `src/TreeTN.jl` | Modify | Add `MPS(tt::SimpleTensorTrain)` constructor | +| `src/TreeTCI.jl` | Modify | 1-indexed pivots/evaluate/callbacks | +| `ext/Tensor4allTCIExt.jl` | Create | TCI.TensorTrain ↔ SimpleTensorTrain conversion | +| `Project.toml` | Modify | Add TensorCrossInterpolation weak dep | +| `test/test_simplett.jl` | Rewrite | Tests for all SimpleTT changes | +| `test/test_quanticstci.jl` | Create | Tests for QuanticsTCI updates | +| `test/test_treetci.jl` | Create | Tests for TreeTCI 1-indexed | +| `test/test_conversions.jl` | Create | Tests for SimpleTT↔TreeTN conversion | +| `test/runtests.jl` | Modify | Include new test files | + +--- + +## Task 1: C_API.jl — Add SimpleTT new function bindings + +**Files:** +- Modify: `src/C_API.jl` + +Add ccall wrappers for all new SimpleTT C API functions. Follow existing patterns in C_API.jl (look at `t4a_simplett_f64_evaluate` etc. for reference). + +- [ ] **Step 1: Add f64 operation bindings** + +Add to C_API.jl after existing simplett_f64 functions: + +```julia +# SimpleTT f64 new operations +function t4a_simplett_f64_from_site_tensors(n_sites, left_dims, site_dims, right_dims, data, data_len, out_ptr) + ccall(_sym(:t4a_simplett_f64_from_site_tensors), Cint, + (Csize_t, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Cdouble}, Csize_t, Ptr{Ptr{Cvoid}}), + n_sites, left_dims, site_dims, right_dims, data, data_len, out_ptr) +end + +function t4a_simplett_f64_add(a, b, out) + ccall(_sym(:t4a_simplett_f64_add), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), a, b, out) +end + +function t4a_simplett_f64_scale(ptr, factor) + ccall(_sym(:t4a_simplett_f64_scale), Cint, + (Ptr{Cvoid}, Cdouble), ptr, factor) +end + +function t4a_simplett_f64_dot(a, b, out_value) + ccall(_sym(:t4a_simplett_f64_dot), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cdouble}), a, b, out_value) +end + +function t4a_simplett_f64_reverse(ptr, out) + ccall(_sym(:t4a_simplett_f64_reverse), Cint, + (Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), ptr, out) +end + +function t4a_simplett_f64_fulltensor(ptr, out_data, buf_len, out_data_len) + ccall(_sym(:t4a_simplett_f64_fulltensor), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}), ptr, out_data, buf_len, out_data_len) +end +``` + +- [ ] **Step 2: Add c64 SimpleTT bindings (all functions including existing equivalents)** + +Add full c64 SimpleTT API. c64 needs release, clone, len, sitedims, linkdims, rank, evaluate, sum, norm, site_tensor, compress, partial_sum, PLUS the new operations: + +```julia +# SimpleTT c64 lifecycle +function t4a_simplett_c64_release(ptr) + ccall(_sym(:t4a_simplett_c64_release), Cvoid, (Ptr{Cvoid},), ptr) +end + +function t4a_simplett_c64_clone(ptr) + ccall(_sym(:t4a_simplett_c64_clone), Ptr{Cvoid}, (Ptr{Cvoid},), ptr) +end + +function t4a_simplett_c64_constant(site_dims, value_re, value_im) + ccall(_sym(:t4a_simplett_c64_constant), Ptr{Cvoid}, + (Ptr{Csize_t}, Cdouble, Cdouble), site_dims, value_re, value_im) +end + +function t4a_simplett_c64_zeros(site_dims) + ccall(_sym(:t4a_simplett_c64_zeros), Ptr{Cvoid}, (Ptr{Csize_t},), site_dims) +end + +function t4a_simplett_c64_len(ptr, out_len) + ccall(_sym(:t4a_simplett_c64_len), Cint, (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_len) +end + +function t4a_simplett_c64_site_dims(ptr, out_dims) + ccall(_sym(:t4a_simplett_c64_site_dims), Cint, (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_dims) +end + +function t4a_simplett_c64_link_dims(ptr, out_dims) + ccall(_sym(:t4a_simplett_c64_link_dims), Cint, (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_dims) +end + +function t4a_simplett_c64_rank(ptr, out_rank) + ccall(_sym(:t4a_simplett_c64_rank), Cint, (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_rank) +end + +function t4a_simplett_c64_evaluate(ptr, indices, out_re, out_im) + ccall(_sym(:t4a_simplett_c64_evaluate), Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Cdouble}), ptr, indices, out_re, out_im) +end + +function t4a_simplett_c64_sum(ptr, out_re, out_im) + ccall(_sym(:t4a_simplett_c64_sum), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), ptr, out_re, out_im) +end + +function t4a_simplett_c64_norm(ptr, out_value) + ccall(_sym(:t4a_simplett_c64_norm), Cint, (Ptr{Cvoid}, Ptr{Cdouble}), ptr, out_value) +end + +function t4a_simplett_c64_site_tensor(ptr, site, out_data, buf_len, out_left, out_site, out_right) + ccall(_sym(:t4a_simplett_c64_site_tensor), Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}), + ptr, site, out_data, buf_len, out_left, out_site, out_right) +end + +function t4a_simplett_c64_compress(ptr, method, tolerance, max_bonddim) + ccall(_sym(:t4a_simplett_c64_compress), Cint, + (Ptr{Cvoid}, Cint, Cdouble, Csize_t), ptr, method, tolerance, max_bonddim) +end + +function t4a_simplett_c64_partial_sum(ptr, dims, out) + ccall(_sym(:t4a_simplett_c64_partial_sum), Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Ptr{Ptr{Cvoid}}), ptr, dims, out) +end + +# c64 new operations +function t4a_simplett_c64_from_site_tensors(n_sites, left_dims, site_dims, right_dims, data, data_len, out_ptr) + ccall(_sym(:t4a_simplett_c64_from_site_tensors), Cint, + (Csize_t, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Cdouble}, Csize_t, Ptr{Ptr{Cvoid}}), + n_sites, left_dims, site_dims, right_dims, data, data_len, out_ptr) +end + +function t4a_simplett_c64_add(a, b, out) + ccall(_sym(:t4a_simplett_c64_add), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), a, b, out) +end + +function t4a_simplett_c64_scale(ptr, factor_re, factor_im) + ccall(_sym(:t4a_simplett_c64_scale), Cint, + (Ptr{Cvoid}, Cdouble, Cdouble), ptr, factor_re, factor_im) +end + +function t4a_simplett_c64_dot(a, b, out_re, out_im) + ccall(_sym(:t4a_simplett_c64_dot), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), a, b, out_re, out_im) +end + +function t4a_simplett_c64_reverse(ptr, out) + ccall(_sym(:t4a_simplett_c64_reverse), Cint, + (Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), ptr, out) +end + +function t4a_simplett_c64_fulltensor(ptr, out_data, buf_len, out_data_len) + ccall(_sym(:t4a_simplett_c64_fulltensor), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}), ptr, out_data, buf_len, out_data_len) +end +``` + +- [ ] **Step 3: Verify C_API loads** + +Run: `julia --startup-file=no -e "using Pkg; Pkg.activate(\".\"); include(\"src/C_API.jl\")"` +Expected: No errors. + +- [ ] **Step 4: Commit** + +```bash +git add src/C_API.jl +git commit -m "feat(C_API): add ccall bindings for SimpleTT new operations (f64+c64)" +``` + +--- + +## Task 2: C_API.jl — Add QuanticsTCI new function bindings + +**Files:** +- Modify: `src/C_API.jl` + +- [ ] **Step 1: Add QtciOptions bindings** + +```julia +# QtciOptions lifecycle +function t4a_qtci_options_default() + ccall(_sym(:t4a_qtci_options_default), Ptr{Cvoid}, ()) +end + +function t4a_qtci_options_release(ptr) + ccall(_sym(:t4a_qtci_options_release), Cvoid, (Ptr{Cvoid},), ptr) +end + +function t4a_qtci_options_clone(ptr) + ccall(_sym(:t4a_qtci_options_clone), Ptr{Cvoid}, (Ptr{Cvoid},), ptr) +end + +# QtciOptions setters +function t4a_qtci_options_set_tolerance(ptr, tol) + ccall(_sym(:t4a_qtci_options_set_tolerance), Cint, (Ptr{Cvoid}, Cdouble), ptr, tol) +end + +function t4a_qtci_options_set_maxbonddim(ptr, dim) + ccall(_sym(:t4a_qtci_options_set_maxbonddim), Cint, (Ptr{Cvoid}, Csize_t), ptr, dim) +end + +function t4a_qtci_options_set_maxiter(ptr, iter) + ccall(_sym(:t4a_qtci_options_set_maxiter), Cint, (Ptr{Cvoid}, Csize_t), ptr, iter) +end + +function t4a_qtci_options_set_nrandominitpivot(ptr, n) + ccall(_sym(:t4a_qtci_options_set_nrandominitpivot), Cint, (Ptr{Cvoid}, Csize_t), ptr, n) +end + +function t4a_qtci_options_set_unfoldingscheme(ptr, scheme) + ccall(_sym(:t4a_qtci_options_set_unfoldingscheme), Cint, (Ptr{Cvoid}, Cint), ptr, scheme) +end + +function t4a_qtci_options_set_normalize_error(ptr, flag) + ccall(_sym(:t4a_qtci_options_set_normalize_error), Cint, (Ptr{Cvoid}, Cint), ptr, flag) +end + +function t4a_qtci_options_set_verbosity(ptr, level) + ccall(_sym(:t4a_qtci_options_set_verbosity), Cint, (Ptr{Cvoid}, Csize_t), ptr, level) +end + +function t4a_qtci_options_set_nsearchglobalpivot(ptr, n) + ccall(_sym(:t4a_qtci_options_set_nsearchglobalpivot), Cint, (Ptr{Cvoid}, Csize_t), ptr, n) +end + +function t4a_qtci_options_set_nsearch(ptr, n) + ccall(_sym(:t4a_qtci_options_set_nsearch), Cint, (Ptr{Cvoid}, Csize_t), ptr, n) +end +``` + +- [ ] **Step 2: Add QuanticsTCI c64 + updated f64 bindings** + +Note: The existing `t4a_quanticscrossinterpolate_f64` signature has CHANGED (breaking). It now takes options, initial_pivots, and convergence output buffers. Update the existing binding AND add c64 variants. + +```julia +# Updated f64 crossinterpolate (BREAKING: new signature) +function t4a_quanticscrossinterpolate_f64(grid, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) + ccall(_sym(:t4a_quanticscrossinterpolate_f64), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, + Cdouble, Csize_t, Csize_t, Ptr{Int64}, Csize_t, + Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + grid, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) +end + +# Updated discrete f64 (BREAKING: new signature) +function t4a_quanticscrossinterpolate_discrete_f64(sizes, ndims, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, unfoldingscheme, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) + ccall(_sym(:t4a_quanticscrossinterpolate_discrete_f64), Cint, + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, + Cdouble, Csize_t, Csize_t, Cint, Ptr{Int64}, Csize_t, + Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + sizes, ndims, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, unfoldingscheme, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) +end + +# c64 crossinterpolate +function t4a_quanticscrossinterpolate_c64(grid, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) + ccall(_sym(:t4a_quanticscrossinterpolate_c64), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, + Cdouble, Csize_t, Csize_t, Ptr{Int64}, Csize_t, + Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + grid, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) +end + +function t4a_quanticscrossinterpolate_discrete_c64(sizes, ndims, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, unfoldingscheme, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) + ccall(_sym(:t4a_quanticscrossinterpolate_discrete_c64), Cint, + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, + Cdouble, Csize_t, Csize_t, Cint, Ptr{Int64}, Csize_t, + Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + sizes, ndims, eval_fn, user_data, options, + tolerance, max_bonddim, max_iter, unfoldingscheme, initial_pivots, n_pivots, + out_qtci, out_ranks, out_errors, out_n_iters) +end + +# c64 QTCI lifecycle + accessors +function t4a_qtci_c64_release(ptr) + ccall(_sym(:t4a_qtci_c64_release), Cvoid, (Ptr{Cvoid},), ptr) +end + +function t4a_qtci_c64_clone(ptr) + ccall(_sym(:t4a_qtci_c64_clone), Ptr{Cvoid}, (Ptr{Cvoid},), ptr) +end + +function t4a_qtci_c64_rank(ptr, out_rank) + ccall(_sym(:t4a_qtci_c64_rank), Cint, (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_rank) +end + +function t4a_qtci_c64_link_dims(ptr, out_dims, buf_len) + ccall(_sym(:t4a_qtci_c64_link_dims), Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), ptr, out_dims, buf_len) +end + +function t4a_qtci_c64_evaluate(ptr, indices, n_indices, out_re, out_im) + ccall(_sym(:t4a_qtci_c64_evaluate), Cint, + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}), ptr, indices, n_indices, out_re, out_im) +end + +function t4a_qtci_c64_sum(ptr, out_re, out_im) + ccall(_sym(:t4a_qtci_c64_sum), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), ptr, out_re, out_im) +end + +function t4a_qtci_c64_integral(ptr, out_re, out_im) + ccall(_sym(:t4a_qtci_c64_integral), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), ptr, out_re, out_im) +end + +function t4a_qtci_c64_to_tensor_train(ptr) + ccall(_sym(:t4a_qtci_c64_to_tensor_train), Ptr{Cvoid}, (Ptr{Cvoid},), ptr) +end + +# TreeTCI2 state accessors (f64 + c64) +function t4a_qtci_f64_max_bond_error(ptr, out_value) + ccall(_sym(:t4a_qtci_f64_max_bond_error), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), ptr, out_value) +end + +function t4a_qtci_f64_max_rank(ptr, out_rank) + ccall(_sym(:t4a_qtci_f64_max_rank), Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_rank) +end + +function t4a_qtci_c64_max_bond_error(ptr, out_value) + ccall(_sym(:t4a_qtci_c64_max_bond_error), Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), ptr, out_value) +end + +function t4a_qtci_c64_max_rank(ptr, out_rank) + ccall(_sym(:t4a_qtci_c64_max_rank), Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), ptr, out_rank) +end +``` + +- [ ] **Step 3: Commit** + +```bash +git add src/C_API.jl +git commit -m "feat(C_API): add ccall bindings for QtciOptions, QuanticsTCI c64, state accessors" +``` + +--- + +## Task 3: Rewrite SimpleTT.jl + +Complete rewrite of `src/SimpleTT.jl` to support ComplexF64, 1-indexed, renamed functions, and new operations. The implementer should read the current file first, then replace it entirely. + +**Files:** +- Modify: `src/SimpleTT.jl` + +**Key design decisions for the implementer:** + +1. **Type dispatch pattern:** Use `_suffix(::Type{Float64}) = "f64"` and `_api(T, name) = getfield(C_API, Symbol("t4a_simplett_", _suffix(T), "_", name))` pattern (already exists in current code) for dispatching f64/c64 calls. + +2. **ComplexF64 type parameter:** `SimpleTensorTrain{T}` where `T <: Union{Float64, ComplexF64}`. The `_SimpleTTScalar = Union{Float64, ComplexF64}`. + +3. **1-indexed:** All user-facing functions use 1-based indices. Internal C API calls subtract 1. + +4. **Renamed functions (no aliases):** + - `site_dims` → `sitedims` + - `link_dims` → `linkdims` + - `site_tensor` → `sitetensor` + +5. **New operations to add:** + - `Base.:+(a, b)` — calls `_add` C API + - `Base.:-(a, b)` — `a + (-1 * b)` + - `Base.:*(α, tt)` and `Base.:*(tt, α)` — clone + scale + - `LinearAlgebra.dot(a, b)` — calls `_dot` C API + - `scale!(tt, α)` — in-place, calls `_scale` C API + - `Base.reverse(tt)` — calls `_reverse` C API + - `fulltensor(tt)` — calls `_fulltensor` C API + - `SimpleTensorTrain(site_tensors::Vector{<:AbstractArray{T,3}})` — calls `_from_site_tensors` + +6. **ComplexF64 specifics:** + - `evaluate` returns `ComplexF64` via (out_re, out_im) + - `sum` returns `ComplexF64` via (out_re, out_im) + - `sitetensor` returns `Array{ComplexF64,3}` by reinterpreting interleaved doubles + - `scale!(tt, α::Complex)` passes (re, im) + - `dot` returns `ComplexF64` via (out_re, out_im) + - `from_site_tensors` converts ComplexF64 data to interleaved doubles + - `fulltensor` reinterprets interleaved doubles to ComplexF64 + +- [ ] **Step 1: Rewrite SimpleTT.jl** + +The implementer should write the complete new SimpleTT.jl following the design decisions above. Reference the existing code for patterns, but produce a complete rewrite. All functions must work for both Float64 and ComplexF64 via type dispatch. + +- [ ] **Step 2: Verify it compiles** + +Run: `julia --startup-file=no -e "using Pkg; Pkg.activate(\".\"); Pkg.instantiate(); using Tensor4all; using Tensor4all.SimpleTT"` +Expected: No errors. + +- [ ] **Step 3: Commit** + +```bash +git add src/SimpleTT.jl +git commit -m "feat(SimpleTT): rewrite with ComplexF64, 1-indexed, renamed functions, new operations" +``` + +--- + +## Task 4: Rewrite test_simplett.jl + +**Files:** +- Modify: `test/test_simplett.jl` +- Modify: `test/runtests.jl` — add `include("test_simplett.jl")` if not present + +- [ ] **Step 1: Write comprehensive tests** + +Tests must cover for BOTH Float64 and ComplexF64: +- Construction: constant, zeros, from site tensors +- Accessors: length, sitedims, linkdims, rank +- Evaluation: 1-indexed evaluate, callable interface `tt(1, 2, 3)` +- Site tensor: 1-indexed sitetensor +- Arithmetic: `+`, `-`, `*` (scalar), `dot` +- In-place: `scale!` +- Other: `reverse`, `fulltensor`, `copy`, `norm`, `sum` +- Verify 1-indexing: `evaluate(tt, [1,1,1])` works, `evaluate(tt, [0,0,0])` errors + +- [ ] **Step 2: Run tests** + +Run: `julia --startup-file=no --project=. -e "using Pkg; Pkg.test()"` +Or: `julia --startup-file=no --project=. test/test_simplett.jl` +Expected: All tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add test/test_simplett.jl test/runtests.jl +git commit -m "test(SimpleTT): comprehensive tests for rewritten module" +``` + +--- + +## Task 5: Rename QuanticsGrids.localdimensions + +**Files:** +- Modify: `src/QuanticsGrids.jl` + +- [ ] **Step 1: Rename function** + +In `src/QuanticsGrids.jl`, rename `local_dimensions` → `localdimensions` for both `DiscretizedGrid` and `InherentDiscreteGrid`. Update the export list. + +- [ ] **Step 2: Update callers** + +Search for `local_dimensions` in all src/ files and update to `localdimensions`. Key callers: `src/QuanticsTCI.jl`. + +- [ ] **Step 3: Commit** + +```bash +git add src/QuanticsGrids.jl src/QuanticsTCI.jl +git commit -m "refactor: rename local_dimensions to localdimensions" +``` + +--- + +## Task 6: Rewrite QuanticsTCI.jl + +**Files:** +- Modify: `src/QuanticsTCI.jl` + +**Key changes:** +1. `QuanticsTensorCI2{V}` — type-parameterized, V = Float64 or ComplexF64 +2. Return `(qtci, ranks, errors)` tuple from `quanticscrossinterpolate` +3. First argument is `::Type{V}` (matching Pure Julia API) +4. Build `t4a_qtci_options` handle from kwargs, pass to C API, release after +5. Pass `initial_pivots` if provided +6. Receive `out_ranks`, `out_errors`, `out_n_iters` from C API +7. Add `max_bond_error(qtci)` and `max_rank(qtci)` accessors +8. Add c64 callback trampolines (result writes [re, im]) +9. `evaluate` for c64 returns ComplexF64 +10. `sum`, `integral` for c64 return ComplexF64 +11. `to_tensor_train` for c64 returns `SimpleTensorTrain{ComplexF64}` +12. Add overloads: size tuple, xvals, Array (internally construct grid then call main function) + +- [ ] **Step 1: Rewrite QuanticsTCI.jl** + +Complete rewrite. Reference existing code for callback trampoline patterns. + +- [ ] **Step 2: Commit** + +```bash +git add src/QuanticsTCI.jl +git commit -m "feat(QuanticsTCI): rewrite with ComplexF64, tuple return, QtciOptions, overloads" +``` + +--- + +## Task 7: Write QuanticsTCI tests + +**Files:** +- Create: `test/test_quanticstci.jl` +- Modify: `test/runtests.jl` + +- [ ] **Step 1: Write tests** + +Cover: +- `quanticscrossinterpolate(Float64, f, grid)` returns (qtci, ranks, errors) +- `quanticscrossinterpolate(ComplexF64, f, grid)` works +- Discrete variant with size tuple +- `evaluate`, `sum`, `integral` for both Float64 and ComplexF64 +- `max_bond_error`, `max_rank` +- `to_tensor_train` returns correct type +- kwargs: tolerance, maxbonddim, maxiter, verbosity + +- [ ] **Step 2: Run tests and commit** + +```bash +git add test/test_quanticstci.jl test/runtests.jl +git commit -m "test(QuanticsTCI): comprehensive tests for rewritten module" +``` + +--- + +## Task 8: TreeTCI 1-indexed + +**Files:** +- Modify: `src/TreeTCI.jl` + +- [ ] **Step 1: Update indexing** + +Changes needed: +- `crossinterpolate2`: `initialpivots` default changes from `[zeros(Int, n)]` to `[ones(Int, n)]`. Before passing to C API, subtract 1 from each pivot. +- `evaluate(ttn, indices)`: subtract 1 before passing to C API +- `evaluate(ttn, batch::Matrix)`: subtract 1 from all indices +- Callback wrapper: when Rust passes 0-indexed batch to Julia callback, add 1 before calling user function + +- [ ] **Step 2: Write tests** + +Create `test/test_treetci.jl` with basic crossinterpolate2 test using 1-indexed pivots and evaluation. + +- [ ] **Step 3: Commit** + +```bash +git add src/TreeTCI.jl test/test_treetci.jl test/runtests.jl +git commit -m "feat(TreeTCI): convert to 1-indexed user API" +``` + +--- + +## Task 9: SimpleTT ↔ TreeTN.MPS conversion + +**Files:** +- Modify: `src/TreeTN.jl` — add `MPS(tt::SimpleTensorTrain)` +- Modify: `src/SimpleTT.jl` — add `SimpleTensorTrain(mps::TreeTensorNetwork{Int})` + +- [ ] **Step 1: Implement MPS(tt::SimpleTensorTrain)** + +In `src/TreeTN.jl`: +```julia +function MPS(tt::SimpleTensorTrain{T}) where T + n = length(tt) + tensors = Tensor[] + links = Index[] + + # Create site and link indices + for i in 1:n + st = sitetensor(tt, i) # 1-indexed, shape (left, site, right) + left_dim, site_dim, right_dim = size(st) + + site_idx = Index(site_dim) + inds = Index[] + + if i > 1 + push!(inds, links[end]) # left link + end + push!(inds, site_idx) + if i < n + link = Index(right_dim; tags="Link,l=$i") + push!(links, link) + push!(inds, link) + end + + push!(tensors, Tensor(inds, st)) + end + + return MPS(tensors) +end +``` + +- [ ] **Step 2: Implement SimpleTensorTrain(mps::TreeTensorNetwork{Int})** + +In `src/SimpleTT.jl`, import TreeTN types and add: +```julia +function SimpleTensorTrain(mps::TreeTN.TreeTensorNetwork{Int}) + n = TreeTN.nv(mps) + site_tensors = Array{Float64,3}[] # or ComplexF64 based on storage + + for i in 1:n + tensor = mps[i] + # Extract data and determine (left, site, right) ordering + # ... (implementation depends on TreeTN tensor index ordering) + end + + return SimpleTensorTrain(site_tensors) +end +``` + +- [ ] **Step 3: Write round-trip tests** + +Create `test/test_conversions.jl`: +- Create SimpleTT → convert to MPS → convert back → compare values +- Create MPS (via random_mps) → convert to SimpleTT → convert back → compare + +- [ ] **Step 4: Commit** + +```bash +git add src/TreeTN.jl src/SimpleTT.jl test/test_conversions.jl test/runtests.jl +git commit -m "feat: add SimpleTT ↔ TreeTN.MPS bidirectional conversion" +``` + +--- + +## Task 10: TCI conversion bridge extension + +**Files:** +- Create: `ext/Tensor4allTCIExt.jl` +- Modify: `Project.toml` — add TensorCrossInterpolation weak dep + +- [ ] **Step 1: Update Project.toml** + +Add to `[weakdeps]`: +```toml +TensorCrossInterpolation = "" +``` + +Add to `[extensions]`: +```toml +Tensor4allTCIExt = ["TensorCrossInterpolation"] +``` + +Get UUID from `../TensorCrossInterpolation.jl/Project.toml`. + +- [ ] **Step 2: Create extension** + +```julia +module Tensor4allTCIExt + +using Tensor4all +using TensorCrossInterpolation +import Tensor4all.SimpleTT: SimpleTensorTrain + +function SimpleTensorTrain(tt::TensorCrossInterpolation.TensorTrain{T}) where T + return SimpleTensorTrain(tt.sitetensors) +end + +function TensorCrossInterpolation.TensorTrain(stt::SimpleTensorTrain{T}) where T + n = length(stt) + site_tensors = [Tensor4all.SimpleTT.sitetensor(stt, i) for i in 1:n] + return TensorCrossInterpolation.TensorTrain(site_tensors) +end + +end # module +``` + +- [ ] **Step 3: Write tests (if TCI is available in test deps)** + +- [ ] **Step 4: Commit** + +```bash +git add ext/Tensor4allTCIExt.jl Project.toml +git commit -m "feat: add TCI.TensorTrain ↔ SimpleTensorTrain conversion extension" +``` + +--- + +## Implementation Notes + +### Rust library must be rebuilt + +After pulling the latest tensor4all-rs (with #393), the shared library `libtensor4all_capi.so` must be rebuilt. Run: +```bash +TENSOR4ALL_RS_PATH=/home/shinaoka/tensor4all/tensor4all-rs julia --startup-file=no deps/build.jl +``` + +### Testing approach + +Tests should NOT use ITensors (to verify the "ITensors-free" goal). Use only Tensor4all modules directly. ITensors tests are separate in `itensors_ext_test.jl`. + +### Breaking changes + +This plan introduces breaking changes to: +- SimpleTT: 0→1 indexed, function renames +- QuanticsTCI: return type (single → tuple), argument order +- TreeTCI: 0→1 indexed +- QuanticsGrids: `local_dimensions` → `localdimensions` + +No backward compatibility aliases are provided. diff --git a/docs/plans/2026-03-31-quanticstci-treetci-backend.md b/docs/plans/2026-03-31-quanticstci-treetci-backend.md new file mode 100644 index 0000000..6475cfa --- /dev/null +++ b/docs/plans/2026-03-31-quanticstci-treetci-backend.md @@ -0,0 +1,642 @@ +# QuanticsTCI TreeTCI Backend Migration — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Migrate `tensor4all-quanticstci` crate from `tensor4all-tensorci` (chain-specific TCI) to `tensor4all-treetci` (tree-general TCI) as its backend, using a linear chain graph internally. + +**Architecture:** Replace the `crossinterpolate2` call from tensorci with treetci's version. Auto-generate a linear chain `TreeTciGraph` from the grid's site count. Convert the resulting `TreeTN` back to `TensorTrain` (SimpleTT) for `QuanticsTensorCI2`'s evaluation/sum methods. Adapt the point-wise callback to treetci's batch callback. + +**Tech Stack:** Rust, tensor4all-treetci, tensor4all-simplett, tensor4all-core, quanticsgrids + +**Tracking issue:** tensor4all/tensor4all-rs#384 + +--- + +## File Structure + +All changes are in the `tensor4all-rs` repository. + +| File | Action | Responsibility | +|------|--------|---------------| +| `crates/tensor4all-treetci/src/graph.rs` | Modify | Add `linear_chain()` constructor to `TreeTciGraph` | +| `crates/tensor4all-treetci/src/lib.rs` | Verify | Ensure `linear_chain` is accessible | +| `crates/tensor4all-quanticstci/Cargo.toml` | Modify | Replace `tensor4all-tensorci` dep with `tensor4all-treetci` | +| `crates/tensor4all-quanticstci/src/lib.rs` | Modify | Update re-exports | +| `crates/tensor4all-quanticstci/src/options.rs` | Modify | Replace `to_tci2_options()` with `to_treetci_options()` | +| `crates/tensor4all-quanticstci/src/quantics_tci.rs` | Modify | Core migration: replace crossinterpolate2 call, adapt callback, add TreeTN→TensorTrain conversion | +| `crates/tensor4all-quanticstci/src/options/tests/mod.rs` | Modify | Update options tests | + +--- + +## Task 1: Add `linear_chain()` constructor to `TreeTciGraph` + +**Files:** +- Modify: `crates/tensor4all-treetci/src/graph.rs` + +- [ ] **Step 1: Write the test** + +Add to the test module in `crates/tensor4all-treetci/src/graph.rs` (or its test submodule): + +```rust +#[test] +fn test_linear_chain() { + let graph = TreeTciGraph::linear_chain(5).unwrap(); + assert_eq!(graph.n_sites(), 5); + let edges = graph.edges(); + assert_eq!(edges.len(), 4); + assert_eq!(edges[0], TreeTciEdge::new(0, 1)); + assert_eq!(edges[1], TreeTciEdge::new(1, 2)); + assert_eq!(edges[2], TreeTciEdge::new(2, 3)); + assert_eq!(edges[3], TreeTciEdge::new(3, 4)); +} + +#[test] +fn test_linear_chain_single_site() { + let graph = TreeTciGraph::linear_chain(1).unwrap(); + assert_eq!(graph.n_sites(), 1); + assert_eq!(graph.edges().len(), 0); +} + +#[test] +fn test_linear_chain_two_sites() { + let graph = TreeTciGraph::linear_chain(2).unwrap(); + assert_eq!(graph.n_sites(), 2); + assert_eq!(graph.edges().len(), 1); +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cargo test -p tensor4all-treetci test_linear_chain` +Expected: FAIL — `linear_chain` method does not exist. + +- [ ] **Step 3: Implement `linear_chain`** + +In `crates/tensor4all-treetci/src/graph.rs`, add to `impl TreeTciGraph`: + +```rust +/// Create a linear chain graph: 0—1—2—…—(n-1). +pub fn linear_chain(n_sites: usize) -> Result { + if n_sites == 0 { + return Err(anyhow::anyhow!("linear_chain requires at least 1 site")); + } + let edges: Vec = (0..n_sites.saturating_sub(1)) + .map(|i| TreeTciEdge::new(i, i + 1)) + .collect(); + Self::new(n_sites, &edges) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cargo test -p tensor4all-treetci test_linear_chain` +Expected: All 3 tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-treetci/src/graph.rs +git commit -m "feat(treetci): add TreeTciGraph::linear_chain() constructor" +``` + +--- + +## Task 2: Update `QtciOptions` to map to `TreeTciOptions` + +**Files:** +- Modify: `crates/tensor4all-quanticstci/src/options.rs` +- Modify: `crates/tensor4all-quanticstci/src/options/tests/mod.rs` +- Modify: `crates/tensor4all-quanticstci/Cargo.toml` + +- [ ] **Step 1: Update Cargo.toml dependencies** + +In `crates/tensor4all-quanticstci/Cargo.toml`: + +```diff + [dependencies] + tensor4all-tcicore = { path = "../tensor4all-tcicore" } +-tensor4all-tensorci = { path = "../tensor4all-tensorci" } ++tensor4all-treetci = { path = "../tensor4all-treetci" } + tensor4all-simplett = { path = "../tensor4all-simplett" } + quanticsgrids.workspace = true +``` + +- [ ] **Step 2: Update options.rs imports and mapping** + +In `crates/tensor4all-quanticstci/src/options.rs`: + +Replace imports: +```diff +-use tensor4all_tensorci::{PivotSearchStrategy, TCI2Options}; ++use tensor4all_treetci::TreeTciOptions; +``` + +Remove `PivotSearchStrategy` from the `QtciOptions` struct and related builder methods. Replace `to_tci2_options` with `to_treetci_options`: + +```rust +/// Convert to TreeTciOptions for the underlying algorithm. +pub fn to_treetci_options(&self) -> TreeTciOptions { + TreeTciOptions { + tolerance: self.tolerance, + max_iter: self.maxiter, + max_bond_dim: self.maxbonddim.unwrap_or(usize::MAX), + normalize_error: self.normalize_error, + } +} +``` + +Remove `pivot_search` field from `QtciOptions`, its builder method `with_pivot_search`, and the default value. Keep `nsearchglobalpivot` and `nsearch` fields for future GlobalPivotFinder integration (but unused for now). + +- [ ] **Step 3: Update options tests** + +In `crates/tensor4all-quanticstci/src/options/tests/mod.rs`: + +Update `test_default_options` — remove `pivot_search` assertion. + +Update `test_builder_pattern` — remove `pivot_search` builder call and assertion. + +Replace `test_to_tci2_options` with: + +```rust +#[test] +fn test_to_treetci_options() { + let opts = QtciOptions::default() + .with_tolerance(1e-6) + .with_maxbonddim(100); + + let tree_opts = opts.to_treetci_options(); + assert!((tree_opts.tolerance - 1e-6).abs() < 1e-15); + assert_eq!(tree_opts.max_bond_dim, 100); + assert_eq!(tree_opts.max_iter, 200); + assert!(tree_opts.normalize_error); +} +``` + +- [ ] **Step 4: Verify options tests compile and pass** + +Run: `cargo test -p tensor4all-quanticstci -- options` +Expected: PASS (tests may fail due to other imports still referencing tensorci — that's OK, we'll fix in Task 3). + +- [ ] **Step 5: Commit** + +```bash +git add crates/tensor4all-quanticstci/ +git commit -m "refactor(quanticstci): replace TCI2Options with TreeTciOptions" +``` + +--- + +## Task 3: Migrate `QuanticsTensorCI2` and `quanticscrossinterpolate` + +This is the core migration task. Changes are in `crates/tensor4all-quanticstci/src/quantics_tci.rs`. + +**Files:** +- Modify: `crates/tensor4all-quanticstci/src/quantics_tci.rs` +- Modify: `crates/tensor4all-quanticstci/src/lib.rs` + +### Step-by-step changes: + +- [ ] **Step 1: Update imports** + +In `quantics_tci.rs`, replace: + +```diff +-use tensor4all_simplett::{AbstractTensorTrain, TTScalar, TensorTrain}; +-use tensor4all_tcicore::{DenseFaerLuKernel, LazyBlockRookKernel, PivotKernel}; +-use tensor4all_tensorci::Scalar; +-use tensor4all_tensorci::{crossinterpolate2, TensorCI2}; ++use tensor4all_simplett::{AbstractTensorTrain, TTScalar, TensorTrain, Tensor3Ops}; ++use tensor4all_tcicore::DenseFaerLuKernel; ++use tensor4all_tcicore::PivotKernel; ++use tensor4all_treetci::{ ++ crossinterpolate2 as treetci_crossinterpolate2, ++ DefaultProposer, GlobalIndexBatch, TreeTciGraph, ++}; ++use tensor4all_treetci::TreeTciOptions; ++use tensor4all_core::TensorDynLen; +``` + +- [ ] **Step 2: Change `QuanticsTensorCI2` to store `TensorTrain`** + +Replace the struct definition: + +```diff +-pub struct QuanticsTensorCI2 { +- tci: TensorCI2, ++pub struct QuanticsTensorCI2 { ++ tt: TensorTrain, + discretized_grid: Option, + inherent_grid: Option, + cache: HashMap, V>, + } +``` + +- [ ] **Step 3: Update `QuanticsTensorCI2` constructors** + +```rust +impl QuanticsTensorCI2 +where + V: TTScalar + Default + Clone, +{ + pub fn from_discretized( + tt: TensorTrain, + grid: DiscretizedGrid, + cache: HashMap, V>, + ) -> Self { + Self { + tt, + discretized_grid: Some(grid), + inherent_grid: None, + cache, + } + } + + pub fn from_inherent( + tt: TensorTrain, + grid: InherentDiscreteGrid, + cache: HashMap, V>, + ) -> Self { + Self { + tt, + discretized_grid: None, + inherent_grid: Some(grid), + cache, + } + } +``` + +- [ ] **Step 4: Update `QuanticsTensorCI2` methods** + +```rust + pub fn rank(&self) -> usize { + self.tt.rank() + } + + pub fn link_dims(&self) -> Vec { + self.tt.link_dims() + } + + pub fn evaluate(&self, indices: &[i64]) -> Result { + let quantics = self.grididx_to_quantics(indices)?; + // Convert 1-indexed quantics to 0-indexed for tensor train + let quantics_usize: Vec = quantics.iter().map(|&x| (x - 1) as usize).collect(); + self.tt.evaluate(&quantics_usize) + .map_err(|e| anyhow!("Evaluation error: {}", e)) + } + + pub fn sum(&self) -> V { + self.tt.sum() + } + + pub fn integral(&self) -> Result + where + V: std::ops::Mul, + { + let sum_val = self.sum(); + if let Some(grid) = &self.discretized_grid { + let step_product: f64 = grid.grid_step().iter().product(); + Ok(sum_val * step_product) + } else { + Ok(sum_val) + } + } + + pub fn tensor_train(&self) -> TensorTrain { + self.tt.clone() + } +``` + +Note: `sum()` no longer returns `Result` — it returns `V` directly since `TensorTrain::sum()` is infallible. `grididx_to_quantics` stays the same. + +- [ ] **Step 5: Add `treetn_to_tensor_train` conversion helper** + +Add a private helper function in `quantics_tci.rs`: + +```rust +/// Convert a linear-chain TreeTN to a SimpleTT TensorTrain. +/// +/// The TreeTN must have been produced by treetci::crossinterpolate2 with a +/// linear chain graph. Nodes are numbered 0..n-1. +/// +/// TreeTN tensors from `to_treetn` have index order: +/// [site_dim, incoming_bond_dims..., outgoing_bond_dims...] +/// where incoming = children, outgoing = parent in BFS from root=0. +/// +/// For SimpleTT we need (left_bond, site_dim, right_bond). +fn treetn_to_tensor_train( + treetn: &tensor4all_treetn::TreeTN, + n_sites: usize, + local_dims: &[usize], +) -> Result> +where + V: TTScalar + Default + Clone + tensor4all_core::TensorElement, +{ + use tensor4all_simplett::types::tensor3_from_data; + + let mut tensors = Vec::with_capacity(n_sites); + + for site in 0..n_sites { + let node_idx = treetn.node_index(&site) + .ok_or_else(|| anyhow!("node {} not found in TreeTN", site))?; + let tensor = treetn.tensor(node_idx) + .ok_or_else(|| anyhow!("tensor not found at node {}", site))?; + + let site_dim = local_dims[site]; + + if n_sites == 1 { + // Single site: tensor has only site index, shape (site_dim,) + let data = tensor.to_column_major_vec::()?; + tensors.push(tensor3_from_data(1, site_dim, 1, data)?); + } else if site == 0 { + // Root (leftmost): indices = [site, bond_01] + // Data shape: (site_dim, bond_dim) column-major + // Need: (1, site_dim, bond_dim) + let bond_dim = tensor.total_size() / site_dim; + let data = tensor.to_column_major_vec::()?; + // Reshape: (site, bond) → (1, site, bond) + // Column-major (site, bond): data[s + site_dim * b] + // Target (1, site, bond): data[0 + 1*(s + site_dim * b)] — same layout + tensors.push(tensor3_from_data(1, site_dim, bond_dim, data)?); + } else if site == n_sites - 1 { + // Leaf (rightmost): indices = [site, bond_{n-2,n-1}] + // Data shape: (site_dim, bond_dim) column-major + // Need: (bond_dim, site_dim, 1) + let bond_dim = tensor.total_size() / site_dim; + let data = tensor.to_column_major_vec::()?; + // Permute: (site, bond) → (bond, site) + let mut permuted = vec![V::default(); data.len()]; + for b in 0..bond_dim { + for s in 0..site_dim { + permuted[b + bond_dim * s] = data[s + site_dim * b]; + } + } + tensors.push(tensor3_from_data(bond_dim, site_dim, 1, permuted)?); + } else { + // Middle node: indices = [site, bond_{site,site+1}, bond_{site-1,site}] + // (site is root=0, so parent is towards 0) + // incoming = bond to right child, outgoing = bond to parent (left) + // Data shape: (site_dim, right_bond, left_bond) column-major + // Need: (left_bond, site_dim, right_bond) + let total = tensor.total_size(); + let left_bond = tensor.dim_of_index(2)?; // outgoing = parent bond + let right_bond = tensor.dim_of_index(1)?; // incoming = child bond + let data = tensor.to_column_major_vec::()?; + // Permute: (site, right, left) → (left, site, right) + let mut permuted = vec![V::default(); total]; + for l in 0..left_bond { + for s in 0..site_dim { + for r in 0..right_bond { + let src = s + site_dim * (r + right_bond * l); + let dst = l + left_bond * (s + site_dim * r); + permuted[dst] = data[src]; + } + } + } + tensors.push(tensor3_from_data(left_bond, site_dim, right_bond, permuted)?); + } + } + + TensorTrain::new(tensors).map_err(|e| anyhow!("Failed to build TensorTrain: {}", e)) +} +``` + +**Important note:** The exact API for extracting typed data from `TensorDynLen` (`to_column_major_vec::()`, `dim_of_index()`, `total_size()`) may differ from what's shown. During implementation, check `tensor4all_core::TensorDynLen` API and adapt accordingly. The core logic (index permutation) is correct. + +- [ ] **Step 6: Migrate `quanticscrossinterpolate` (continuous)** + +Replace the function body. Key changes: +1. Wrap the point-wise function `qf` into a batch function for treetci +2. Generate linear chain graph +3. Call `treetci_crossinterpolate2` instead of `crossinterpolate2` +4. Convert TreeTN result to TensorTrain + +```rust +pub fn quanticscrossinterpolate( + grid: &DiscretizedGrid, + f: F, + initial_pivots: Option>>, + options: QtciOptions, +) -> Result<(QuanticsTensorCI2, Vec, Vec)> +where + V: TTScalar + Default + Clone + 'static + + tensor4all_core::TensorElement + + tensor4all_treetci::materialize::FullPivLuScalar, + DenseFaerLuKernel: PivotKernel, + F: Fn(&[f64]) -> V + 'static, +{ + let local_dims = grid.local_dimensions(); + let n_sites = local_dims.len(); + + let cache: Rc, V>>> = Rc::new(RefCell::new(HashMap::new())); + let cache_clone = cache.clone(); + + // Wrap function: original coords → quantics → 0-indexed for TCI + let grid_clone = grid.clone(); + let qf = move |q: &Vec| -> V { + let q_i64: Vec = q.iter().map(|&x| (x + 1) as i64).collect(); + if let Some(v) = cache_clone.borrow().get(&q_i64) { + return v.clone(); + } + let coords = match grid_clone.quantics_to_origcoord(&q_i64) { + Ok(coords) => coords, + Err(_) => return V::default(), + }; + let value = f(&coords); + cache_clone.borrow_mut().insert(q_i64, value.clone()); + value + }; + + // Batch adapter: treetci expects Fn(GlobalIndexBatch) -> Result> + let batch_eval = move |batch: GlobalIndexBatch<'_>| -> Result> { + let n_points = batch.n_points(); + let n = batch.n_sites(); + let mut results = Vec::with_capacity(n_points); + for p in 0..n_points { + let point: Vec = (0..n).map(|s| batch.get(s, p)).collect(); + results.push(qf(&point)); + } + Ok(results) + }; + + // Prepare initial pivots (0-indexed) + let mut qinitialpivots: Vec> = if let Some(pivots) = initial_pivots { + pivots.iter().filter_map(|p| { + grid.grididx_to_quantics(p).ok() + .map(|q| q.iter().map(|&x| (x - 1) as usize).collect()) + }).collect() + } else { + vec![vec![0; n_sites]] + }; + + let mut rng = rand::rng(); + for _ in 0..options.nrandominitpivot { + let pivot: Vec = local_dims.iter().map(|&d| rng.random_range(0..d)).collect(); + qinitialpivots.push(pivot); + } + + // Run TreeTCI with linear chain + let graph = TreeTciGraph::linear_chain(n_sites)?; + let tree_opts = options.to_treetci_options(); + let proposer = DefaultProposer; + let (treetn, ranks, errors) = treetci_crossinterpolate2( + batch_eval, + local_dims.clone(), + graph, + qinitialpivots, + tree_opts, + Some(0), // center_site = 0 (root at left end) + &proposer, + )?; + + // Convert TreeTN → TensorTrain + let tt = treetn_to_tensor_train::(&treetn, n_sites, &local_dims)?; + + let final_cache = Rc::try_unwrap(cache) + .map_err(|_| anyhow!("Failed to extract cache"))? + .into_inner(); + + Ok(( + QuanticsTensorCI2::from_discretized(tt, grid.clone(), final_cache), + ranks, + errors, + )) +} +``` + +- [ ] **Step 7: Migrate `quanticscrossinterpolate_discrete`** + +Apply the same pattern as Step 6. The changes are identical in structure: +1. Wrap `qf` into batch adapter +2. Generate linear chain graph +3. Call `treetci_crossinterpolate2` +4. Convert TreeTN → TensorTrain +5. Build `QuanticsTensorCI2::from_inherent` + +The discrete version uses `InherentDiscreteGrid` and the function receives `&[i64]` indices instead of `&[f64]` coordinates. The callback wrapping logic stays the same as existing code, just with the batch adapter added. + +- [ ] **Step 8: Update `lib.rs` re-exports** + +In `crates/tensor4all-quanticstci/src/lib.rs`: + +```diff +-pub use tensor4all_tensorci::{PivotSearchStrategy, Scalar, TCI2Options, TensorCI2}; ++pub use tensor4all_treetci::{TreeTciGraph, TreeTciOptions, DefaultProposer}; ++pub use tensor4all_simplett::{AbstractTensorTrain, TensorTrain}; +``` + +- [ ] **Step 9: Verify it compiles** + +Run: `cargo build -p tensor4all-quanticstci` +Expected: Successful compilation. Fix any type mismatches. + +- [ ] **Step 10: Commit** + +```bash +git add crates/tensor4all-quanticstci/ +git commit -m "feat(quanticstci): migrate backend from tensorci to treetci" +``` + +--- + +## Task 4: Verify all existing tests pass + +**Files:** +- All test files in `crates/tensor4all-quanticstci/src/quantics_tci/tests/mod.rs` +- All test files in `crates/tensor4all-quanticstci/src/options/tests/mod.rs` + +- [ ] **Step 1: Run all quanticstci tests** + +Run: `cargo test -p tensor4all-quanticstci` +Expected: All tests PASS. Key tests to verify: +- `test_discrete_simple_function` — f(i,j) = i+j +- `test_continuous_grid_interpolation` — f(x) = x² +- `test_continuous_grid_integral` — integral of x² +- `test_discrete_with_initial_pivots` +- `test_continuous_grid_with_initial_pivots` +- `test_from_arrays_valid` +- `test_from_arrays_1d` + +- [ ] **Step 2: Run full workspace tests** + +Run: `cargo test --workspace` +Expected: No regressions in other crates. + +- [ ] **Step 3: If tests fail, debug and fix** + +Common issues to check: +- Index permutation order in `treetn_to_tensor_train` (the most likely source of bugs) +- Sign/value differences due to different canonicalization +- Tolerance differences (TreeTCI may converge differently than chain TCI) + +For tolerance issues: the existing tests use `approx::assert_abs_diff_eq!` with tolerances. If TreeTCI converges to slightly different accuracy, adjust the tolerance in the test assertions rather than changing the algorithm. + +- [ ] **Step 4: Commit any fixes** + +```bash +git add -A +git commit -m "fix(quanticstci): fix test failures after treetci migration" +``` + +--- + +## Task 5: Clean up old `tensor4all-tensorci` dependency + +- [ ] **Step 1: Verify no remaining references to tensorci in quanticstci** + +Run: `grep -r "tensor4all.tensorci\|tensorci" crates/tensor4all-quanticstci/src/` +Expected: No matches (only treetci references). + +- [ ] **Step 2: Check if other crates still depend on tensorci** + +Run: `grep -r "tensor4all-tensorci" crates/*/Cargo.toml` +Expected: Only `tensor4all-tcicore` (dev-dependency) and `tensor4all-tensorci` itself. + +- [ ] **Step 3: Commit final state** + +```bash +git add crates/tensor4all-quanticstci/ +git commit -m "chore(quanticstci): remove all tensor4all-tensorci references" +``` + +--- + +## Implementation Notes + +### TreeTN → TensorTrain conversion details + +The `to_treetn` function in treetci produces tensors with this index ordering per node: + +``` +indices = [site_index, incoming_bond_indices..., outgoing_bond_indices...] +``` + +For a linear chain 0—1—2—…—(n-1) with BFS root at 0: +- **Node 0** (root): `[site, bond_01]` — no parent, one child +- **Node k** (middle): `[site, bond_{k,k+1}, bond_{k-1,k}]` — one child (incoming), one parent (outgoing) +- **Node n-1** (leaf): `[site, bond_{n-2,n-1}]` — no children, one parent + +SimpleTT needs `(left_bond, site, right_bond)`: +- Node 0: `(1, site, bond_01)` — insert dummy left=1 +- Node k: `(bond_{k-1,k}, site, bond_{k,k+1})` — permute from (site, right, left) +- Node n-1: `(bond_{n-2,n-1}, site, 1)` — permute from (site, left), insert dummy right=1 + +### Type constraint changes + +Old: `V: Scalar + TTScalar + Default + Clone + MatrixLuciScalar` +New: `V: TTScalar + Default + Clone + TensorElement + FullPivLuScalar` + +`FullPivLuScalar` is implemented for f32, f64, Complex32, Complex64 — same as before. + +### Callback adaptation + +tensorci: `Fn(&Vec) -> V` (single point) +treetci: `Fn(GlobalIndexBatch<'_>) -> Result>` (batch) + +The adapter iterates over batch points and calls the point-wise function for each. + +### `sum()` return type change + +`TensorCI2::sum()` required `to_tensor_train()` first (fallible). +`TensorTrain::sum()` is direct and infallible. + +`QuanticsTensorCI2::sum()` changes from `Result` to `V`. This is a **breaking API change** for downstream callers. The C API wrapper in `tensor4all-capi` will need updating to match (but that's a separate task). diff --git a/docs/plans/2026-04-06-quantics-rust-parity.md b/docs/plans/2026-04-06-quantics-rust-parity.md new file mode 100644 index 0000000..ff01b81 --- /dev/null +++ b/docs/plans/2026-04-06-quantics-rust-parity.md @@ -0,0 +1,242 @@ +# Quantics Rust Parity Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Expose the pure Rust quantics surface through `tensor4all-capi` and `Tensor4all.jl`, so downstream code can use grids and multivariable quantics transforms without reviving the removed `TensorCI` wrapper. + +**Architecture:** Keep Rust as the source of truth for quantics functionality. Fill the missing FFI layer in `tensor4all-capi`, then add thin Julia bindings in `C_API.jl` and ergonomic wrappers in `QuanticsGrids.jl` and `QuanticsTransform.jl`. Use `SimpleTT.SimpleTensorTrain` as the canonical Julia TT object and leave external TT conversions in extension modules. + +**Tech Stack:** Rust, `tensor4all-rs`, `tensor4all-capi`, Julia `ccall`, `Tensor4all.jl` (`C_API.jl`, `QuanticsGrids.jl`, `QuanticsTransform.jl`, `SimpleTT.jl`) + +**Spec:** `docs/specs/2026-04-06-quantics-rust-parity-design.md` + +--- + +### Task 1: Tighten the QuanticsGrids Julia surface + +**Files:** +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/src/QuanticsGrids.jl` +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/src/Tensor4all.jl` +- Test: `/home/shinaoka/tensor4all/Tensor4all.jl/test/test_quanticsgrids.jl` + +**Step 1: Write the failing test** + +Add tests that require: + +- `using Tensor4all` to access `DiscretizedGrid` and `InherentDiscreteGrid` +- `unfolding=:grouped` to be accepted when Rust/C API support it + +**Step 2: Run test to verify it fails** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticsgrids\"])'` + +Expected: failures showing missing top-level exports or unsupported `:grouped`. + +**Step 3: Write minimal implementation** + +- extend `_unfolding_to_cint` in `QuanticsGrids.jl` to include `:grouped` +- export selected grid types/functions from `Tensor4all.jl` +- keep the submodule intact for namespaced use + +**Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticsgrids\"])'` + +Expected: pass. + +**Step 5: Commit** + +```bash +git add src/QuanticsGrids.jl src/Tensor4all.jl test/test_quanticsgrids.jl +git commit -m "feat: polish QuanticsGrids Julia surface" +``` + +### Task 2: Expose multivariable quantics transform constructors in the C API + +**Files:** +- Modify: `/home/shinaoka/tensor4all/tensor4all-rs/crates/tensor4all-capi/src/quanticstransform.rs` +- Modify: `/home/shinaoka/tensor4all/tensor4all-rs/crates/tensor4all-capi/src/types.rs` +- Test: `/home/shinaoka/tensor4all/tensor4all-rs/crates/tensor4all-capi/src/quanticstransform/tests/mod.rs` + +**Step 1: Write the failing Rust tests** + +Add tests that attempt to construct: + +- `shift_operator_multivar` +- `flip_operator_multivar` +- `phase_rotation_operator_multivar` +- `affine_operator` +- `binaryop_operator` + +with mixed boundary conditions and asymmetric input/output dimensions where applicable. + +**Step 2: Run tests to verify they fail** + +Run: `cargo test -p tensor4all-capi quanticstransform -- --nocapture` + +Expected: compile or symbol failures because the C API does not expose these constructors yet. + +**Step 3: Write minimal implementation** + +- add FFI-safe boundary-condition conversions in `types.rs` if needed +- expose the missing constructor entry points in `quanticstransform.rs` +- keep function signatures close to the Rust backend surface + +**Step 4: Run tests to verify they pass** + +Run: `cargo test -p tensor4all-capi quanticstransform -- --nocapture` + +Expected: pass. + +**Step 5: Commit** + +```bash +git add crates/tensor4all-capi/src/types.rs crates/tensor4all-capi/src/quanticstransform.rs crates/tensor4all-capi/src/quanticstransform/tests/mod.rs +git commit -m "feat: expose multivariable quantics transforms in C API" +``` + +### Task 3: Add Julia C bindings for the new quantics transform constructors + +**Files:** +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/src/C_API.jl` +- Test: `/home/shinaoka/tensor4all/Tensor4all.jl/test/test_quanticstransform.jl` + +**Step 1: Write the failing Julia test** + +Add tests that call the new C-API wrappers from Julia and fail if any symbol or argument conversion is missing. + +**Step 2: Run test to verify it fails** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticstransform\"])'` + +Expected: `UndefVarError`, `MethodError`, or symbol lookup failures. + +**Step 3: Write minimal implementation** + +Add `ccall` wrappers in `C_API.jl` for: + +- boundary-condition enum conversion +- multivariable shift/flip/phase constructors +- affine operator constructor +- binaryop operator constructor + +**Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticstransform\"])'` + +Expected: pass. + +**Step 5: Commit** + +```bash +git add src/C_API.jl test/test_quanticstransform.jl +git commit -m "feat: add Julia bindings for quantics transform C API" +``` + +### Task 4: Build the Julia QuanticsTransform wrapper surface + +**Files:** +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/src/QuanticsTransform.jl` +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/src/Tensor4all.jl` +- Test: `/home/shinaoka/tensor4all/Tensor4all.jl/test/test_quanticstransform.jl` + +**Step 1: Write the failing wrapper tests** + +Cover: + +- `BoundaryCondition` exposure +- multivariable operator constructors +- mixed boundary conditions +- a simple affine embedding case that changes logical variable count + +**Step 2: Run test to verify it fails** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticstransform\"])'` + +Expected: failures in wrapper construction or type dispatch. + +**Step 3: Write minimal implementation** + +- extend `QuanticsTransform.jl` with thin wrappers over the new `C_API.jl` bindings +- keep the operator object model consistent with the existing `apply` path +- export the new constructors without adding a fake `TensorCI` layer + +**Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticstransform\"])'` + +Expected: pass. + +**Step 5: Commit** + +```bash +git add src/QuanticsTransform.jl src/Tensor4all.jl test/test_quanticstransform.jl +git commit -m "feat: add Julia wrappers for multivariable quantics transforms" +``` + +### Task 5: Add a ReFrequenTT-shaped regression test + +**Files:** +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/test/test_quanticstransform.jl` +- Optionally modify: `/home/shinaoka/tensor4all/Tensor4all.jl/test/runtests.jl` + +**Step 1: Write the failing regression test** + +Add a small, local test that exercises the two patterns that matter most for `ReFrequenTT`: + +- same-dimension affine remap with mixed BC +- dimension-embedding remap such as `(omega, q) -> (nu-nup, k-kp)` + +The test only needs to verify operator construction and action on a small TT, not full physics correctness. + +**Step 2: Run test to verify it fails** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticstransform\"])'` + +Expected: failure until the wrapper surface matches the needed semantics. + +**Step 3: Write minimal implementation** + +- refine argument conventions +- fix BC ordering if needed +- add helper utilities only if tests show repetition + +**Step 4: Run test to verify it passes** + +Run: `julia --startup-file=no --project=. -e 'using Pkg; Pkg.test(; test_args=[\"test_quanticstransform\"])'` + +Expected: pass. + +**Step 5: Commit** + +```bash +git add test/test_quanticstransform.jl test/runtests.jl +git commit -m "test: cover ReFrequenTT-style quantics remaps" +``` + +### Task 6: Re-evaluate whether a TensorCI shim is still needed + +**Files:** +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/docs/specs/2026-04-06-quantics-rust-parity-design.md` +- Modify: `/home/shinaoka/tensor4all/Tensor4all.jl/docs/plans/2026-04-06-quantics-rust-parity.md` + +**Step 1: Review the resulting public surface** + +Inspect whether the new grid/transform API already gives downstream users a direct migration path. + +**Step 2: Decide based on evidence** + +If the answer is yes, leave `TensorCI` removed. + +If the answer is no, write down the smallest possible shim and the exact gap it closes. + +**Step 3: Update docs** + +Record the outcome in the design/plan docs instead of guessing up front. + +**Step 4: Commit** + +```bash +git add docs/specs/2026-04-06-quantics-rust-parity-design.md docs/plans/2026-04-06-quantics-rust-parity.md +git commit -m "docs: record TensorCI shim decision" +``` diff --git a/docs/specs/2026-03-30-col-major-array-design.md b/docs/specs/2026-03-30-col-major-array-design.md new file mode 100644 index 0000000..7cedc1e --- /dev/null +++ b/docs/specs/2026-03-30-col-major-array-design.md @@ -0,0 +1,467 @@ +# API 統一 & tensorci 削除 設計 (最終版) + +## Overview + +1. `tensor4all-core` に軽量 N 次元 column-major array 型を追加 +2. `TreeTN::evaluate` を IndexId ベース batch API に移行(HashSet 順序バグ修正) +3. TreeTCI API を統一(`crossinterpolate2`、batch evaluate のみ) +4. tensorci (chain TCI) の C API と Julia ラッパーを削除 +5. TreeTCI 内部状態を `ColMajorArray` に移行 + +## 動機 + +- HashMap ベースの evaluate は遅い + HashSet 順序バグ +- `Vec>` (nested array) はキャッシュ非効率 +- `point_eval` と `batch_eval` の二重 API は不要 +- tensorci (chain) は treetci の下位互換 → 削除して保守コスト削減 +- C API 命名が不一致 → `t4a__` を徹底 + +--- + +## Part 1: 削除 + +### tensor4all-capi: tensorci C API 削除 + +| 削除対象 | ファイル | +|---------|---------| +| `tensorci.rs` 全体 | `crates/tensor4all-capi/src/tensorci.rs` | +| テスト | `crates/tensor4all-capi/src/tensorci/tests/mod.rs` | +| types.rs から `t4a_tci2_f64`, `t4a_tci2_c64` | `crates/tensor4all-capi/src/types.rs` | +| lib.rs から `mod tensorci; pub use tensorci::*;` | `crates/tensor4all-capi/src/lib.rs` | + +**削除される C API 関数:** +- `t4a_tci2_f64_*` (18 関数) +- `t4a_tci2_c64_*` (18 関数) +- `t4a_crossinterpolate2_f64`, `t4a_crossinterpolate2_c64` +- `t4a_estimate_true_error_f64`, `t4a_opt_first_pivot_f64` +- `EvalCallback`, `EvalCallbackC64` 型 + +### Tensor4all.jl: TensorCI Julia ラッパー削除 + +| 削除対象 | ファイル | +|---------|---------| +| TensorCI モジュール | `src/TensorCI.jl` | +| テスト | `test/test_tensorci.jl`, `test/test_tensorci_advanced.jl` | +| Tensor4all.jl から `include("TensorCI.jl")` | `src/Tensor4all.jl` | +| runtests.jl から include | `test/runtests.jl` | + +### Tensor4all.jl: SimpleTT Julia ラッパーの tensorci 依存除去 + +`SimpleTT.jl` が `TensorCI.jl` に依存している箇所があれば除去。 +`to_tensor_train` 等の TCI → SimpleTT 変換は TreeTCI 側で提供。 + +--- + +## Part 2: ColMajorArray 型 (`tensor4all-core`) + +### ファイル + +`crates/tensor4all-core/src/col_major_array.rs` + +### 3つの型 + +```rust +/// Borrowed N-dimensional column-major array view. +/// Element at [i0, i1, ...] is at: i0 + shape[0] * (i1 + shape[1] * ...) +#[derive(Clone, Copy, Debug)] +pub struct ColMajorArrayRef<'a, T> { + data: &'a [T], + shape: &'a [usize], +} + +/// Mutable borrowed N-dimensional column-major array view. +#[derive(Debug)] +pub struct ColMajorArrayMut<'a, T> { + data: &'a mut [T], + shape: &'a [usize], +} + +/// Owned N-dimensional column-major array. +#[derive(Clone, Debug)] +pub struct ColMajorArray { + data: Vec, + shape: Vec, +} +``` + +`T` に trait 境界なし。 + +### API + +**全型共通:** +```rust +fn ndim(&self) -> usize +fn shape(&self) -> &[usize] +fn len(&self) -> usize // 全要素数 +fn is_empty(&self) -> bool +fn data(&self) -> &[T] +fn get(&self, indices: &[usize]) -> Option<&T> +``` + +**Mut, Owned:** +```rust +fn data_mut(&mut self) -> &mut [T] +fn get_mut(&mut self, indices: &[usize]) -> Option<&mut T> +``` + +**Owned のみ:** +```rust +fn new(data: Vec, shape: Vec) -> Result +fn into_data(self) -> Vec +fn as_ref(&self) -> ColMajorArrayRef<'_, T> +fn as_mut(&mut self) -> ColMajorArrayMut<'_, T> +``` + +**2D 便利メソッド (Owned のみ):** +```rust +fn nrows(&self) -> usize // shape[0], panics if not 2D +fn ncols(&self) -> usize // shape[1], panics if not 2D +fn column(&self, j: usize) -> Option<&[T]> +fn push_column(&mut self, column: &[T]) -> Result<()> // column-major append +``` + +**ファクトリ:** +```rust +fn filled(shape: Vec, value: T) -> Self // T: Clone +fn zeros(shape: Vec) -> Self // T: Default + Clone +``` + +### flat_offset (checked arithmetic) + +```rust +fn flat_offset(shape: &[usize], indices: &[usize]) -> Option { + if indices.len() != shape.len() { return None; } + let mut offset = 0usize; + let mut stride = 1usize; + for (&idx, &dim) in indices.iter().zip(shape.iter()) { + if idx >= dim { return None; } + offset = offset.checked_add(idx.checked_mul(stride)?)?; + stride = stride.checked_mul(dim)?; + } + Some(offset) +} +``` + +--- + +## Part 3: TreeTN API 変更 + +### all_site_index_ids (新規) + +```rust +/// Returns all site index IDs and their owning vertex names. +/// +/// Returns (index_ids, vertex_names) where index_ids[i] belongs to +/// vertex vertex_names[i]. Order is unspecified but consistent +/// between the two vectors. +/// +/// For evaluate(), pass index_ids and arrange values in the same order. +pub fn all_site_index_ids(&self) -> Result<( + Vec<::Id>, + Vec, +)> +where + V: Clone, + ::Id: Clone, +``` + +### evaluate (新規、旧版を置き換え) + +```rust +/// Evaluate the TreeTN at multiple multi-indices (batch). +/// +/// index_ids: which indices to fix (n_indices 個). +/// Each ID identifies a specific site index unambiguously. +/// Must enumerate every site index exactly once. +/// values: shape = [n_indices, n_points], column-major. +/// values.get(&[i, p]) = value of index_ids[i] at point p. +/// +/// Returns one AnyScalar per point. +pub fn evaluate( + &self, + index_ids: &[::Id], + values: ColMajorArrayRef<'_, usize>, +) -> Result> +where + ::Id: + Clone + Hash + Eq + Ord + Debug + Send + Sync, +``` + +**内部実装:** +- `index_ids` から各 index を ID で直接 lookup(HashMap 不要) +- point ごとに onehot contraction(既存ロジック流用) +- HashSet 順序バグ解消 — index を ID で直接指定するため + +**旧 `evaluate(&HashMap>)` は削除。** + +### C API + +**削除:** `t4a_treetn_evaluate_batch` (旧 vertex 名ベース) + +**新規:** + +```rust +/// Get all site index IDs and vertex names. +/// Query-then-fill: pass NULL buffers to get out_n_indices only. +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetn_all_site_index_ids( + ptr: *const t4a_treetn, + out_index_ids: *mut u64, // DynId as u64, NULL for query + out_vertex_names: *mut libc::size_t, // vertex name (usize), NULL for query + buf_len: libc::size_t, + out_n_indices: *mut libc::size_t, +) -> StatusCode; + +/// Evaluate TreeTN at multiple points. +/// index_ids: n_indices index IDs (from all_site_index_ids) +/// values: column-major [n_indices, n_points] +/// out_re: n_points results (real part) +/// out_im: n_points results (imag part, NULL for real-only) +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetn_evaluate( + ptr: *const t4a_treetn, + index_ids: *const u64, + n_indices: libc::size_t, + values: *const libc::size_t, + n_points: libc::size_t, + out_re: *mut libc::c_double, + out_im: *mut libc::c_double, +) -> StatusCode; +``` + +### 呼び出し箇所の更新 + +| ファイル | 変更 | +|---------|------| +| `treetn/tests/ops.rs` (6箇所) | `HashMap::from(...)` → `all_site_index_ids` + `ColMajorArrayRef` | +| `treetci/tests/simple_parity.rs` (4箇所) | 同上 | +| `treetci/tests/advanced_quantics.rs` (1箇所) | 同上 | +| `treetci/src/materialize/tests.rs` | 同上 | +| `tensor4all-capi/src/treetn.rs` | 旧 `evaluate_batch` → 新 `evaluate` | + +--- + +## Part 4: TreeTCI API 統一 + +### 命名変更 + +| Before | After | +|--------|-------| +| `crossinterpolate_tree` | `crossinterpolate2` | +| `crossinterpolate_tree_with_proposer` | 削除(`crossinterpolate2` に統合) | +| `batch_eval` (引数名) | `evaluate` | +| `point_eval` | 削除 | +| `fallback_batch_eval` | 削除 | +| `TreeTciBatchEvalCallback` (C API) | `TreeTciEvalCallback` | +| `TreeTciBatchEvalCallbackC64` (C API) | `TreeTciEvalCallbackC64` | +| `t4a_crossinterpolate_tree_f64` (C API) | `t4a_treetci_crossinterpolate2_f64` | +| `t4a_crossinterpolate_tree_c64` (C API) | `t4a_treetci_crossinterpolate2_c64` | + +### crossinterpolate2 (統一版) + +```rust +pub fn crossinterpolate2( + evaluate: F, + local_dims: Vec, + graph: TreeTciGraph, + initial_pivots: Vec, + options: TreeTciOptions, + center_site: Option, + proposer: &P, +) -> Result +where + T: FullPivLuScalar, + DenseFaerLuKernel: PivotKernel, + F: Fn(GlobalIndexBatch<'_>) -> Result>, + P: PivotCandidateProposer, +{ + let pivots = if initial_pivots.is_empty() { + vec![vec![0; local_dims.len()]] + } else { + initial_pivots + }; + + let mut tci = SimpleTreeTci::::new(local_dims, graph)?; + tci.add_global_pivots(&pivots)?; + + // Initialize max_sample_value via batch evaluate + let n_sites = tci.local_dims.len(); + let flat: Vec = pivots.iter().flat_map(|p| p.iter().copied()).collect(); + let shape = [n_sites, pivots.len()]; + let batch = GlobalIndexBatch::new(&flat, &shape)?; + let init_vals = evaluate(batch)?; + tci.max_sample_value = init_vals.iter() + .map(|v| T::abs_val(*v)) + .fold(0.0f64, f64::max); + ensure!(tci.max_sample_value > 0.0, ...); + + let (ranks, errors) = optimize_with_proposer(&mut tci, &evaluate, &options, proposer)?; + let treetn = to_treetn(&tci, &evaluate, center_site)?; + + Ok((treetn, ranks, errors)) +} +``` + +### GlobalIndexBatch + +**shape 所有問題への対応:** + +`GlobalIndexBatch` は pure newtype にせず、独自に shape 情報を所有: + +```rust +#[derive(Clone, Copy, Debug)] +pub struct GlobalIndexBatch<'a> { + data: &'a [usize], + n_sites: usize, + n_points: usize, +} + +impl<'a> GlobalIndexBatch<'a> { + pub fn new(data: &'a [usize], n_sites: usize, n_points: usize) -> Result; + + pub fn n_sites(&self) -> usize; + pub fn n_points(&self) -> usize; + pub fn data(&self) -> &'a [usize]; + pub fn get(&self, site: usize, point: usize) -> Option; + + /// Convert to ColMajorArrayRef (caller must provide shape with sufficient lifetime). + pub fn as_col_major<'s>(&self, shape: &'s [usize]) -> Result> + where 'a: 's; +} +``` + +**現在の構造を維持**。`ColMajorArrayRef` への変換メソッドを追加するだけ。 + +### optimize_with_proposer + +```rust +pub fn optimize_with_proposer( + state: &mut SimpleTreeTci, + evaluate: &F, // renamed from batch_eval + options: &TreeTciOptions, + proposer: &P, +) -> Result<(Vec, Vec)> +where + T: FullPivLuScalar, // tightened from Scalar + DenseFaerLuKernel: PivotKernel, + F: Fn(GlobalIndexBatch<'_>) -> Result>, + P: PivotCandidateProposer, +``` + +### to_treetn + +```rust +pub fn to_treetn( + state: &SimpleTreeTci, + evaluate: &F, // renamed from batch_eval + center_site: Option, +) -> Result> +where + T: FullPivLuScalar, + F: Fn(GlobalIndexBatch<'_>) -> Result>, +``` + +--- + +## Part 5: TreeTCI 内部状態 + +### ijset + ijset_history + +```rust +// Before +pub ijset: HashMap>, +pub ijset_history: Vec>>, + +// After +pub ijset: HashMap>, +// shape = [n_subtree_sites, n_pivots] +// .ncols() = ピボット数 +// .column(j) = j 番目のピボット +// .push_column(pivot) = ピボット追加 + +pub ijset_history: Vec>>, +``` + +**空の初期状態:** `ColMajorArray::new(vec![], vec![n_subtree_sites, 0])` → 0列の 2D 配列。 + +**push_unique_column:** +```rust +fn push_unique_column(array: &mut ColMajorArray, column: &[usize]) { + // 既存列と比較、重複なければ追加 + let nrows = array.nrows(); + for j in 0..array.ncols() { + if array.column(j) == Some(column) { + return; // duplicate + } + } + array.push_column(column).unwrap(); +} +``` + +### .len() → .ncols() への置き換え + +全箇所でピボット数を参照する `.len()` を `.ncols()` に変更: +- `state.rs` L106 +- `materialize.rs` L89, L217 +- `proposer.rs` L82, L248 + +--- + +## Part 6: Julia ラッパー更新 + +### 削除 + +- `src/TensorCI.jl` 全体 +- `test/test_tensorci.jl`, `test/test_tensorci_advanced.jl` +- `src/Tensor4all.jl` の TensorCI include と using +- `test/runtests.jl` の TensorCI include + +### TreeTCI 更新 + +- `crossinterpolate_tree` → `crossinterpolate2` (Julia 側は既に改名済み、内部実装を更新) +- `evaluate` 関数: `all_site_index_ids` + 新 `t4a_treetn_evaluate` を使用 +- C API シンボル名の更新 (`_sym` 呼び出し) + +--- + +## 実装順序 + +| Task | 内容 | crate | +|------|------|-------| +| 1 | `ColMajorArray` 型を core に追加 + テスト | tensor4all-core | +| 2 | tensorci C API 削除 | tensor4all-capi | +| 3 | `TreeTN::all_site_index_ids` 追加 | tensor4all-treetn | +| 4 | `TreeTN::evaluate` を新 API に変更 + 旧削除 | tensor4all-treetn | +| 5 | 旧 evaluate 呼び出し箇所を全更新 | tests, treetci tests | +| 6 | TreeTCI: `evaluate` 改名 + `crossinterpolate2` 統一 + point_eval 削除 | tensor4all-treetci | +| 7 | TreeTCI: `ijset` + `ijset_history` → `ColMajorArray` | tensor4all-treetci | +| 8 | C API: treetci 改名 + treetn 新 API | tensor4all-capi | +| 9 | Julia: TensorCI 削除 + TreeTCI/TreeTN 更新 | Tensor4all.jl | +| 10 | 全テスト + clippy + fmt | validation | + +--- + +## 設計上の注意 + +1. **`ColMajorArrayRef` の shape**: borrowed `&'a [usize]`。SmallVec 不使用。 + +2. **`GlobalIndexBatch`**: pure newtype にしない。独自に `n_sites, n_points` を所有する + 現在の設計を維持。`as_col_major()` 変換メソッドを追加。 + +3. **`TreeTN::evaluate` の IndexId**: `&[::Id]` で index を直接指定。 + `all_site_index_ids()` で事前取得。HashSet 順序バグ解消。 + `index_ids` は全 site index を exactly once 列挙する契約 (docstring に明記)。 + +4. **型境界 `FullPivLuScalar`**: `f32, f64, Complex32, Complex64` に実装済み。 + +5. **後方互換性不要** (AGENTS.md): 旧 API は即削除。 + +6. **ijset_history**: `ijset` と同時に `ColMajorArray` に移行。 + 空配列は `[n_subtree_sites, 0]` shaped。 + +7. **DynId**: `pub struct DynId(pub u64)` → C API で `uint64_t` として渡せる。 + +8. **Python ラッパー**: 無視 (disable 方針)。 + +9. **tensorci Rust crate 自体は残す**: C API と Julia ラッパーのみ削除。 + Rust 内部で他 crate が依存している可能性があるため。 diff --git a/docs/specs/2026-03-30-complex-interleaved-design.md b/docs/specs/2026-03-30-complex-interleaved-design.md new file mode 100644 index 0000000..e746acd --- /dev/null +++ b/docs/specs/2026-03-30-complex-interleaved-design.md @@ -0,0 +1,305 @@ +# Complex64 対応 & Interleaved 統一 設計 + +## Overview + +tensor4all-rs C API と Tensor4all.jl において、Complex64 (c64) サポートを追加し、 +complex データ表現を **interleaved** (`[re0, im0, re1, im1, ...]`) に統一する。 + +## 背景 + +- Rust 側は `TensorTrain`, `TensorCI2`, `SimpleTreeTci` を + ジェネリクスで完全サポート済み +- C API は現在 f64 のみ。唯一の complex 対応 (`t4a_tensor_new_dense_c64`) は + **separated buffers** (re配列 + im配列) だが、これは不自然 +- Julia `ComplexF64`, C `double _Complex`, Rust `Complex64` は全てメモリ上 `[re, im]` の + interleaved なので、separated は変換コストが無駄 + +## 設計方針 + +### Interleaved 表現 + +Complex データは全て `*const f64` / `*mut f64` バッファで、長さ `2 * n_elements`。 +`data[2*i]` = real part, `data[2*i+1]` = imaginary part。 + +Julia 側は `reinterpret(Float64, ::Vector{ComplexF64})` で zero-copy 変換可能。 + +### C API の命名規則 + +既存パターンに従い `_f64` / `_c64` サフィックスで区別: + +``` +t4a_simplett_f64_evaluate → t4a_simplett_c64_evaluate +t4a_tci2_f64_sweep2site → t4a_tci2_c64_sweep2site +t4a_treetci_f64_sweep → t4a_treetci_c64_sweep +``` + +--- + +## Part 1: C API 変更 (tensor4all-rs) + +### 1.1 既存 tensor API: separated → interleaved 移行 + +**変更対象:** +- `t4a_tensor_new_dense_c64`: `(data_re, data_im, data_len)` → `(data_interleaved, data_len)` + - `data_interleaved`: `*const f64`, length = `2 * n_elements` + - `data_len`: 要素数(complex 要素数, interleaved 配列長の半分) +- `t4a_tensor_get_data_c64`: `(buf_re, buf_im, buf_len, out_len)` → `(buf, buf_len, out_len)` + - `buf`: `*mut f64`, length = `2 * n_elements` + - `buf_len`, `out_len`: complex 要素数 + +**破壊的変更**: はい。Python ラッパーは disable のため問題なし。 + +### 1.2 新規: SimpleTT c64 + +**新規オペーク型** (`types.rs`): +```rust +pub struct t4a_simplett_c64 { + pub(crate) _private: *const c_void, +} +// inner: TensorTrain +// Clone, Drop, Send+Sync 実装 +``` + +**新規関数** (`simplett.rs` に追加): + +| 関数 | 説明 | +|------|------| +| `t4a_simplett_c64_release` | 解放 | +| `t4a_simplett_c64_clone` | 複製 | +| `t4a_simplett_c64_constant(site_dims, n_sites, value_re, value_im)` | 定数 TT 作成 | +| `t4a_simplett_c64_zeros(site_dims, n_sites)` | ゼロ TT 作成 | +| `t4a_simplett_c64_len(ptr, out)` | サイト数 | +| `t4a_simplett_c64_site_dims(ptr, buf, ...)` | サイト次元 | +| `t4a_simplett_c64_link_dims(ptr, buf, ...)` | ボンド次元 | +| `t4a_simplett_c64_rank(ptr, out)` | 最大ランク | +| `t4a_simplett_c64_evaluate(ptr, indices, n, out_re, out_im)` | 評価 → complex 値 | +| `t4a_simplett_c64_sum(ptr, out_re, out_im)` | 総和 → complex 値 | +| `t4a_simplett_c64_norm(ptr, out)` | ノルム → f64 | +| `t4a_simplett_c64_site_tensor(ptr, site, buf, buf_len, out_len, out_dims, ...)` | サイトテンソル → interleaved | +| `t4a_simplett_c64_compress(ptr, method, tol, max_bonddim)` | 圧縮 | +| `t4a_simplett_c64_partial_sum(ptr, dims, n_dims, out)` | 部分和 | + +**スカラー戻り値**: `evaluate`, `sum` は `out_re: *mut f64, out_im: *mut f64` の2引数で返す。 +**テンソルデータ**: `site_tensor` は interleaved buffer。 + +### 1.3 新規: TensorCI2 c64 + +**新規コールバック型:** +```rust +/// Complex evaluation callback. +/// result: interleaved [re, im] (2 doubles per point) +pub type EvalCallbackC64 = extern "C" fn( + indices: *const i64, + n_indices: libc::size_t, + result_re: *mut f64, + result_im: *mut f64, + user_data: *mut c_void, +) -> i32; +``` + +注: 単一値の戻りなので `result_re, result_im` の2ポインタ。 +バッチではなく1点評価のため interleaved バッファではなく分離で OK。 + +**新規オペーク型:** `t4a_tci2_c64` (wraps `TensorCI2`) + +**新規関数** (`tensorci.rs` に追加): + +f64 版と同じ 18 関数 + 高レベル関数: + +| 関数 | 差異 | +|------|------| +| `t4a_tci2_c64_new` | 同じ | +| `t4a_tci2_c64_release` | 同じ | +| `t4a_tci2_c64_sweep2site(ptr, eval_cb, user_data, ...)` | EvalCallbackC64 使用 | +| `t4a_tci2_c64_sweep1site(ptr, eval_cb, user_data, ...)` | EvalCallbackC64 使用 | +| `t4a_tci2_c64_fill_site_tensors(ptr, eval_cb, user_data)` | EvalCallbackC64 使用 | +| `t4a_tci2_c64_to_tensor_train(ptr, out)` | `*mut *mut t4a_simplett_c64` | +| `t4a_crossinterpolate2_c64(...)` | EvalCallbackC64 + `*mut *mut t4a_tci2_c64` | +| その他アクセサ | f64 版と同じシグネチャ(rank, link_dims, errors は f64 を返す) | + +### 1.4 新規: TreeTCI c64 + +**新規コールバック型:** +```rust +/// Complex batch evaluation callback. +/// results: interleaved [re0, im0, re1, im1, ...] (2 * n_points doubles) +pub type TreeTciBatchEvalCallbackC64 = extern "C" fn( + batch_data: *const libc::size_t, + n_sites: libc::size_t, + n_points: libc::size_t, + results: *mut libc::c_double, // interleaved, length = 2 * n_points + user_data: *mut c_void, +) -> i32; +``` + +バッチなので interleaved buffer を使う。 + +**新規オペーク型:** `t4a_treetci_c64` (wraps `SimpleTreeTci`) + +**新規関数** (`treetci.rs` に追加): + +| 関数 | 差異 | +|------|------| +| `t4a_treetci_c64_new` | 同じ | +| `t4a_treetci_c64_release` | 同じ | +| `t4a_treetci_c64_add_global_pivots` | 同じ | +| `t4a_treetci_c64_sweep(ptr, eval_cb, ...)` | TreeTciBatchEvalCallbackC64 | +| `t4a_treetci_c64_max_bond_error` | 同じ (f64 戻り) | +| `t4a_treetci_c64_max_rank` | 同じ | +| `t4a_treetci_c64_max_sample_value` | 同じ (f64 戻り) | +| `t4a_treetci_c64_bond_dims` | 同じ | +| `t4a_treetci_c64_to_treetn(ptr, eval_cb, ..., out_treetn)` | TreeTciBatchEvalCallbackC64、出力は `t4a_treetn` | +| `t4a_crossinterpolate_tree_c64(...)` | TreeTciBatchEvalCallbackC64 | + +**Graph / Options / Proposer は f64/c64 共通** — 型に依存しないので既存のものを再利用。 + +### 1.5 テスト方針 (C API) + +各モジュールに c64 テストを追加: + +- **tensor**: interleaved 版 `new_dense_c64` / `get_data_c64` のラウンドトリップ +- **simplett_c64**: complex constant TT の作成・評価・sum・norm・圧縮 +- **tci2_c64**: complex product function `f(idx) = prod((idx[s]+1) + i*(2*idx[s]+1))` でTCI実行 +- **treetci_c64**: 同じ complex product function で7-site branching tree、Rust parity テストと同じ tolerance (1e-12) + +--- + +## Part 2: Julia ラッパー変更 (Tensor4all.jl) + +### 2.1 Tensor: interleaved 対応 + +**変更:** `src/Tensor4all.jl` の `Tensor(inds, data::ComplexF64)` コンストラクタと `data()` アクセサ。 + +```julia +# Before (separated) +data_re = Cdouble[real(z) for z in flat_data] +data_im = Cdouble[imag(z) for z in flat_data] +ptr = C_API.t4a_tensor_new_dense_c64(r, index_ptrs, dims_vec, data_re, data_im) + +# After (interleaved) +data_interleaved = reinterpret(Cdouble, ComplexF64.(vec(data))) +ptr = C_API.t4a_tensor_new_dense_c64(r, index_ptrs, dims_vec, data_interleaved) +``` + +取得側: +```julia +# Before +buf_re = Vector{Cdouble}(undef, n) +buf_im = Vector{Cdouble}(undef, n) +C_API.t4a_tensor_get_data_c64(ptr, buf_re, buf_im, n, out_len) +buf = [ComplexF64(r, i) for (r, i) in zip(buf_re, buf_im)] + +# After +buf = Vector{Cdouble}(undef, 2 * n) +C_API.t4a_tensor_get_data_c64(ptr, buf, n, out_len) +result = reinterpret(ComplexF64, buf) +``` + +### 2.2 SimpleTT: c64 対応 + +`src/SimpleTT.jl` の `SimpleTensorTrain` を `SimpleTensorTrain{T}` にジェネリック化: + +```julia +mutable struct SimpleTensorTrain{T<:Union{Float64, ComplexF64}} + ptr::Ptr{Cvoid} +end +``` + +各メソッドで `T` に応じて `_f64` / `_c64` の C API 関数を dispatch: + +```julia +function _sym_for(::Type{Float64}, name::Symbol) = C_API._sym(Symbol("t4a_simplett_f64_", name)) +function _sym_for(::Type{ComplexF64}, name::Symbol) = C_API._sym(Symbol("t4a_simplett_c64_", name)) +``` + +`evaluate` の戻り値が `T` になる: +- `Float64`: `Ref{Cdouble}` → `Float64` +- `ComplexF64`: `Ref{Cdouble}` × 2 (re, im) → `ComplexF64` + +### 2.3 TensorCI: c64 対応 + +`src/TensorCI.jl` の `TensorCI2{T}` を拡張: + +```julia +mutable struct TensorCI2{T<:Union{Float64, ComplexF64}} + ptr::Ptr{Cvoid} + local_dims::Vector{Int} +end +``` + +コールバック trampoline を追加: +- f64: 既存 `_trampoline` (変更なし) +- c64: 新規 `_trampoline_c64` — Julia 関数が `ComplexF64` を返し、re/im に分離 + +```julia +function _trampoline_c64(indices_ptr, n_indices, result_re, result_im, user_data)::Cint + f = unsafe_pointer_to_objref(user_data)::Ref{Any} |> x -> x[] + indices = unsafe_wrap(Array, indices_ptr, Int(n_indices)) + val = ComplexF64(f(indices...)) + unsafe_store!(result_re, real(val)) + unsafe_store!(result_im, imag(val)) + Cint(0) +end +``` + +### 2.4 TreeTCI: c64 対応 + +`src/TreeTCI.jl` に c64 対応を追加: + +- `SimpleTreeTci{T}` にジェネリック化 +- バッチ trampoline: + - f64: 既存 (変更なし) + - c64: Julia 関数が `Vector{ComplexF64}` を返し、`reinterpret(Float64, vals)` で interleaved に変換 + +```julia +function _treetci_batch_trampoline_c64(batch_data, n_sites, n_points, results, user_data)::Cint + f = unsafe_pointer_to_objref(user_data)::Ref{Any} |> x -> x[] + batch = unsafe_wrap(Array, batch_data, (Int(n_sites), Int(n_points))) + vals = ComplexF64.(f(batch)) + interleaved = reinterpret(Float64, vals) + for i in eachindex(interleaved) + unsafe_store!(results, interleaved[i], i) + end + Cint(0) +end +``` + +`crossinterpolate_tree` も `T` パラメータで dispatch。 + +### 2.5 テスト方針 (Julia) + +既存の f64 テストに加え、各モジュールに c64 テストを追加: + +- **tensor**: ComplexF64 ラウンドトリップ(interleaved 版) +- **simplett_c64**: complex constant TT 作成・評価 +- **tci2_c64**: complex product function でTCI +- **treetci_c64**: complex product function で7-site branching tree + +Rust parity テストと **同一の関数・パラメータ・tolerance** を使用。 + +--- + +## 実装順序 + +| Step | リポジトリ | 内容 | +|------|-----------|------| +| 1 | tensor4all-rs | tensor API: separated → interleaved 移行 | +| 2 | tensor4all-rs | simplett_c64 追加 | +| 3 | tensor4all-rs | tci2_c64 追加 | +| 4 | tensor4all-rs | treetci_c64 追加 | +| 5 | Tensor4all.jl | tensor interleaved 対応 | +| 6 | Tensor4all.jl | SimpleTT{T} ジェネリック化 | +| 7 | Tensor4all.jl | TensorCI{T} ジェネリック化 | +| 8 | Tensor4all.jl | TreeTCI{T} ジェネリック化 | + +各ステップは独立した PR にできる。Step 1-4 は tensor4all-rs 内で1つの PR にまとめても可。 + +--- + +## 設計上の注意 + +1. **Graph / Options / Proposer は共通** — complex で変わらない +2. **norm, max_bond_error, max_sample_value は常に f64** — complex でも実数値 +3. **evaluate, sum の戻り値は T** — complex の場合は `(re, im)` ペアで返す +4. **site_tensor データは interleaved** — `length = 2 * n_elements` の f64 バッファ +5. **Python ラッパーは無視** — disable 方針 diff --git a/docs/specs/2026-03-30-treetci-capi-and-julia-wrapper-design.md b/docs/specs/2026-03-30-treetci-capi-and-julia-wrapper-design.md new file mode 100644 index 0000000..e757da4 --- /dev/null +++ b/docs/specs/2026-03-30-treetci-capi-and-julia-wrapper-design.md @@ -0,0 +1,1014 @@ +# TreeTCI C API & Julia Wrapper Design + +## Overview + +tensor4all-rs の `tensor4all-treetci` クレートで実装された TreeTCI (tree-structured tensor cross interpolation) を C API 経由で Julia から利用できるようにする。 + +**スコープ:** +1. tensor4all-rs 側: `tensor4all-capi` に TreeTCI の C API バインディングを追加 +2. Tensor4all.jl 側: C API を呼び出す Julia ラッパーモジュール `TreeTCI` を追加 + +**設計方針:** +- **ステートフル API (Approach B)**: `SimpleTreeTci` をオペークハンドルとして公開し、ピボット追加 → sweep → 検査 → materialization のライフサイクルを制御可能にする +- **単一バッチコールバック**: point eval と batch eval を統一。n_points=1 が point eval に相当 +- **Proposer 列挙型選択**: `DefaultProposer`, `SimpleProposer`, `TruncatedDefaultProposer` を enum で切り替え +- **高レベル便利関数**: `crossinterpolate_tree` 一発実行版も提供 + +--- + +## Part 1: C API (tensor4all-rs 側) + +### 1.1 ファイル構成 + +``` +crates/tensor4all-capi/ + src/ + treetci.rs # 新規: TreeTCI C API 関数すべて + types.rs # 既存: t4a_treetci_graph, t4a_treetci_f64 の型定義を追加 + lib.rs # 既存: mod treetci; pub use treetci::*; を追加 + Cargo.toml # 既存: tensor4all-treetci への依存を追加 +``` + +### 1.2 依存関係 (Cargo.toml) + +```toml +[dependencies] +tensor4all-treetci = { path = "../tensor4all-treetci" } +``` + +### 1.3 コールバック型 + +```rust +/// バッチ評価コールバック +/// +/// # Arguments +/// - `batch_data`: column-major (n_sites, n_points) のインデックス配列 +/// - batch_data[site + n_sites * point] でアクセス +/// - `n_sites`: サイト数 +/// - `n_points`: 評価点数 (n_points=1 のとき point eval 相当) +/// - `results`: 呼び出し側が n_points 個の f64 を書き込む出力バッファ +/// - `user_data`: ユーザーデータポインタ +/// +/// # Returns +/// 0 on success, non-zero on error +pub type TreeTciBatchEvalCallback = extern "C" fn( + batch_data: *const libc::size_t, + n_sites: libc::size_t, + n_points: libc::size_t, + results: *mut libc::c_double, + user_data: *mut c_void, +) -> i32; +``` + +Rust 側での closure 変換: + +```rust +fn make_batch_eval_closure( + eval_fn: TreeTciBatchEvalCallback, + user_data: *mut c_void, +) -> impl Fn(GlobalIndexBatch) -> Result> { + move |batch: GlobalIndexBatch| -> Result> { + let mut results = vec![0.0; batch.n_points()]; + let status = eval_fn( + batch.data().as_ptr(), + batch.n_sites(), + batch.n_points(), + results.as_mut_ptr(), + user_data, + ); + if status != 0 { + anyhow::bail!("Batch eval callback returned error status {}", status); + } + Ok(results) + } +} + +fn make_point_eval_closure( + eval_fn: TreeTciBatchEvalCallback, + user_data: *mut c_void, +) -> impl Fn(&[usize]) -> f64 { + move |indices: &[usize]| -> f64 { + let mut result: f64 = 0.0; + let status = eval_fn( + indices.as_ptr(), + indices.len(), + 1, + &mut result, + user_data, + ); + if status != 0 { f64::NAN } else { result } + } +} +``` + +### 1.4 列挙型 + +```rust +/// Proposer 選択 +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum t4a_treetci_proposer_kind { + /// DefaultProposer: neighbor-product (TreeTCI.jl と同等) + Default = 0, + /// SimpleProposer: ランダム (seed ベース) + Simple = 1, + /// TruncatedDefaultProposer: Default の truncated ランダムサブセット + TruncatedDefault = 2, +} +``` + +`types.rs` に追加し、`From` で Rust の proposer オブジェクトに変換する。 + +### 1.5 オペーク型 + +#### t4a_treetci_graph + +```rust +// types.rs に追加 +#[repr(C)] +pub struct t4a_treetci_graph { + pub(crate) _private: *const c_void, +} + +impl t4a_treetci_graph { + pub(crate) fn new(inner: TreeTciGraph) -> Self { + Self { + _private: Box::into_raw(Box::new(inner)) as *const c_void, + } + } + + pub(crate) fn inner(&self) -> &TreeTciGraph { + unsafe { &*(self._private as *const TreeTciGraph) } + } +} + +impl Clone for t4a_treetci_graph { + fn clone(&self) -> Self { + Self::new(self.inner().clone()) + } +} + +impl Drop for t4a_treetci_graph { + fn drop(&mut self) { + if !self._private.is_null() { + unsafe { let _ = Box::from_raw(self._private as *mut TreeTciGraph); } + } + } +} + +unsafe impl Send for t4a_treetci_graph {} +unsafe impl Sync for t4a_treetci_graph {} +``` + +#### t4a_treetci_f64 + +```rust +// types.rs に追加 +#[repr(C)] +pub struct t4a_treetci_f64 { + pub(crate) _private: *const c_void, +} + +impl t4a_treetci_f64 { + pub(crate) fn new(inner: SimpleTreeTci) -> Self { + Self { + _private: Box::into_raw(Box::new(inner)) as *const c_void, + } + } + + pub(crate) fn inner(&self) -> &SimpleTreeTci { + unsafe { &*(self._private as *const SimpleTreeTci) } + } + + pub(crate) fn inner_mut(&mut self) -> &mut SimpleTreeTci { + unsafe { &mut *(self._private as *mut SimpleTreeTci) } + } +} + +impl Drop for t4a_treetci_f64 { + fn drop(&mut self) { + if !self._private.is_null() { + unsafe { let _ = Box::from_raw(self._private as *mut SimpleTreeTci); } + } + } +} + +// Clone は実装しない (TCI2 と同様) +unsafe impl Send for t4a_treetci_f64 {} +unsafe impl Sync for t4a_treetci_f64 {} +``` + +### 1.6 C API 関数一覧 + +すべて `treetci.rs` に実装。各関数は `catch_unwind(AssertUnwindSafe(...))` でパニック保護。 + +#### 1.6.1 グラフ ライフサイクル + +```rust +impl_opaque_type_common!(treetci_graph); +// 生成: t4a_treetci_graph_release, t4a_treetci_graph_clone, t4a_treetci_graph_is_assigned +``` + +```rust +/// 木グラフを作成 +/// +/// # Arguments +/// - `n_sites`: サイト数 (>= 1) +/// - `edges_flat`: エッジ配列 [u0, v0, u1, v1, ...] (length = n_edges * 2) +/// - `n_edges`: エッジ数 (n_sites - 1 であること) +/// +/// # Returns +/// 新しい t4a_treetci_graph ポインタ。エラー時は NULL。 +/// バリデーション: 連結性、辺数 = n_sites - 1、自己ループなし、重複辺なし +/// +/// # Safety +/// edges_flat は n_edges * 2 個の size_t を含む有効なバッファであること +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_graph_new( + n_sites: libc::size_t, + edges_flat: *const libc::size_t, + n_edges: libc::size_t, +) -> *mut t4a_treetci_graph; + +/// サイト数を取得 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_graph_n_sites( + graph: *const t4a_treetci_graph, + out: *mut libc::size_t, +) -> StatusCode; + +/// エッジ数を取得 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_graph_n_edges( + graph: *const t4a_treetci_graph, + out: *mut libc::size_t, +) -> StatusCode; +``` + +#### 1.6.2 ステート ライフサイクル + +```rust +/// SimpleTreeTci ステートを作成 +/// +/// # Arguments +/// - `local_dims`: 各サイトの局所次元 (length = n_sites) +/// - `n_sites`: サイト数 (graph の n_sites と一致すること) +/// - `graph`: 木グラフハンドル (所有権は移転しない、内部で clone) +/// +/// # Returns +/// 新しい t4a_treetci_f64 ポインタ。エラー時は NULL。 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_new( + local_dims: *const libc::size_t, + n_sites: libc::size_t, + graph: *const t4a_treetci_graph, +) -> *mut t4a_treetci_f64; + +/// ステートを解放 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_release( + ptr: *mut t4a_treetci_f64, +); +``` + +#### 1.6.3 ピボット管理 + +```rust +/// グローバルピボットを追加 +/// +/// 各ピボットは全サイトのインデックスを持つマルチインデックス。 +/// pivots_flat は column-major (n_sites, n_pivots) レイアウト。 +/// +/// # Arguments +/// - `ptr`: ステートハンドル +/// - `pivots_flat`: column-major (n_sites, n_pivots) のインデックス配列 +/// - `n_sites`: サイト数 (ステートの n_sites と一致すること) +/// - `n_pivots`: ピボット数 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_add_global_pivots( + ptr: *mut t4a_treetci_f64, + pivots_flat: *const libc::size_t, + n_sites: libc::size_t, + n_pivots: libc::size_t, +) -> StatusCode; +``` + +#### 1.6.4 Sweep 実行 + +```rust +/// 1イテレーション実行(全辺を1回訪問) +/// +/// 内部で AllEdges visitor を使い、指定された proposer で候補を生成し、 +/// matrixluci で pivot 選択を行い、ステートを更新する。 +/// +/// # Arguments +/// - `ptr`: ステートハンドル (mutable) +/// - `eval_cb`: バッチ評価コールバック +/// - `user_data`: コールバックに渡すユーザーデータ +/// - `proposer_kind`: proposer 選択 (Default=0, Simple=1, TruncatedDefault=2) +/// - `tolerance`: このイテレーションの相対 tolerance +/// - `max_bond_dim`: 最大ボンド次元 (0 = 無制限) +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_sweep( + ptr: *mut t4a_treetci_f64, + eval_cb: TreeTciBatchEvalCallback, + user_data: *mut c_void, + proposer_kind: t4a_treetci_proposer_kind, + tolerance: libc::c_double, + max_bond_dim: libc::size_t, +) -> StatusCode; +``` + +#### 1.6.5 ステート検査 + +```rust +/// 全辺の最大ボンドエラーを取得 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_max_bond_error( + ptr: *const t4a_treetci_f64, + out: *mut libc::c_double, +) -> StatusCode; + +/// 最大ボンド次元(現在の最大 rank)を取得 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_max_rank( + ptr: *const t4a_treetci_f64, + out: *mut libc::size_t, +) -> StatusCode; + +/// 観測された最大サンプル値を取得(正規化用) +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_max_sample_value( + ptr: *const t4a_treetci_f64, + out: *mut libc::c_double, +) -> StatusCode; + +/// 各辺のボンド次元を取得 +/// +/// # Arguments +/// - `out_ranks`: 出力バッファ (length >= n_edges) +/// - `buf_len`: バッファサイズ +/// - `out_n_edges`: 実際のエッジ数を出力 +/// +/// query-then-fill パターン: out_ranks=NULL で out_n_edges のみ取得可能 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_bond_dims( + ptr: *const t4a_treetci_f64, + out_ranks: *mut libc::size_t, + buf_len: libc::size_t, + out_n_edges: *mut libc::size_t, +) -> StatusCode; +``` + +#### 1.6.6 Materialization + +```rust +/// 収束したステートから TreeTN を構築 +/// +/// 内部で batch_eval を使ってテンソル値を再評価し、LU 分解で +/// 各辺のテンソルを構築する。 +/// +/// # Arguments +/// - `ptr`: ステートハンドル (const — ステート自体は変更しない) +/// - `eval_cb`: バッチ評価コールバック (materialization 時にテンソル値を再評価) +/// - `user_data`: コールバックに渡すユーザーデータ +/// - `center_site`: BFS ルートサイト (materialization の中心) +/// - `out_treetn`: 出力 TreeTN ハンドルポインタ +/// +/// # Returns +/// 既存の t4a_treetn 型として返す。呼び出し側は t4a_treetn_release で解放。 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_treetci_f64_to_treetn( + ptr: *const t4a_treetci_f64, + eval_cb: TreeTciBatchEvalCallback, + user_data: *mut c_void, + center_site: libc::size_t, + out_treetn: *mut *mut t4a_treetn, +) -> StatusCode; +``` + +#### 1.6.7 高レベル便利関数 + +```rust +/// TreeTCI を一発実行して TreeTN を取得 +/// +/// 内部で以下を実行: +/// 1. SimpleTreeTci を作成 +/// 2. initial_pivots を追加 +/// 3. max_iter 回まで sweep を繰り返す (convergence check あり) +/// 4. TreeTN に materialize +/// +/// # Arguments +/// - `eval_cb`: バッチ評価コールバック +/// - `user_data`: コールバックに渡すユーザーデータ +/// - `local_dims`: 各サイトの局所次元 (length = n_sites) +/// - `n_sites`: サイト数 +/// - `graph`: 木グラフハンドル +/// - `initial_pivots_flat`: column-major (n_sites, n_pivots)、NULL 可 (空ピボット) +/// - `n_pivots`: 初期ピボット数 +/// - `proposer_kind`: proposer 選択 +/// - `tolerance`: 相対 tolerance +/// - `max_bond_dim`: 最大ボンド次元 (0 = 無制限) +/// - `max_iter`: 最大イテレーション数 +/// - `normalize_error`: 正規化フラグ (0=false, 1=true) +/// - `center_site`: materialization の中心サイト +/// - `out_treetn`: 出力 TreeTN ハンドル +/// - `out_ranks`: 各イテレーションの最大 rank (バッファ length >= max_iter, NULL可) +/// - `out_errors`: 各イテレーションの正規化エラー (バッファ length >= max_iter, NULL可) +/// - `out_n_iters`: 実際のイテレーション数 +#[unsafe(no_mangle)] +pub extern "C" fn t4a_crossinterpolate_tree_f64( + eval_cb: TreeTciBatchEvalCallback, + user_data: *mut c_void, + local_dims: *const libc::size_t, + n_sites: libc::size_t, + graph: *const t4a_treetci_graph, + initial_pivots_flat: *const libc::size_t, + n_pivots: libc::size_t, + proposer_kind: t4a_treetci_proposer_kind, + tolerance: libc::c_double, + max_bond_dim: libc::size_t, + max_iter: libc::size_t, + normalize_error: libc::c_int, + center_site: libc::size_t, + out_treetn: *mut *mut t4a_treetn, + out_ranks: *mut libc::size_t, + out_errors: *mut libc::c_double, + out_n_iters: *mut libc::size_t, +) -> StatusCode; +``` + +### 1.7 テスト方針 + +`crates/tensor4all-capi/tests/test_treetci.rs` に統合テスト: + +1. **グラフ構築テスト**: 正常ケース + 不正なグラフ (非連結、自己ループ) のバリデーション +2. **ステートフル API テスト**: 7-site tree 上の既知関数で new → add_pivots → sweep → 検査 → to_treetn の全ライフサイクル +3. **高レベル関数テスト**: `crossinterpolate_tree_f64` で同じ既知関数を一発実行し、結果を比較 +4. **バッチコールバックテスト**: n_points > 1 のバッチが正しく column-major で渡されることを検証 + +TreeTCI.jl の既存 parity テストと同じテスト関数 (7-site tree) を使用。 + +--- + +## Part 2: Julia ラッパー (Tensor4all.jl 側) + +### 2.1 ファイル構成 + +``` +src/ + TreeTCI.jl # 新規: TreeTCI モジュール + C_API.jl # 既存: _sym() に新しい関数名を追加 + Tensor4all.jl # 既存: include("TreeTCI.jl") を追加 +test/ + test_treetci.jl # 新規: TreeTCI テスト + runtests.jl # 既存: test_treetci.jl を include +``` + +### 2.2 C API バインディング追加 (`src/C_API.jl`) + +既存の `_sym(name)` パターンで新しい関数シンボルを解決。追加コードは不要(`_sym` は動的に `dlsym` するため)。 + +### 2.3 モジュール定義 (`src/TreeTCI.jl`) + +```julia +module TreeTCI + +using ..Tensor4all: C_API, TreeTN # 内部依存 + +export TreeTciGraph, SimpleTreeTci +export crossinterpolate_tree + +# ============================================================================ +# TreeTciGraph +# ============================================================================ + +""" + TreeTciGraph(n_sites, edges) + +木グラフ構造を定義する。 + +# Arguments +- `n_sites::Int`: サイト数 +- `edges::Vector{Tuple{Int,Int}}`: エッジのリスト (0-based site indices) + +# Example +```julia +# 線形チェーン: 0-1-2-3 +graph = TreeTciGraph(4, [(0,1), (1,2), (2,3)]) + +# スターグラフ: 0 を中心に 1,2,3 が接続 +graph = TreeTciGraph(4, [(0,1), (0,2), (0,3)]) +``` +""" +mutable struct TreeTciGraph + ptr::Ptr{Cvoid} + n_sites::Int + + function TreeTciGraph(n_sites::Int, edges::Vector{Tuple{Int,Int}}) + n_edges = length(edges) + edges_flat = Vector{Csize_t}(undef, n_edges * 2) + for (i, (u, v)) in enumerate(edges) + edges_flat[2i - 1] = u + edges_flat[2i] = v + end + ptr = ccall( + C_API._sym(:t4a_treetci_graph_new), + Ptr{Cvoid}, + (Csize_t, Ptr{Csize_t}, Csize_t), + n_sites, edges_flat, n_edges, + ) + ptr == C_NULL && error("Failed to create TreeTciGraph: $(C_API.last_error())") + obj = new(ptr, n_sites) + finalizer(obj) do x + if x.ptr != C_NULL + ccall(C_API._sym(:t4a_treetci_graph_release), Cvoid, (Ptr{Cvoid},), x.ptr) + x.ptr = C_NULL + end + end + obj + end +end + +# ============================================================================ +# Batch Eval Trampoline +# ============================================================================ + +""" +コールバックトランポリン。 + +C側から呼ばれ、Julia のユーザー関数を呼び出す。 +batch_data は column-major (n_sites, n_points)。 +ユーザー関数のシグネチャ: f(batch::Matrix{Int}) -> Vector{Float64} +""" +function _treetci_batch_trampoline( + batch_data::Ptr{Csize_t}, + n_sites::Csize_t, + n_points::Csize_t, + results::Ptr{Cdouble}, + user_data::Ptr{Cvoid}, +)::Cint + try + f_ref = unsafe_pointer_to_objref(user_data)::Ref{Any} + f = f_ref[] + # column-major (n_sites, n_points) を Julia Matrix として wrap + batch = unsafe_wrap(Array, batch_data, (Int(n_sites), Int(n_points))) + vals = f(batch) + for i in 1:Int(n_points) + unsafe_store!(results, vals[i], i) + end + return Cint(0) + catch e + # エラーを stderr に出力し、非ゼロを返す + @error "TreeTCI batch eval callback error" exception = (e, catch_backtrace()) + return Cint(-1) + end +end + +# C function pointer (モジュールロード時に一度だけ生成) +const _BATCH_TRAMPOLINE_PTR = Ref{Ptr{Cvoid}}(C_NULL) + +function _get_batch_trampoline() + if _BATCH_TRAMPOLINE_PTR[] == C_NULL + _BATCH_TRAMPOLINE_PTR[] = @cfunction( + _treetci_batch_trampoline, + Cint, + (Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Cdouble}, Ptr{Cvoid}), + ) + end + _BATCH_TRAMPOLINE_PTR[] +end + +# ============================================================================ +# Proposer 変換 +# ============================================================================ + +const PROPOSER_DEFAULT = Cint(0) +const PROPOSER_SIMPLE = Cint(1) +const PROPOSER_TRUNCATED_DEFAULT = Cint(2) + +function _proposer_to_cint(proposer::Symbol)::Cint + if proposer == :default + PROPOSER_DEFAULT + elseif proposer == :simple + PROPOSER_SIMPLE + elseif proposer == :truncated_default + PROPOSER_TRUNCATED_DEFAULT + else + error("Unknown proposer: $proposer. Use :default, :simple, or :truncated_default") + end +end + +# ============================================================================ +# SimpleTreeTci +# ============================================================================ + +""" + SimpleTreeTci(local_dims, graph) + +ステートフルな TreeTCI オブジェクト。 + +# Arguments +- `local_dims::Vector{Int}`: 各サイトの局所次元 (length = graph.n_sites) +- `graph::TreeTciGraph`: 木グラフ構造 + +# Lifecycle +```julia +tci = SimpleTreeTci(local_dims, graph) +add_global_pivots!(tci, pivots) +for i in 1:max_iter + sweep!(tci, f; tolerance=1e-8) + @info "iter \$i" max_bond_error(tci) max_rank(tci) + max_bond_error(tci) < tolerance && break +end +ttn = to_treetn(tci, f) +``` +""" +mutable struct SimpleTreeTci + ptr::Ptr{Cvoid} + graph::TreeTciGraph # GC 保護のため参照保持 + local_dims::Vector{Int} + + function SimpleTreeTci(local_dims::Vector{Int}, graph::TreeTciGraph) + @assert length(local_dims) == graph.n_sites + dims_csize = Csize_t.(local_dims) + ptr = ccall( + C_API._sym(:t4a_treetci_f64_new), + Ptr{Cvoid}, + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}), + dims_csize, length(dims_csize), graph.ptr, + ) + ptr == C_NULL && error("Failed to create SimpleTreeTci: $(C_API.last_error())") + obj = new(ptr, graph, local_dims) + finalizer(obj) do x + if x.ptr != C_NULL + ccall(C_API._sym(:t4a_treetci_f64_release), Cvoid, (Ptr{Cvoid},), x.ptr) + x.ptr = C_NULL + end + end + obj + end +end + +# ============================================================================ +# ピボット管理 +# ============================================================================ + +""" + add_global_pivots!(tci, pivots) + +グローバルピボットを追加する。 + +# Arguments +- `tci::SimpleTreeTci`: ステート +- `pivots::Vector{Vector{Int}}`: 各ピボットは全サイトのインデックス (0-based) +""" +function add_global_pivots!(tci::SimpleTreeTci, pivots::Vector{Vector{Int}}) + n_sites = length(tci.local_dims) + n_pivots = length(pivots) + n_pivots == 0 && return + # column-major (n_sites, n_pivots) に pack + pivots_flat = Vector{Csize_t}(undef, n_sites * n_pivots) + for j in 1:n_pivots + @assert length(pivots[j]) == n_sites + for i in 1:n_sites + pivots_flat[i + n_sites * (j - 1)] = pivots[j][i] + end + end + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_add_global_pivots), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t), + tci.ptr, pivots_flat, n_sites, n_pivots, + )) +end + +# ============================================================================ +# Sweep 実行 +# ============================================================================ + +""" + sweep!(tci, f; proposer=:default, tolerance=1e-8, max_bond_dim=0) + +1イテレーション(全辺を1回訪問)を実行する。 + +# Arguments +- `tci::SimpleTreeTci`: ステート +- `f`: 評価関数 `f(batch::Matrix{Int}) -> Vector{Float64}` + - `batch` は column-major (n_sites, n_points), 0-based indices + - 戻り値は n_points 個の Float64 +- `proposer`: `:default`, `:simple`, `:truncated_default` +- `tolerance`: 相対 tolerance +- `max_bond_dim`: 最大ボンド次元 (0 = 無制限) +""" +function sweep!(tci::SimpleTreeTci, f; + proposer::Symbol = :default, + tolerance::Float64 = 1e-8, + max_bond_dim::Int = 0, +) + f_ref = Ref{Any}(f) + GC.@preserve f_ref begin + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_sweep), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cint, Cdouble, Csize_t), + tci.ptr, + _get_batch_trampoline(), + pointer_from_objref(f_ref), + _proposer_to_cint(proposer), + tolerance, + max_bond_dim, + )) + end +end + +# ============================================================================ +# ステート検査 +# ============================================================================ + +"""全辺の最大ボンドエラーを取得""" +function max_bond_error(tci::SimpleTreeTci)::Float64 + out = Ref{Cdouble}(0.0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_max_bond_error), + Cint, (Ptr{Cvoid}, Ptr{Cdouble}), + tci.ptr, out, + )) + out[] +end + +"""最大ボンド次元を取得""" +function max_rank(tci::SimpleTreeTci)::Int + out = Ref{Csize_t}(0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_max_rank), + Cint, (Ptr{Cvoid}, Ptr{Csize_t}), + tci.ptr, out, + )) + Int(out[]) +end + +"""観測された最大サンプル値を取得""" +function max_sample_value(tci::SimpleTreeTci)::Float64 + out = Ref{Cdouble}(0.0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_max_sample_value), + Cint, (Ptr{Cvoid}, Ptr{Cdouble}), + tci.ptr, out, + )) + out[] +end + +"""各辺のボンド次元を取得""" +function bond_dims(tci::SimpleTreeTci)::Vector{Int} + # query: エッジ数を取得 + n_edges_ref = Ref{Csize_t}(0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_bond_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + tci.ptr, C_NULL, 0, n_edges_ref, + )) + n_edges = Int(n_edges_ref[]) + # fill: バッファに書き込み + buf = Vector{Csize_t}(undef, n_edges) + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_bond_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + tci.ptr, buf, n_edges, n_edges_ref, + )) + Int.(buf) +end + +# ============================================================================ +# Materialization +# ============================================================================ + +""" + to_treetn(tci, f; center_site=0) + +収束した TreeTCI ステートから TreeTensorNetwork を構築する。 + +# Arguments +- `tci::SimpleTreeTci`: 収束したステート +- `f`: 評価関数 (sweep! と同じシグネチャ) +- `center_site`: materialization の中心サイト (0-based) + +# Returns +- `TreeTensorNetwork`: 既存の TreeTN ラッパー型 +""" +function to_treetn(tci::SimpleTreeTci, f; center_site::Int = 0) + f_ref = Ref{Any}(f) + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + GC.@preserve f_ref begin + C_API.check_status(ccall( + C_API._sym(:t4a_treetci_f64_to_treetn), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Ptr{Ptr{Cvoid}}), + tci.ptr, + _get_batch_trampoline(), + pointer_from_objref(f_ref), + center_site, + out_ptr, + )) + end + TreeTN.TreeTensorNetwork(out_ptr[]) +end + +# ============================================================================ +# 高レベル便利関数 +# ============================================================================ + +""" + crossinterpolate_tree(f, local_dims, graph; kwargs...) -> (ttn, ranks, errors) + +TreeTCI を一発実行して TreeTensorNetwork を取得する。 + +# Arguments +- `f`: 評価関数 `f(batch::Matrix{Int}) -> Vector{Float64}` +- `local_dims::Vector{Int}`: 各サイトの局所次元 +- `graph::TreeTciGraph`: 木グラフ構造 + +# Keyword Arguments +- `initial_pivots::Vector{Vector{Int}} = []`: 初期ピボット (0-based) +- `proposer::Symbol = :default`: proposer 選択 +- `tolerance::Float64 = 1e-8`: 相対 tolerance +- `max_bond_dim::Int = 0`: 最大ボンド次元 (0=無制限) +- `max_iter::Int = 20`: 最大イテレーション数 +- `normalize_error::Bool = true`: エラー正規化フラグ +- `center_site::Int = 0`: materialization の中心サイト + +# Returns +- `ttn::TreeTensorNetwork`: 結果のテンソルネットワーク +- `ranks::Vector{Int}`: 各イテレーションの最大 rank +- `errors::Vector{Float64}`: 各イテレーションの正規化エラー +""" +function crossinterpolate_tree( + f, local_dims::Vector{Int}, graph::TreeTciGraph; + initial_pivots::Vector{Vector{Int}} = Vector{Int}[], + proposer::Symbol = :default, + tolerance::Float64 = 1e-8, + max_bond_dim::Int = 0, + max_iter::Int = 20, + normalize_error::Bool = true, + center_site::Int = 0, +) + n_sites = length(local_dims) + n_pivots = length(initial_pivots) + + # initial_pivots を column-major (n_sites, n_pivots) に pack + pivots_flat = if n_pivots > 0 + buf = Vector{Csize_t}(undef, n_sites * n_pivots) + for j in 1:n_pivots + for i in 1:n_sites + buf[i + n_sites * (j - 1)] = initial_pivots[j][i] + end + end + buf + else + Csize_t[] + end + + # 出力バッファ (max_iter 分を事前確保) + out_ranks = Vector{Csize_t}(undef, max_iter) + out_errors = Vector{Cdouble}(undef, max_iter) + out_n_iters = Ref{Csize_t}(0) + out_treetn = Ref{Ptr{Cvoid}}(C_NULL) + + f_ref = Ref{Any}(f) + GC.@preserve f_ref pivots_flat out_ranks out_errors begin + C_API.check_status(ccall( + C_API._sym(:t4a_crossinterpolate_tree_f64), + Cint, + ( + Ptr{Cvoid}, Ptr{Cvoid}, # eval_cb, user_data + Ptr{Csize_t}, Csize_t, # local_dims, n_sites + Ptr{Cvoid}, # graph + Ptr{Csize_t}, Csize_t, # initial_pivots, n_pivots + Cint, # proposer_kind + Cdouble, Csize_t, Csize_t, # tol, max_bond_dim, max_iter + Cint, # normalize_error + Csize_t, # center_site + Ptr{Ptr{Cvoid}}, # out_treetn + Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}, # out_ranks, errors, n_iters + ), + _get_batch_trampoline(), + pointer_from_objref(f_ref), + Csize_t.(local_dims), n_sites, + graph.ptr, + n_pivots > 0 ? pivots_flat : C_NULL, n_pivots, + _proposer_to_cint(proposer), + tolerance, max_bond_dim, max_iter, + normalize_error ? Cint(1) : Cint(0), + center_site, + out_treetn, + out_ranks, out_errors, out_n_iters, + )) + end + + n_iters = Int(out_n_iters[]) + ttn = TreeTN.TreeTensorNetwork(out_treetn[]) + ranks = Int.(out_ranks[1:n_iters]) + errors = Float64.(out_errors[1:n_iters]) + return (ttn, ranks, errors) +end + +end # module TreeTCI +``` + +### 2.4 メインモジュールへの統合 (`src/Tensor4all.jl`) + +```julia +# 既存の include の後に追加 +include("TreeTCI.jl") +``` + +### 2.5 テスト (`test/test_treetci.jl`) + +```julia +using Tensor4all.TreeTCI +using Test + +@testset "TreeTCI" begin + @testset "TreeTciGraph" begin + # 線形チェーン + graph = TreeTciGraph(4, [(0,1), (1,2), (2,3)]) + # スターグラフ + graph_star = TreeTciGraph(4, [(0,1), (0,2), (0,3)]) + # 不正なグラフはエラー + @test_throws ErrorException TreeTciGraph(4, [(0,1), (2,3)]) # 非連結 + end + + @testset "Stateful API" begin + # 7-site star tree (TreeTCI.jl parity test に相当) + n_sites = 7 + local_dims = fill(2, n_sites) + edges = [(0,i) for i in 1:6] # site 0 を中心としたスター + graph = TreeTciGraph(n_sites, edges) + + # テスト関数: 全サイトの積 + function f_batch(batch::Matrix{<:Integer}) + n_pts = size(batch, 2) + results = Vector{Float64}(undef, n_pts) + for j in 1:n_pts + val = 1.0 + for i in 1:size(batch, 1) + val *= (batch[i, j] + 1.0) + end + results[j] = val + end + results + end + + tci = SimpleTreeTci(local_dims, graph) + add_global_pivots!(tci, [zeros(Int, n_sites)]) # all-zero pivot + + for iter in 1:10 + sweep!(tci, f_batch; tolerance=1e-12) + max_bond_error(tci) < 1e-12 && break + end + + @test max_bond_error(tci) < 1e-10 + @test max_rank(tci) >= 1 + + ttn = to_treetn(tci, f_batch) + # ttn を使って既存の TreeTN API で検証可能 + end + + @testset "High-level API" begin + n_sites = 4 + local_dims = fill(3, n_sites) + graph = TreeTciGraph(n_sites, [(0,1), (1,2), (2,3)]) + + function f_batch(batch) + [sum(Float64, batch[:, j]) for j in 1:size(batch, 2)] + end + + ttn, ranks, errors = crossinterpolate_tree( + f_batch, local_dims, graph; + initial_pivots = [zeros(Int, n_sites)], + tolerance = 1e-10, + max_iter = 20, + ) + + @test length(ranks) > 0 + @test last(errors) < 1e-8 + end +end +``` + +--- + +## まとめ: 実装順序 + +| Step | リポジトリ | 内容 | +|------|-----------|------| +| 1 | tensor4all-rs | `Cargo.toml` に `tensor4all-treetci` 依存追加 | +| 2 | tensor4all-rs | `types.rs` に `t4a_treetci_graph`, `t4a_treetci_f64`, `t4a_treetci_proposer_kind` 追加 | +| 3 | tensor4all-rs | `treetci.rs` 新規作成: 全 C API 関数実装 | +| 4 | tensor4all-rs | `lib.rs` に `mod treetci; pub use treetci::*;` 追加 | +| 5 | tensor4all-rs | `tests/test_treetci.rs` で C API テスト | +| 6 | Tensor4all.jl | `src/TreeTCI.jl` 新規作成 | +| 7 | Tensor4all.jl | `src/Tensor4all.jl` に include 追加 | +| 8 | Tensor4all.jl | `test/test_treetci.jl` で統合テスト | diff --git a/docs/specs/2026-03-30-treetci-global-pivot-design.md b/docs/specs/2026-03-30-treetci-global-pivot-design.md new file mode 100644 index 0000000..eb3ff43 --- /dev/null +++ b/docs/specs/2026-03-30-treetci-global-pivot-design.md @@ -0,0 +1,517 @@ +# TreeTCI Global Pivot Search 詳細設計 + +## Overview + +TreeTCI の optimize ループに global pivot search を追加し、局所的な特徴を持つ関数の収束を改善する。 + +## スコープ + +1. `TreeTN::evaluate_batch` — batch 版 evaluate を TreeTN に追加 +2. `DefaultTreeGlobalPivotFinder` — greedy local search (TCI2 の finder と同アルゴリズム) +3. `TreeTciOptions` にパラメータ追加 +4. `optimize_with_proposer` ループに組み込み +5. テスト: chain tree + TCI.jl parity の nasty function + +## スコープ外 + +- `TreeTN::evaluate_batch` の部分木キャッシュ最適化 (将来タスク) +- index type の generic 化 (usize 固定のまま) +- C API / Julia 側の変更 (Rust のみ) + +--- + +## 1. TreeTN::evaluate — batch 版に置き換え + +### 設計方針 + +batch evaluation が唯一の API。`evaluate` という名前で batch を受け取る。 +1点評価は `evaluate(&[single_idx])` で表現。 + +### 現状 + +```rust +// crates/tensor4all-treetn/src/treetn/ops.rs +impl TreeTN { + pub fn evaluate( + &self, + index_values: &HashMap>, + ) -> Result +} +``` + +1点ずつ `HashMap` を構築して評価。 + +### 変更 + +既存の 1点版 `evaluate` を **削除** し、batch 版 `evaluate` に置き換える。 +1点評価は `evaluate(&[single_idx])` で表現。 + +```rust +// crates/tensor4all-treetn/src/treetn/ops.rs + +/// Evaluate the TreeTN at multiple multi-indices (batch). +/// +/// Each element of `indices` is a HashMap mapping vertex names to their +/// site index values. Returns one AnyScalar per evaluation point. +/// +/// For single-point evaluation, pass a slice of length 1. +/// +/// Internal implementation evaluates point-by-point. +/// Future: subtree cache optimization for shared prefixes. +pub fn evaluate( + &self, + indices: &[HashMap>], +) -> Result> +where + ::Id: + Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, +{ + // existing single-point logic inlined here, called per point + indices.iter().map(|idx| { + // ... existing evaluate body (onehot contraction) ... + }).collect() +} +``` + +旧 `evaluate(&HashMap)` の呼び出し箇所を全て `evaluate(&[idx])` に更新する。 +これは **破壊的変更** だが、early development のため後方互換性不要 (AGENTS.md)。 + +### TreeTCI 向けヘルパー + +vertex 名 = `usize` で各 vertex が 1 site index のみ持つケースに特化: + +```rust +/// Evaluate at multiple multi-indices given as flat site-order arrays. +/// +/// 各 multi-index は `[idx_site0, idx_site1, ..., idx_site_{n-1}]` (0-based)。 +/// vertex 名が usize の TreeTN 専用。 +/// +/// 1点評価: `evaluate_at_site_indices(&[vec![0, 1, 2]])` +pub fn evaluate_at_site_indices( + &self, + indices: &[Vec], +) -> Result> +where + V: From + Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, + ::Id: + Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, +{ + let n_sites = self.node_count(); + let hash_indices: Vec>> = indices.iter() + .map(|multi_idx| { + (0..n_sites) + .map(|s| (V::from(s), vec![multi_idx[s]])) + .collect() + }) + .collect(); + self.evaluate(&hash_indices) +} +``` + +### C API への影響 + +既存の `t4a_treetn_evaluate_batch` は内部で旧 1点版 `evaluate` を +point ごとに呼んでいる。新しい batch 版 `evaluate` を使うように更新。 + +### 既存コードの呼び出し箇所の更新 + +旧 `evaluate(&HashMap)` を呼んでいる箇所を全て +`evaluate(&[hashmap])[0]` に置き換える。grep で洗い出して全更新。 + +--- + +## 2. TreeTciOptions パラメータ追加 + +```rust +// crates/tensor4all-treetci/src/optimize.rs + +#[derive(Clone, Debug)] +pub struct TreeTciOptions { + // 既存 + pub tolerance: f64, + pub max_iter: usize, + pub max_bond_dim: usize, + pub normalize_error: bool, + + // Global pivot search + /// Run global pivot search every N iterations. 0 = disabled (default). + pub global_search_interval: usize, + /// Maximum number of global pivots to add per search round. + pub max_global_pivots: usize, + /// Number of random starting points for greedy search. + pub num_global_searches: usize, + /// Only add pivots where error > tolerance × this margin. + pub global_search_tol_margin: f64, +} + +impl Default for TreeTciOptions { + fn default() -> Self { + Self { + tolerance: 1e-8, + max_iter: 20, + max_bond_dim: usize::MAX, + normalize_error: true, + global_search_interval: 0, + max_global_pivots: 5, + num_global_searches: 5, + global_search_tol_margin: 10.0, + } + } +} +``` + +--- + +## 3. DefaultTreeGlobalPivotFinder + +### ファイル + +`crates/tensor4all-treetci/src/globalpivot.rs` + +### アルゴリズム + +TCI2 の `DefaultGlobalPivotFinder` と同じ greedy local search。 +全ての関数評価は **batch** で行う。 + +``` +for each of num_searches random starting points: + current_point = random_point(local_dims) + best_error = 0 + best_point = current_point + + for each site p in 0..n_sites: + # Batch: local_dims[p] 個の候補点を一括評価 + candidates = [current_point with site p = v for v in 0..local_dims[p]] + + f_vals = batch_eval(candidates) # 真値 (batch) + approx_vals = treetn.evaluate_at_site_indices(candidates) # 近似値 + + errors = |f_vals - approx_vals| + v_best = argmax(errors) + + if errors[v_best] > best_error: + best_error = errors[v_best] + best_point = candidates[v_best] + + current_point[p] = v_best # 次の dimension は最良値から出発 + + if best_error > abs_tol * tol_margin: + found_pivots.append(best_point) + +sort by error descending, dedup, truncate to max_pivots +``` + +### 型シグネチャ + +```rust +use crate::{GlobalIndexBatch, OwnedGlobalIndexBatch}; +use crate::materialize::FullPivLuScalar; +use tensor4all_tcicore::MultiIndex; +use tensor4all_treetn::TreeTN; +use tensor4all_core::TensorDynLen; +use rand::Rng; +use anyhow::Result; + +#[derive(Debug, Clone)] +pub struct DefaultTreeGlobalPivotFinder { + pub num_searches: usize, + pub max_pivots: usize, + pub tol_margin: f64, +} + +impl Default for DefaultTreeGlobalPivotFinder { + fn default() -> Self { + Self { + num_searches: 5, + max_pivots: 5, + tol_margin: 10.0, + } + } +} + +impl DefaultTreeGlobalPivotFinder { + pub fn new(num_searches: usize, max_pivots: usize, tol_margin: f64) -> Self { + Self { num_searches, max_pivots, tol_margin } + } + + /// Find global pivots where interpolation error is large. + /// + /// # Arguments + /// - `local_dims`: site dimensions + /// - `batch_eval`: the user's batch evaluation function (true values) + /// - `treetn`: materialized current approximation + /// - `abs_tol`: absolute tolerance threshold + /// - `rng`: random number generator + /// + /// # Returns + /// Multi-indices where |f(x) - approx(x)| > abs_tol * tol_margin + pub fn find_pivots( + &self, + local_dims: &[usize], + batch_eval: &F, + treetn: &TreeTN, + abs_tol: f64, + rng: &mut impl Rng, + ) -> Result> + where + T: FullPivLuScalar, + F: Fn(GlobalIndexBatch<'_>) -> Result>, + { + // ... implementation ... + } +} +``` + +### Batch 評価の実装詳細 + +dimension sweep の各 site `p` で `local_dims[p]` 個の候補点を作り、 +一括で `batch_eval` と `treetn.evaluate_at_site_indices` を呼ぶ。 + +```rust +// site p の sweep: local_dims[p] 個の候補 +let d = local_dims[p]; +let mut candidates: Vec> = Vec::with_capacity(d); +for v in 0..d { + let mut point = current_point.clone(); + point[p] = v; + candidates.push(point); +} + +// batch_eval で真値を取得 +let batch_data: Vec = candidates.iter() + .flat_map(|c| c.iter().copied()) + .collect(); +let batch = GlobalIndexBatch::new(&batch_data, n_sites, d)?; +let f_vals: Vec = batch_eval(batch)?; + +// treetn で近似値を取得 +let approx_vals: Vec = treetn.evaluate_at_site_indices(&candidates)?; + +// 誤差計算 +for i in 0..d { + let f_abs = T::abs_val(f_vals[i]); + let approx_abs = approx_vals[i].abs(); + let error = (f_abs - approx_abs).abs(); + // ... track best ... +} +``` + +注意: `f_vals` は `T` 型 (f64 or Complex64), `approx_vals` は `AnyScalar`。 +誤差は `|f(x)| - |approx(x)|` ではなく `|f(x) - approx(x)|` であるべき。 + +**Complex 対応の誤差計算:** + +```rust +// T から f64 への abs 変換 +let f_abs = f64::sqrt(T::abs_sq(f_vals[i])); +// AnyScalar からの abs +let approx_abs_re = approx_vals[i].real(); +let approx_abs_im = approx_vals[i].imag(); + +// 差の abs: |f - approx| +// f_vals[i] を re/im に分解する必要がある +// T: FullPivLuScalar は Scalar trait を持つので: +let f_re: f64 = /* T の実部取得 */; +let f_im: f64 = /* T の虚部取得 */; +let diff_re = f_re - approx_abs_re; +let diff_im = f_im - approx_abs_im; +let error = (diff_re * diff_re + diff_im * diff_im).sqrt(); +``` + +`Scalar` trait に `real_part()`, `imag_part()` -> f64 があるか要確認。 +なければ f64 の場合は `(val as f64, 0.0)`, Complex64 は `(val.re, val.im)` で +マッチする。 + +実装上は `T` の concrete type (`f64` or `Complex64`) でディスパッチする +ヘルパー関数を使う: + +```rust +fn scalar_to_re_im(val: T) -> (f64, f64) { + // T::abs_sq は f64 を返す (Scalar trait) + // T が f64 なら (val, 0.0) + // T が Complex64 なら (val.re, val.im) + // Scalar trait に直接 re/im アクセスがなければ: + // val を AnyScalar に変換して .real(), .imag() を使うのが安全 + todo!("check Scalar trait for re/im access") +} +``` + +→ 実装時に `Scalar` trait のメソッドを確認して適切な変換を選ぶ。 + +--- + +## 4. optimize ループへの組み込み + +### 型境界の変更 + +```rust +pub fn optimize_with_proposer( + state: &mut SimpleTreeTci, + batch_eval: F, + options: &TreeTciOptions, + proposer: &P, +) -> Result<(Vec, Vec)> +where + T: FullPivLuScalar, // was: Scalar (tightened for to_treetn) + DenseFaerLuKernel: PivotKernel, + F: Fn(GlobalIndexBatch<'_>) -> Result>, + P: PivotCandidateProposer, +``` + +`FullPivLuScalar: Scalar + TensorElement` なので上位互換。 +`f32, f64, Complex32, Complex64` に実装済み。 + +### ループ変更 + +```rust +let global_finder = if options.global_search_interval > 0 { + Some(DefaultTreeGlobalPivotFinder::new( + options.num_global_searches, + options.max_global_pivots, + options.global_search_tol_margin, + )) +} else { + None +}; +let mut rng = rand::rngs::StdRng::seed_from_u64(42); +let mut nglobal_pivots_history: Vec = Vec::new(); + +for iter in 0..options.max_iter { + // --- 既存の inner edge passes --- + for _pass in 0..INNER_EDGE_PASSES { + // ... update_edge ... + } + + ranks.push(state.max_rank()); + let normalized_error = /* 既存 */; + errors.push(normalized_error); + + // --- Global pivot search --- + let n_global = if let Some(ref finder) = global_finder { + if (iter + 1) % options.global_search_interval == 0 { + let error_scale = if options.normalize_error && state.max_sample_value > 0.0 { + state.max_sample_value + } else { + 1.0 + }; + let abs_tol = options.tolerance * error_scale; + + // Materialize current state + let treetn = to_treetn(state, &batch_eval, Some(0))?; + + // Find global pivots + let pivots = finder.find_pivots::( + &state.local_dims, + &batch_eval, + &treetn, + abs_tol, + &mut rng, + )?; + + let n = pivots.len(); + if !pivots.is_empty() { + state.add_global_pivots(&pivots)?; + } + n + } else { + 0 + } + } else { + 0 + }; + nglobal_pivots_history.push(n_global); + + // --- Early exit --- + if normalized_error < options.tolerance { + // global search が有効なら、直近の search で 0 pivots のとき終了 + if global_finder.is_none() || n_global == 0 { + break; + } + } +} +``` + +### 戻り値の変更 + +現在は `(Vec, Vec)` (ranks, errors)。 +Global pivot 情報も返すとデバッグに有用: + +```rust +/// Optimization result. +pub struct OptimizeResult { + pub ranks: Vec, + pub errors: Vec, + pub nglobal_pivots: Vec, +} +``` + +→ **破壊的変更**になるため、今回は `nglobal_pivots` は内部で使うだけにして +戻り値は変えない。将来 `OptimizeResult` 構造体に移行。 + +--- + +## 5. テスト + +### テスト関数: TCI.jl parity + +```rust +// crates/tensor4all-treetci/tests/global_pivot.rs + +/// Chain tree: 0--1--2--...--N-1 +fn chain_graph(n: usize) -> TreeTciGraph { ... } + +/// Quantics-like: bitlist [b0, b1, ...] → x = sum(b_i * 2^{-(i+1)}) ∈ [0, 1) +fn bits_to_x(bits: &[usize]) -> f64 { ... } + +/// TCI.jl "nasty function": f(x) = exp(-10x) * sin(2π * 100 * x^1.1) +fn nasty_function(x: f64) -> f64 { + (-10.0 * x).exp() * (2.0 * PI * 100.0 * x.powf(1.1)).sin() +} +``` + +### テストケース + +**1. global search あり vs なし の比較:** +- 10-bit chain (1024 grid points) +- `global_search_interval = 1` (毎イテレーション) vs `0` (無効) +- global search あり → error < 1e-6 を期待 +- global search なし → error がより大きい(または収束しない)を期待 + +**2. global search の interval 動作確認:** +- `global_search_interval = 3` で、3イテレーションごとに search が走ることを確認 + (nglobal_pivots_history をチェック) + +**3. 既存テストのリグレッション:** +- `global_search_interval = 0` (デフォルト) で既存の動作が変わらないことを確認 + +--- + +## 実装順序 + +| Task | 内容 | ファイル | +|------|------|---------| +| 1 | `TreeTN::evaluate_at_site_indices` | `crates/tensor4all-treetn/src/treetn/ops.rs` | +| 2 | `TreeTciOptions` パラメータ追加 | `crates/tensor4all-treetci/src/optimize.rs` | +| 3 | `DefaultTreeGlobalPivotFinder` | `crates/tensor4all-treetci/src/globalpivot.rs` + `lib.rs` | +| 4 | `optimize_with_proposer` 組み込み | `crates/tensor4all-treetci/src/optimize.rs` | +| 5 | テスト: nasty function | `crates/tensor4all-treetci/tests/global_pivot.rs` | +| 6 | 全テスト + clippy + fmt | validation | + +--- + +## 設計上の注意 + +1. **型境界 `Scalar` → `FullPivLuScalar`**: optimize の型境界を狭める。 + `FullPivLuScalar` は `f32, f64, Complex32, Complex64` に実装済みなので + 既存コードに影響なし。 + +2. **Materialization コスト**: `to_treetn` は全テンソルを再構築。 + `global_search_interval` で頻度制御。デフォルト 0 (無効) なので opt-in。 + +3. **Complex 対応**: `find_pivots` 内の誤差計算で `T` → `(re, im)` 変換が必要。 + `AnyScalar` は `.real()`, `.imag()` を持つ。`T` 側は Scalar trait の + メソッドで対応(実装時に確認)。 + +4. **RNG**: `StdRng::seed_from_u64(42)` で決定論的。 + +5. **戻り値**: 今回は `(Vec, Vec)` のまま変えない。 diff --git a/docs/specs/2026-03-31-api-expansion-design.md b/docs/specs/2026-03-31-api-expansion-design.md new file mode 100644 index 0000000..6619e57 --- /dev/null +++ b/docs/specs/2026-03-31-api-expansion-design.md @@ -0,0 +1,212 @@ +# Tensor4all.jl API Expansion Design + +## Goal + +Expand Tensor4all.jl to wrap the new C API functions (#393), unify indexing conventions to 1-indexed, align naming with Pure Julia ecosystem (TensorCrossInterpolation.jl, QuanticsTCI.jl), and add type conversion between SimpleTT and TreeTN. + +## Principles + +- All Julia-facing APIs use **1-indexed** (Julia convention). Internal C API calls convert to 0-indexed. +- Function names match Pure Julia libraries where equivalent functionality exists. +- No backward compatibility aliases. Breaking changes are acceptable. +- Explicit type conversions — no implicit magic. + +--- + +## A. Indexing Convention: 1-indexed everywhere + +### SimpleTT + +| Function | Current | Change | +|---|---|---| +| `evaluate(tt, indices)` | 0-indexed | 1-indexed; subtract 1 internally | +| `sitetensor(tt, site)` | 0-indexed | 1-indexed; subtract 1 internally | +| Callable `tt(indices...)` | 0-indexed | 1-indexed | + +### TreeTCI + +| Function | Current | Change | +|---|---|---| +| `initialpivots` in `crossinterpolate2` | 0-indexed | 1-indexed; subtract 1 internally | +| `evaluate(ttn, indices)` | 0-indexed | 1-indexed; subtract 1 internally | +| Callback `f(batch)` | 0-indexed | 1-indexed; add 1 when passing to Julia callback | + +### TreeTN + +No change — already 1-indexed. + +--- + +## B. Naming Alignment + +Rename to match Pure Julia convention (no underscores): + +| Current | New | Module | +|---|---|---| +| `link_dims` | `linkdims` | SimpleTT | +| `site_dims` | `sitedims` | SimpleTT | +| `site_tensor` | `sitetensor` | SimpleTT | +| `local_dimensions` | `localdimensions` | QuanticsGrids | +| `link_dims` | `linkdims` | QuanticsTCI | + +No deprecated aliases. Direct rename. + +--- + +## C. SimpleTT Basic Operations + +New functions in `src/SimpleTT.jl`. All support Float64 and ComplexF64. + +### Arithmetic (operator overloading) + +```julia +Base.:+(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) where T + → SimpleTensorTrain{T} # calls t4a_simplett_{f64,c64}_add + +Base.:-(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) where T + → SimpleTensorTrain{T} # add(a, scaled(b, -1)) + +Base.:*(α::Number, tt::SimpleTensorTrain{T}) where T + → SimpleTensorTrain{T} # clone + scale + +Base.:*(tt::SimpleTensorTrain{T}, α::Number) where T + → SimpleTensorTrain{T} # α * tt + +LinearAlgebra.dot(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) where T + → T # calls t4a_simplett_{f64,c64}_dot +``` + +### In-place + +```julia +scale!(tt::SimpleTensorTrain, α::Number) → tt + # calls t4a_simplett_{f64,c64}_scale +``` + +### Other operations + +```julia +Base.reverse(tt::SimpleTensorTrain{T}) where T → SimpleTensorTrain{T} + # calls t4a_simplett_{f64,c64}_reverse + +fulltensor(tt::SimpleTensorTrain{T}) where T → Array{T} + # calls t4a_simplett_{f64,c64}_fulltensor + # returns Array with dimensions = sitedims(tt) +``` + +### Constructor from site tensors + +```julia +SimpleTensorTrain(site_tensors::Vector{<:AbstractArray{T,3}}) where T<:Union{Float64,ComplexF64} + # calls t4a_simplett_{f64,c64}_from_site_tensors + # site_tensors[i] has shape (left_dim, site_dim, right_dim) +``` + +--- + +## D. SimpleTT ↔ TreeTN.MPS Conversion + +### In TreeTN.jl + +```julia +function MPS(tt::SimpleTensorTrain{T}) where T + # Extract site tensors from SimpleTT (1-indexed after B) + # Create Index objects for sites and links + # Build Tensor objects and call MPS(tensors::Vector{Tensor}) +end +``` + +### In SimpleTT.jl + +```julia +function SimpleTensorTrain(mps::TreeTensorNetwork{Int}) + # For each site 1..n, extract tensor data via TreeTN accessors + # Build 3D arrays (left, site, right) + # Call SimpleTensorTrain(site_tensors) +end +``` + +Both conversions are pure Julia — no new C API needed. They compose existing accessors. + +--- + +## E. QuanticsTCI Update + +### Type parameterization + +```julia +mutable struct QuanticsTensorCI2{V} + ptr::Ptr{Cvoid} + # V determines which C API functions to call (f64 vs c64) +end +``` + +### Updated main function + +```julia +function quanticscrossinterpolate( + ::Type{V}, f, grid::DiscretizedGrid; + tolerance=1e-8, maxbonddim=0, maxiter=200, + initialpivots=nothing, nrandominitpivot=5, + verbosity=0, unfoldingscheme=:interleaved, + nsearchglobalpivot=5, nsearch=100, + normalize_error=true +) where V <: Union{Float64, ComplexF64} + → (QuanticsTensorCI2{V}, Vector{Int}, Vector{Float64}) +end +``` + +### Overloads + +```julia +# Discrete domain (size tuple) +quanticscrossinterpolate(::Type{V}, f, size::NTuple{N,Int}; kwargs...) + +# From coordinate arrays +quanticscrossinterpolate(::Type{V}, f, xvals::Vector{Vector{Float64}}; kwargs...) + +# From dense array +quanticscrossinterpolate(F::Array{V}; kwargs...) +``` + +### New accessors + +```julia +max_bond_error(qtci::QuanticsTensorCI2) → Float64 +max_rank(qtci::QuanticsTensorCI2) → Int +``` + +### Internal: QtciOptions management + +Julia kwargs → create `t4a_qtci_options` handle → pass to C API → release after call. The options handle is ephemeral (not stored). + +--- + +## F. TCI Conversion Bridge + +New file: `ext/Tensor4allTCIExt.jl` + +Weak dependency on `TensorCrossInterpolation`. + +```julia +function SimpleTensorTrain(tt::TCI.TensorTrain{T}) where T + # Extract tt.sitetensors::Vector{Array{T,3}} + # Call SimpleTensorTrain(site_tensors) +end + +function TCI.TensorTrain(stt::SimpleTensorTrain{T}) where T + # Extract site tensors via sitetensor(stt, i) for i in 1:length(stt) + # Build TCI.TensorTrain from Vector{Array{T,3}} +end +``` + +--- + +## Scope Exclusions + +- QuanticsTransform extensions (multivar operators, etc.) +- TreeTCI Julia wrapper overhaul (beyond indexing fix) +- ITensorsExt updates +- QuanticsGrids naming beyond `localdimensions` +- `cachedata`, `quanticsfouriermpo` +- `compress!` for SimpleTT (already exists via C API) diff --git a/docs/specs/2026-04-06-quantics-rust-parity-design.md b/docs/specs/2026-04-06-quantics-rust-parity-design.md new file mode 100644 index 0000000..0276ca1 --- /dev/null +++ b/docs/specs/2026-04-06-quantics-rust-parity-design.md @@ -0,0 +1,182 @@ +# Quantics Rust Parity Design + +## Goal + +Make `Tensor4all.jl` expose the existing pure Rust quantics functionality through the C API and Julia wrapper, instead of rebuilding a legacy `TensorCI` compatibility layer. The immediate downstream target is `ReFrequenTT`, which needs grids, multivariable transforms, dimension embedding, and mixed boundary conditions more than it needs the old `TensorCI2` surface. + +## Decision + +Adopt `tensor4all-rs` as the source of truth for quantics functionality and make `Tensor4all.jl` a thin, ergonomic wrapper over that surface. + +This means: + +- Do not revive the old Rust-backed `TensorCI` wrapper. +- Do not make `TensorCrossInterpolation.jl` reexport masquerade as `Tensor4all.TensorCI`. +- Use `SimpleTT.SimpleTensorTrain` as the canonical Julia TT type inside `Tensor4all.jl`. +- Keep interoperability with external TT types in weak-dependency extensions. + +--- + +## A. Why This Direction + +Reviving `Tensor4all.TensorCI` would mostly solve naming continuity, but it would not solve the actual migration blockers for `ReFrequenTT`. The missing pieces there are: + +- multivariable affine transforms +- dimension-changing embeddings +- per-axis boundary conditions +- a clean path from quantics operators to Julia TT values + +Those capabilities belong to the quantics backend surface, not to a legacy `TensorCI2` facade. + +By contrast, `QuanticsGrids` is already close to the right architecture: Rust owns the implementation, the C API exposes opaque grid handles and conversion functions, and Julia wraps them. The correct direction is to finish that pattern across the quantics stack. + +--- + +## B. Public API Principles + +### 1. Rust/C API parity first + +If a capability already exists in `tensor4all-rs`, the Julia package should expose it with minimal semantic drift. Julia convenience is allowed, but the underlying feature boundary should stay aligned with Rust. + +### 2. Ergonomic top-level exports are secondary + +Top-level exports such as `DiscretizedGrid` or `InherentDiscreteGrid` improve usability, but they are not substitutes for missing backend exposure. Reexports should be treated as API polish, not architecture. + +### 3. One canonical TT type inside Tensor4all.jl + +`SimpleTT.SimpleTensorTrain` is the internal Julia TT representation that all quantics wrappers should produce and consume where possible. + +### 4. Interop lives in extensions + +Conversions to external tensor-train ecosystems should stay in weak-dependency extension files: + +- `TensorCrossInterpolation.TensorTrain <-> SimpleTensorTrain` +- future `ITensorLike.TensorTrain <-> SimpleTensorTrain`, if a concrete package surface exists + +This keeps the core package small and avoids coupling `Tensor4all.jl` to external API churn. + +--- + +## C. QuanticsGrids Scope + +`QuanticsGrids` is already C-API-backed, so the work here is parity and cleanup, not reinvention. + +Planned direction: + +- keep `DiscretizedGrid` and `InherentDiscreteGrid` as opaque Julia wrappers over Rust-owned objects +- expose the full unfolding enum that Rust already supports, including `:grouped` +- make the common grid accessors part of the normal Julia surface +- reexport the most common grid symbols from `Tensor4all` top level for ergonomics + +Recommended top-level exports: + +- `DiscretizedGrid` +- `InherentDiscreteGrid` +- `localdimensions` +- coordinate conversion helpers only if name collisions remain manageable + +Non-goal: + +- do not duplicate grid logic in Julia + +--- + +## D. QuanticsTransform Scope + +This is the core gap for `ReFrequenTT`. + +`tensor4all-rs` already has the relevant operator constructors in `tensor4all-quanticstransform`, including: + +- multivariable shift/flip/phase operators +- affine operators +- asymmetric input/output layouts for embedding +- boundary conditions per output variable + +The missing piece is exposure through `tensor4all-capi` and then `Tensor4all.jl`. + +Julia should expose: + +- `BoundaryCondition` +- `shift_operator`, `flip_operator`, `phase_rotation_operator` +- `shift_operator_multivar`, `flip_operator_multivar`, `phase_rotation_operator_multivar` +- `affine_operator` +- `binaryop_operator` +- `apply` + +Design requirements: + +- support mixed boundary conditions, e.g. frequency axes open and momentum axes periodic +- support dimension-preserving transforms such as `(nu, omega, k, q) -> (nu, nu+omega, k, k+q)` +- support dimension-changing embeddings such as `f(omega, q) -> g(nu, nup, k, kp) = f(nu-nup, k-kp)` +- preserve the current operator-application path onto `SimpleTensorTrain` + +This surface is the one that directly enables `ReFrequenTT`-style kernels. + +--- + +## E. Tensor Representation and Interoperability + +### Canonical representation + +Inside `Tensor4all.jl`, the canonical TT object remains `SimpleTT.SimpleTensorTrain`. + +Reasons: + +- it is already backed by Rust +- `QuanticsTCI` already converts into it +- arithmetic and partial reductions already exist on it +- it avoids making external packages part of the core abstraction boundary + +### TensorCrossInterpolation interoperability + +Keep the current weak-dep extension that converts between: + +- `TensorCrossInterpolation.TensorTrain` +- `Tensor4all.SimpleTT.SimpleTensorTrain` + +This is useful, but it should remain interoperability, not the primary API. + +### ITensorLike interoperability + +Do not design around `ITensorLike.TensorTrain` yet. + +There is no current `ITensorLike` module in this repository, so adding it now would create a second unfinished abstraction boundary. If a concrete downstream need appears, add a weak-dependency extension later with explicit constructors, mirroring the `TensorCrossInterpolation` pattern. + +--- + +## F. Migration Impact + +### What becomes smoother + +- `ReFrequenTT` can target `Tensor4all.jl` quantics primitives directly +- users can stay within one package family for grid, transform, and TT operations +- Julia wrappers stay stable as long as Rust/C API surfaces stay stable + +### What does not become smoother + +- code expecting the deleted Rust-backed `TensorCI2` API does not automatically recover +- a `TensorCrossInterpolation.jl` facade under the old `TensorCI` name is still likely to have behavioral mismatches + +### Migration recommendation + +For downstream ports, migrate toward: + +1. grid construction via `QuanticsGrids` +2. interpolation/compression via `QuanticsTCI` +3. TT data via `SimpleTensorTrain` +4. operator construction via `QuanticsTransform` + +Do not migrate toward a reintroduced `TensorCI` namespace unless a separate compatibility objective is explicitly accepted. + +--- + +## Scope Exclusions + +This design does not yet include: + +- a `TTFunction`-style high-level wrapper bundling `(grid, tt, logical variables)` +- variable-aware contraction APIs analogous to `BubbleTeaCI.BasicContractOrder` +- a compatibility facade that recreates the removed Rust `TensorCI2` wrapper +- a new `ITensorLike` module + +Those are valid follow-up layers, but they should be built on top of the Rust-parity quantics surface, not before it. diff --git a/docs/test-reports/test-feature-20260330-130000.md b/docs/test-reports/test-feature-20260330-130000.md new file mode 100644 index 0000000..b982386 --- /dev/null +++ b/docs/test-reports/test-feature-20260330-130000.md @@ -0,0 +1,60 @@ +# Feature Test Report: Tensor4all.jl TreeTCI + +**Date:** 2026-03-30 +**Project type:** Julia library (FFI wrapper for Rust tensor network library) +**Features tested:** TreeTCI (crossinterpolate2, to_treetn, verbosity) +**Profile:** ephemeral — computational physicist trying custom functions +**Use Case:** Approximate custom multi-variable functions (Gaussian+cross-terms, oscillatory) on non-chain tree topologies +**Expected Outcome:** Successfully approximate, obtain TreeTensorNetwork, verify point-wise accuracy +**Verdict:** pass +**Critical Issues:** 0 + +## Summary + +| Feature | Discoverable | Setup | Works | Expected Outcome Met | Doc Quality | +|---------|-------------|-------|-------|---------------------|-------------| +| TreeTCI crossinterpolate2 | partial | yes | yes | yes | missing batch format docs | +| to_treetn materialization | yes | yes | yes | yes | good | +| verbosity logging | yes | yes | yes | yes | good | +| Point-wise evaluation of TTN | no | n/a | no API | no | missing evaluate() | + +## Per-Feature Details + +### TreeTCI crossinterpolate2 +- **Role:** Computational physicist, experienced with TCI.jl +- **Use Case:** Approximate Gaussian+cross-terms and oscillatory functions on star/branching trees +- **What they tried:** 5-site star graph + Gaussian, 7-site branching + sin*exp +- **Discoverability:** Partial — docstrings show kwargs but batch callback format (n_sites, n_points, 0-based) not prominent +- **Setup:** Docker worked perfectly +- **Functionality:** Both functions converged in 1 iteration to machine precision (~1e-16 error) +- **Expected vs Actual Outcome:** Exceeded expectations — expected multiple iterations, got instant convergence +- **Blocked steps:** None +- **Friction points:** Batch callback matrix layout not obvious from user-facing docstring +- **Doc suggestions:** Document batch format prominently in crossinterpolate2 docstring + +### TreeTensorNetwork point-wise evaluation +- **Role:** Same +- **Use Case:** Verify TCI result at sample points +- **What they tried:** Called to_dense() and manually indexed since no evaluate() exists +- **Discoverability:** No — had to discover workaround by reading source +- **Setup:** n/a +- **Functionality:** Workaround via to_dense works but impractical for large tensors +- **Expected vs Actual Outcome:** Expected evaluate(ttn, indices) like SimpleTensorTrain has — not available +- **Blocked steps:** None +- **Friction points:** No evaluate() for TreeTensorNetwork; maxbonddim() throws on TTN output +- **Doc suggestions:** Add evaluate() function; document to_dense index ordering + +## Issues Found + +1. **[Medium] No evaluate() for TreeTensorNetwork** — SimpleTensorTrain has evaluate()/callable syntax, TTN lacks both. Users must use to_dense() workaround. +2. **[Medium] Batch callback format underdocumented** — crossinterpolate2 docstring says `f(batch::Matrix{Csize_t}) -> Vector{T}` but doesn't specify (n_sites, n_points) column-major layout or 0-based indices prominently. +3. **[Low] maxbonddim() throws on TTN from to_treetn** — "Invalid argument" error, appears MPS-specific but is exported generically. +4. **[Low] to_dense() index ordering undocumented** — Users can't know which dimension corresponds to which site without inspecting Index IDs. + +## Suggestions + +1. Add `evaluate(ttn, indices)` for TreeTensorNetwork (highest impact for usability) +2. Expand crossinterpolate2 docstring with explicit batch callback format +3. Add "Getting Started with TreeTCI" example to README +4. Fix or document maxbonddim limitation for general TTNs +5. Document to_dense index ordering guarantee diff --git a/ext/Tensor4allTCIExt.jl b/ext/Tensor4allTCIExt.jl new file mode 100644 index 0000000..2df1362 --- /dev/null +++ b/ext/Tensor4allTCIExt.jl @@ -0,0 +1,40 @@ +""" + Tensor4allTCIExt + +Extension module providing bidirectional conversion between +Tensor4all.SimpleTT.SimpleTensorTrain and TensorCrossInterpolation.TensorTrain. +""" +module Tensor4allTCIExt + +using Tensor4all +using TensorCrossInterpolation + +import Tensor4all.SimpleTT: SimpleTensorTrain, sitetensor + +""" + SimpleTensorTrain(tt::TensorCrossInterpolation.TensorTrain{V,3}) where V + +Convert a TensorCrossInterpolation.TensorTrain to a SimpleTensorTrain. + +Extracts the site tensors (each of shape (left, site, right)) and +constructs a SimpleTensorTrain from them. +""" +function Tensor4all.SimpleTT.SimpleTensorTrain(tt::TensorCrossInterpolation.TensorTrain{V,3}) where V + return SimpleTensorTrain(tt.sitetensors) +end + +""" + TensorCrossInterpolation.TensorTrain(stt::SimpleTensorTrain{T}) where T + +Convert a SimpleTensorTrain to a TensorCrossInterpolation.TensorTrain. + +Extracts site tensors via sitetensor(stt, i) for each site and +constructs a TensorCrossInterpolation.TensorTrain from them. +""" +function TensorCrossInterpolation.TensorTrain(stt::SimpleTensorTrain{T}) where T + n = length(stt) + site_tensors = [sitetensor(stt, i) for i in 1:n] + return TensorCrossInterpolation.TensorTrain(site_tensors) +end + +end # module Tensor4allTCIExt diff --git a/src/C_API.jl b/src/C_API.jl index 9171a84..6a2ef26 100644 --- a/src/C_API.jl +++ b/src/C_API.jl @@ -466,19 +466,18 @@ function t4a_tensor_get_data_f64(ptr::Ptr{Cvoid}, buf, buf_len::Integer, out_len end """ - t4a_tensor_get_data_c64(ptr::Ptr{Cvoid}, buf_re, buf_im, buf_len::Integer, out_len::Ref{Csize_t}) -> Cint + t4a_tensor_get_data_c64(ptr::Ptr{Cvoid}, buf, buf_len::Integer, out_len::Ref{Csize_t}) -> Cint -Get dense complex64 data from a tensor in column-major order. -If buf_re or buf_im is C_NULL, only out_len is written (to query required length). +Get dense complex64 data from a tensor in column-major interleaved order. +If `buf` is `C_NULL`, only `out_len` is written (to query required length). """ -function t4a_tensor_get_data_c64(ptr::Ptr{Cvoid}, buf_re, buf_im, buf_len::Integer, out_len::Ref{Csize_t}) +function t4a_tensor_get_data_c64(ptr::Ptr{Cvoid}, buf, buf_len::Integer, out_len::Ref{Csize_t}) return ccall( _sym(:t4a_tensor_get_data_c64), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}), + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}), ptr, - buf_re === nothing ? C_NULL : buf_re, - buf_im === nothing ? C_NULL : buf_im, + buf === nothing ? C_NULL : buf, Csize_t(buf_len), out_len ) @@ -507,22 +506,20 @@ function t4a_tensor_new_dense_f64(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, end """ - t4a_tensor_new_dense_c64(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, dims::Vector{Csize_t}, data_re::Vector{Cdouble}, data_im::Vector{Cdouble}) -> Ptr{Cvoid} + t4a_tensor_new_dense_c64(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, dims::Vector{Csize_t}, data::Vector{Cdouble}) -> Ptr{Cvoid} -Create a new dense complex64 tensor from indices and real/imag data in column-major order. +Create a new dense complex64 tensor from indices and interleaved data in column-major order. """ -function t4a_tensor_new_dense_c64(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, dims::Vector{Csize_t}, data_re::Vector{Cdouble}, data_im::Vector{Cdouble}) - @assert length(data_re) == length(data_im) "Real and imaginary data must have same length" +function t4a_tensor_new_dense_c64(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, dims::Vector{Csize_t}, data::Vector{Cdouble}) return ccall( _sym(:t4a_tensor_new_dense_c64), Ptr{Cvoid}, - (Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Cdouble}, Csize_t), + (Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Csize_t), Csize_t(rank), index_ptrs, dims, - data_re, - data_im, - Csize_t(length(data_re)) + data, + Csize_t(length(data) ÷ 2) ) end @@ -1241,152 +1238,179 @@ function t4a_simplett_f64_site_tensor( end # ============================================================================ -# TensorCI2 lifecycle functions +# SimpleTT f64 operations # ============================================================================ """ - t4a_tci2_f64_release(ptr::Ptr{Cvoid}) + t4a_simplett_f64_compress(ptr, method, tolerance, max_bonddim) -> Cint -Release a TensorCI2 object. +Compress a SimpleTT tensor train in-place. +method: 0=SVD, 1=LU, 2=CI """ -function t4a_tci2_f64_release(ptr::Ptr{Cvoid}) - ptr == C_NULL && return - ccall( - _sym(:t4a_tci2_f64_release), - Cvoid, - (Ptr{Cvoid},), - ptr +function t4a_simplett_f64_compress(ptr::Ptr{Cvoid}, method::Integer, tolerance::Float64, max_bonddim::Integer) + return ccall( + _sym(:t4a_simplett_f64_compress), + Cint, + (Ptr{Cvoid}, Cint, Cdouble, Csize_t), + ptr, + Cint(method), + tolerance, + Csize_t(max_bonddim) ) end """ - t4a_tci2_f64_new(local_dims::Vector{Csize_t}) -> Ptr{Cvoid} + t4a_simplett_f64_partial_sum(ptr, dims, n_dims, out) -> Cint -Create a new TensorCI2 object. +Compute a partial sum over specified dimensions. """ -function t4a_tci2_f64_new(local_dims::Vector{Csize_t}) +function t4a_simplett_f64_partial_sum(ptr::Ptr{Cvoid}, dims, n_dims::Integer, out) return ccall( - _sym(:t4a_tci2_f64_new), - Ptr{Cvoid}, - (Ptr{Csize_t}, Csize_t), - local_dims, - Csize_t(length(local_dims)) + _sym(:t4a_simplett_f64_partial_sum), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Ptr{Cvoid}}), + ptr, + dims, + Csize_t(n_dims), + out ) end -# ============================================================================ -# TensorCI2 accessors -# ============================================================================ - """ - t4a_tci2_f64_len(ptr::Ptr{Cvoid}, out_len::Ref{Csize_t}) -> Cint + t4a_simplett_f64_from_site_tensors(n_sites, left_dims, site_dims, right_dims, data, data_len, out_ptr) -> Cint -Get the number of sites. +Create a SimpleTT tensor train from site tensor data. """ -function t4a_tci2_f64_len(ptr::Ptr{Cvoid}, out_len::Ref{Csize_t}) +function t4a_simplett_f64_from_site_tensors( + n_sites::Integer, + left_dims, + site_dims, + right_dims, + data, + data_len::Integer, + out_ptr +) return ccall( - _sym(:t4a_tci2_f64_len), + _sym(:t4a_simplett_f64_from_site_tensors), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}), - ptr, - out_len + (Csize_t, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Cdouble}, Csize_t, Ptr{Ptr{Cvoid}}), + Csize_t(n_sites), + left_dims, + site_dims, + right_dims, + data, + Csize_t(data_len), + out_ptr ) end """ - t4a_tci2_f64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) -> Cint + t4a_simplett_f64_add(a, b, out) -> Cint -Get the current rank (maximum bond dimension). +Add two SimpleTT tensor trains. """ -function t4a_tci2_f64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) +function t4a_simplett_f64_add(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out) return ccall( - _sym(:t4a_tci2_f64_rank), + _sym(:t4a_simplett_f64_add), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}), - ptr, - out_rank + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), + a, + b, + out ) end """ - t4a_tci2_f64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}) -> Cint + t4a_simplett_f64_scale(ptr, factor) -> Cint -Get the link (bond) dimensions. +Scale a SimpleTT tensor train in-place. """ -function t4a_tci2_f64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}) +function t4a_simplett_f64_scale(ptr::Ptr{Cvoid}, factor::Float64) return ccall( - _sym(:t4a_tci2_f64_link_dims), + _sym(:t4a_simplett_f64_scale), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + (Ptr{Cvoid}, Cdouble), ptr, - out_dims, - Csize_t(length(out_dims)) + factor ) end """ - t4a_tci2_f64_max_sample_value(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) -> Cint + t4a_simplett_f64_dot(a, b, out_value) -> Cint -Get the maximum sample value encountered. +Compute the dot product of two SimpleTT tensor trains. """ -function t4a_tci2_f64_max_sample_value(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) +function t4a_simplett_f64_dot(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out_value::Ref{Cdouble}) return ccall( - _sym(:t4a_tci2_f64_max_sample_value), + _sym(:t4a_simplett_f64_dot), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}), - ptr, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cdouble}), + a, + b, out_value ) end """ - t4a_tci2_f64_max_bond_error(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) -> Cint + t4a_simplett_f64_reverse(ptr, out) -> Cint -Get the maximum bond error from the last sweep. +Reverse the site ordering of a SimpleTT tensor train. """ -function t4a_tci2_f64_max_bond_error(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) +function t4a_simplett_f64_reverse(ptr::Ptr{Cvoid}, out) return ccall( - _sym(:t4a_tci2_f64_max_bond_error), + _sym(:t4a_simplett_f64_reverse), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}), + (Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), ptr, - out_value + out ) end -# ============================================================================ -# TensorCI2 pivot operations -# ============================================================================ - """ - t4a_tci2_f64_add_global_pivots(ptr::Ptr{Cvoid}, pivots::Vector{Csize_t}, n_pivots::Integer, n_sites::Integer) -> Cint + t4a_simplett_f64_fulltensor(ptr, out_data, buf_len, out_data_len) -> Cint -Add global pivots to the TCI. Pivots are stored as a flat array. +Convert SimpleTT to a full tensor (dense array). +If out_data is C_NULL, only out_data_len is written (query mode). """ -function t4a_tci2_f64_add_global_pivots(ptr::Ptr{Cvoid}, pivots::Vector{Csize_t}, n_pivots::Integer, n_sites::Integer) +function t4a_simplett_f64_fulltensor(ptr::Ptr{Cvoid}, out_data, buf_len::Integer, out_data_len) return ccall( - _sym(:t4a_tci2_f64_add_global_pivots), + _sym(:t4a_simplett_f64_fulltensor), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t), + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}), ptr, - pivots, - Csize_t(n_pivots), - Csize_t(n_sites) + out_data, + Csize_t(buf_len), + out_data_len ) end # ============================================================================ -# TensorCI2 conversion +# SimpleTT c64 (complex) lifecycle functions # ============================================================================ """ - t4a_tci2_f64_to_tensor_train(ptr::Ptr{Cvoid}) -> Ptr{Cvoid} + t4a_simplett_c64_release(ptr::Ptr{Cvoid}) -Convert a TCI to a SimpleTT TensorTrain. +Release a complex SimpleTT tensor train. """ -function t4a_tci2_f64_to_tensor_train(ptr::Ptr{Cvoid}) +function t4a_simplett_c64_release(ptr::Ptr{Cvoid}) + ptr == C_NULL && return + ccall( + _sym(:t4a_simplett_c64_release), + Cvoid, + (Ptr{Cvoid},), + ptr + ) +end + +""" + t4a_simplett_c64_clone(ptr::Ptr{Cvoid}) -> Ptr{Cvoid} + +Clone a complex SimpleTT tensor train. +""" +function t4a_simplett_c64_clone(ptr::Ptr{Cvoid}) return ccall( - _sym(:t4a_tci2_f64_to_tensor_train), + _sym(:t4a_simplett_c64_clone), Ptr{Cvoid}, (Ptr{Cvoid},), ptr @@ -1394,176 +1418,658 @@ function t4a_tci2_f64_to_tensor_train(ptr::Ptr{Cvoid}) end # ============================================================================ -# TensorCI2 high-level crossinterpolate2 +# SimpleTT c64 constructors # ============================================================================ -# Callback type for evaluation function -# signature: (indices_ptr, n_indices, result_ptr, user_data) -> status -const EvalCallback = Ptr{Cvoid} +""" + t4a_simplett_c64_constant(site_dims, value_re, value_im) -> Ptr{Cvoid} +Create a constant complex SimpleTT tensor train. """ - t4a_crossinterpolate2_f64(local_dims, initial_pivots, n_initial_pivots, - eval_fn, user_data, tolerance, max_bonddim, max_iter, - out_tci, out_final_error) -> Cint +function t4a_simplett_c64_constant(site_dims::Vector{Csize_t}, value_re::Float64, value_im::Float64) + return ccall( + _sym(:t4a_simplett_c64_constant), + Ptr{Cvoid}, + (Ptr{Csize_t}, Csize_t, Cdouble, Cdouble), + site_dims, + Csize_t(length(site_dims)), + value_re, + value_im + ) +end -Perform cross interpolation of a function. """ -function t4a_crossinterpolate2_f64( - local_dims::Vector{Csize_t}, - initial_pivots::Vector{Csize_t}, # flat array - n_initial_pivots::Integer, - eval_fn::Ptr{Cvoid}, - user_data::Ptr{Cvoid}, - tolerance::Float64, - max_bonddim::Integer, - max_iter::Integer, - out_tci::Ref{Ptr{Cvoid}}, - out_final_error::Ref{Cdouble} -) + t4a_simplett_c64_zeros(site_dims) -> Ptr{Cvoid} + +Create a zero complex SimpleTT tensor train. +""" +function t4a_simplett_c64_zeros(site_dims::Vector{Csize_t}) return ccall( - _sym(:t4a_crossinterpolate2_f64), - Cint, - (Ptr{Csize_t}, Csize_t, Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, - Cdouble, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Cdouble}), - local_dims, - Csize_t(length(local_dims)), - isempty(initial_pivots) ? C_NULL : initial_pivots, - Csize_t(n_initial_pivots), - eval_fn, - user_data, - tolerance, - Csize_t(max_bonddim), - Csize_t(max_iter), - out_tci, - out_final_error + _sym(:t4a_simplett_c64_zeros), + Ptr{Cvoid}, + (Ptr{Csize_t}, Csize_t), + site_dims, + Csize_t(length(site_dims)) ) end # ============================================================================ -# HDF5 initialization and save/load functions +# SimpleTT c64 accessors # ============================================================================ """ - t4a_hdf5_init(library_path::AbstractString) -> Cint + t4a_simplett_c64_len(ptr, out_len) -> Cint -Initialize the HDF5 library by loading it from the specified path. -This must be called before using any HDF5 functions. - -# Arguments -- `library_path`: Path to the HDF5 shared library (e.g., libhdf5.so or libhdf5.dylib) - -# Returns -`T4A_SUCCESS` (0) on success, or an error code on failure. +Get the number of sites. """ -function t4a_hdf5_init(library_path::AbstractString) +function t4a_simplett_c64_len(ptr::Ptr{Cvoid}, out_len::Ref{Csize_t}) return ccall( - _sym(:t4a_hdf5_init), + _sym(:t4a_simplett_c64_len), Cint, - (Cstring,), - library_path + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, + out_len ) end """ - t4a_hdf5_is_initialized() -> Bool + t4a_simplett_c64_site_dims(ptr, out_dims) -> Cint -Check if the HDF5 library is initialized and ready to use. +Get the site dimensions. """ -function t4a_hdf5_is_initialized() +function t4a_simplett_c64_site_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}) return ccall( - _sym(:t4a_hdf5_is_initialized), - Bool, - () + _sym(:t4a_simplett_c64_site_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + ptr, + out_dims, + Csize_t(length(out_dims)) ) end """ - t4a_hdf5_save_itensor(filepath::AbstractString, name::AbstractString, tensor::Ptr{Cvoid}) -> Cint + t4a_simplett_c64_link_dims(ptr, out_dims) -> Cint -Save a tensor as an ITensors.jl-compatible ITensor in an HDF5 file. +Get the link (bond) dimensions. """ -function t4a_hdf5_save_itensor(filepath::AbstractString, name::AbstractString, tensor::Ptr{Cvoid}) +function t4a_simplett_c64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}) return ccall( - _sym(:t4a_hdf5_save_itensor), + _sym(:t4a_simplett_c64_link_dims), Cint, - (Cstring, Cstring, Ptr{Cvoid}), - filepath, - name, - tensor + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + ptr, + out_dims, + Csize_t(length(out_dims)) ) end """ - t4a_hdf5_load_itensor(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) -> Cint + t4a_simplett_c64_rank(ptr, out_rank) -> Cint -Load a tensor from an ITensors.jl-compatible ITensor in an HDF5 file. +Get the maximum bond dimension (rank). """ -function t4a_hdf5_load_itensor(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) +function t4a_simplett_c64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) return ccall( - _sym(:t4a_hdf5_load_itensor), + _sym(:t4a_simplett_c64_rank), Cint, - (Cstring, Cstring, Ptr{Ptr{Cvoid}}), - filepath, - name, - out + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, + out_rank ) end """ - t4a_hdf5_save_mps(filepath::AbstractString, name::AbstractString, ttn::Ptr{Cvoid}) -> Cint + t4a_simplett_c64_evaluate(ptr, indices, out_value_re, out_value_im) -> Cint -Save a tree tensor network (MPS) as an ITensorMPS.jl-compatible MPS in an HDF5 file. +Evaluate the complex tensor train at a given multi-index. """ -function t4a_hdf5_save_mps(filepath::AbstractString, name::AbstractString, ttn::Ptr{Cvoid}) +function t4a_simplett_c64_evaluate(ptr::Ptr{Cvoid}, indices::Vector{Csize_t}, out_value_re::Ref{Cdouble}, out_value_im::Ref{Cdouble}) return ccall( - _sym(:t4a_hdf5_save_mps), + _sym(:t4a_simplett_c64_evaluate), Cint, - (Cstring, Cstring, Ptr{Cvoid}), - filepath, - name, - ttn + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}), + ptr, + indices, + Csize_t(length(indices)), + out_value_re, + out_value_im ) end """ - t4a_hdf5_load_mps(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) -> Cint + t4a_simplett_c64_sum(ptr, out_value_re, out_value_im) -> Cint -Load a tree tensor network (MPS) from an ITensorMPS.jl-compatible MPS in an HDF5 file. -Returns a `t4a_treetn` handle. +Compute the sum over all indices (complex). """ -function t4a_hdf5_load_mps(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) +function t4a_simplett_c64_sum(ptr::Ptr{Cvoid}, out_value_re::Ref{Cdouble}, out_value_im::Ref{Cdouble}) return ccall( - _sym(:t4a_hdf5_load_mps), + _sym(:t4a_simplett_c64_sum), Cint, - (Cstring, Cstring, Ptr{Ptr{Cvoid}}), - filepath, - name, - out + (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), + ptr, + out_value_re, + out_value_im ) end -# ============================================================================ -# QuanticsGrids: DiscretizedGrid functions -# ============================================================================ - -# Unfolding scheme enum (must match Rust side) -const UNFOLDING_FUSED = Cint(0) -const UNFOLDING_INTERLEAVED = Cint(1) - """ - t4a_qgrid_disc_new(ndims, rs_arr, lower_arr, upper_arr, unfolding, out) -> Cint + t4a_simplett_c64_norm(ptr, out_value) -> Cint + +Compute the Frobenius norm (real-valued). """ -function t4a_qgrid_disc_new(ndims::Integer, rs_arr, lower_arr, upper_arr, unfolding::Integer, out) +function t4a_simplett_c64_norm(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) return ccall( - _sym(:t4a_qgrid_disc_new), + _sym(:t4a_simplett_c64_norm), Cint, - (Csize_t, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Cdouble}, Cint, Ptr{Ptr{Cvoid}}), - Csize_t(ndims), - rs_arr, - lower_arr, - upper_arr, - Cint(unfolding), - out + (Ptr{Cvoid}, Ptr{Cdouble}), + ptr, + out_value + ) +end + +""" + t4a_simplett_c64_site_tensor(ptr, site, out_data, buf_len, out_left_dim, out_site_dim, out_right_dim) -> Cint + +Get site tensor data at a specific site (interleaved re/im pairs). +Buffer must hold 2 * left_dim * site_dim * right_dim doubles. +""" +function t4a_simplett_c64_site_tensor( + ptr::Ptr{Cvoid}, + site::Integer, + out_data::Vector{Cdouble}, + out_left_dim::Ref{Csize_t}, + out_site_dim::Ref{Csize_t}, + out_right_dim::Ref{Csize_t} +) + return ccall( + _sym(:t4a_simplett_c64_site_tensor), + Cint, + (Ptr{Cvoid}, Csize_t, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}), + ptr, + Csize_t(site), + out_data, + Csize_t(length(out_data)), + out_left_dim, + out_site_dim, + out_right_dim + ) +end + +# ============================================================================ +# SimpleTT c64 operations +# ============================================================================ + +""" + t4a_simplett_c64_compress(ptr, method, tolerance, max_bonddim) -> Cint + +Compress a complex SimpleTT tensor train in-place. +method: 0=SVD, 1=LU, 2=CI +""" +function t4a_simplett_c64_compress(ptr::Ptr{Cvoid}, method::Integer, tolerance::Float64, max_bonddim::Integer) + return ccall( + _sym(:t4a_simplett_c64_compress), + Cint, + (Ptr{Cvoid}, Cint, Cdouble, Csize_t), + ptr, + Cint(method), + tolerance, + Csize_t(max_bonddim) + ) +end + +""" + t4a_simplett_c64_partial_sum(ptr, dims, n_dims, out) -> Cint + +Compute a partial sum over specified dimensions (complex). +""" +function t4a_simplett_c64_partial_sum(ptr::Ptr{Cvoid}, dims, n_dims::Integer, out) + return ccall( + _sym(:t4a_simplett_c64_partial_sum), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Ptr{Cvoid}}), + ptr, + dims, + Csize_t(n_dims), + out + ) +end + +""" + t4a_simplett_c64_from_site_tensors(n_sites, left_dims, site_dims, right_dims, data, data_len, out_ptr) -> Cint + +Create a complex SimpleTT tensor train from site tensor data (interleaved re/im pairs). +""" +function t4a_simplett_c64_from_site_tensors( + n_sites::Integer, + left_dims, + site_dims, + right_dims, + data, + data_len::Integer, + out_ptr +) + return ccall( + _sym(:t4a_simplett_c64_from_site_tensors), + Cint, + (Csize_t, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Cdouble}, Csize_t, Ptr{Ptr{Cvoid}}), + Csize_t(n_sites), + left_dims, + site_dims, + right_dims, + data, + Csize_t(data_len), + out_ptr + ) +end + +""" + t4a_simplett_c64_add(a, b, out) -> Cint + +Add two complex SimpleTT tensor trains. +""" +function t4a_simplett_c64_add(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out) + return ccall( + _sym(:t4a_simplett_c64_add), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), + a, + b, + out + ) +end + +""" + t4a_simplett_c64_scale(ptr, factor_re, factor_im) -> Cint + +Scale a complex SimpleTT tensor train in-place. +""" +function t4a_simplett_c64_scale(ptr::Ptr{Cvoid}, factor_re::Float64, factor_im::Float64) + return ccall( + _sym(:t4a_simplett_c64_scale), + Cint, + (Ptr{Cvoid}, Cdouble, Cdouble), + ptr, + factor_re, + factor_im + ) +end + +""" + t4a_simplett_c64_dot(a, b, out_re, out_im) -> Cint + +Compute the dot product of two complex SimpleTT tensor trains. +""" +function t4a_simplett_c64_dot(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out_re::Ref{Cdouble}, out_im::Ref{Cdouble}) + return ccall( + _sym(:t4a_simplett_c64_dot), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), + a, + b, + out_re, + out_im + ) +end + +""" + t4a_simplett_c64_reverse(ptr, out) -> Cint + +Reverse the site ordering of a complex SimpleTT tensor train. +""" +function t4a_simplett_c64_reverse(ptr::Ptr{Cvoid}, out) + return ccall( + _sym(:t4a_simplett_c64_reverse), + Cint, + (Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), + ptr, + out + ) +end + +""" + t4a_simplett_c64_fulltensor(ptr, out_data, buf_len, out_data_len) -> Cint + +Convert complex SimpleTT to a full tensor (interleaved re/im pairs). +If out_data is C_NULL, only out_data_len is written (query mode). +""" +function t4a_simplett_c64_fulltensor(ptr::Ptr{Cvoid}, out_data, buf_len::Integer, out_data_len) + return ccall( + _sym(:t4a_simplett_c64_fulltensor), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Csize_t}), + ptr, + out_data, + Csize_t(buf_len), + out_data_len + ) +end + +# ============================================================================ +# TensorCI2 lifecycle functions +# ============================================================================ + +""" + t4a_tci2_f64_release(ptr::Ptr{Cvoid}) + +Release a TensorCI2 object. +""" +function t4a_tci2_f64_release(ptr::Ptr{Cvoid}) + ptr == C_NULL && return + ccall( + _sym(:t4a_tci2_f64_release), + Cvoid, + (Ptr{Cvoid},), + ptr + ) +end + +""" + t4a_tci2_f64_new(local_dims::Vector{Csize_t}) -> Ptr{Cvoid} + +Create a new TensorCI2 object. +""" +function t4a_tci2_f64_new(local_dims::Vector{Csize_t}) + return ccall( + _sym(:t4a_tci2_f64_new), + Ptr{Cvoid}, + (Ptr{Csize_t}, Csize_t), + local_dims, + Csize_t(length(local_dims)) + ) +end + +# ============================================================================ +# TensorCI2 accessors +# ============================================================================ + +""" + t4a_tci2_f64_len(ptr::Ptr{Cvoid}, out_len::Ref{Csize_t}) -> Cint + +Get the number of sites. +""" +function t4a_tci2_f64_len(ptr::Ptr{Cvoid}, out_len::Ref{Csize_t}) + return ccall( + _sym(:t4a_tci2_f64_len), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, + out_len + ) +end + +""" + t4a_tci2_f64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) -> Cint + +Get the current rank (maximum bond dimension). +""" +function t4a_tci2_f64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) + return ccall( + _sym(:t4a_tci2_f64_rank), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, + out_rank + ) +end + +""" + t4a_tci2_f64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}) -> Cint + +Get the link (bond) dimensions. +""" +function t4a_tci2_f64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}) + return ccall( + _sym(:t4a_tci2_f64_link_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + ptr, + out_dims, + Csize_t(length(out_dims)) + ) +end + +""" + t4a_tci2_f64_max_sample_value(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) -> Cint + +Get the maximum sample value encountered. +""" +function t4a_tci2_f64_max_sample_value(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) + return ccall( + _sym(:t4a_tci2_f64_max_sample_value), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), + ptr, + out_value + ) +end + +""" + t4a_tci2_f64_max_bond_error(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) -> Cint + +Get the maximum bond error from the last sweep. +""" +function t4a_tci2_f64_max_bond_error(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) + return ccall( + _sym(:t4a_tci2_f64_max_bond_error), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), + ptr, + out_value + ) +end + +# ============================================================================ +# TensorCI2 pivot operations +# ============================================================================ + +""" + t4a_tci2_f64_add_global_pivots(ptr::Ptr{Cvoid}, pivots::Vector{Csize_t}, n_pivots::Integer, n_sites::Integer) -> Cint + +Add global pivots to the TCI. Pivots are stored as a flat array. +""" +function t4a_tci2_f64_add_global_pivots(ptr::Ptr{Cvoid}, pivots::Vector{Csize_t}, n_pivots::Integer, n_sites::Integer) + return ccall( + _sym(:t4a_tci2_f64_add_global_pivots), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t), + ptr, + pivots, + Csize_t(n_pivots), + Csize_t(n_sites) + ) +end + +# ============================================================================ +# TensorCI2 conversion +# ============================================================================ + +""" + t4a_tci2_f64_to_tensor_train(ptr::Ptr{Cvoid}) -> Ptr{Cvoid} + +Convert a TCI to a SimpleTT TensorTrain. +""" +function t4a_tci2_f64_to_tensor_train(ptr::Ptr{Cvoid}) + return ccall( + _sym(:t4a_tci2_f64_to_tensor_train), + Ptr{Cvoid}, + (Ptr{Cvoid},), + ptr + ) +end + +# ============================================================================ +# TensorCI2 high-level crossinterpolate2 +# ============================================================================ + +# Callback type for evaluation function +# signature: (indices_ptr, n_indices, result_ptr, user_data) -> status +const EvalCallback = Ptr{Cvoid} + +""" + t4a_crossinterpolate2_f64(local_dims, initial_pivots, n_initial_pivots, + eval_fn, user_data, tolerance, max_bonddim, max_iter, + out_tci, out_final_error) -> Cint + +Perform cross interpolation of a function. +""" +function t4a_crossinterpolate2_f64( + local_dims::Vector{Csize_t}, + initial_pivots::Vector{Csize_t}, # flat array + n_initial_pivots::Integer, + eval_fn::Ptr{Cvoid}, + user_data::Ptr{Cvoid}, + tolerance::Float64, + max_bonddim::Integer, + max_iter::Integer, + out_tci::Ref{Ptr{Cvoid}}, + out_final_error::Ref{Cdouble} +) + return ccall( + _sym(:t4a_crossinterpolate2_f64), + Cint, + (Ptr{Csize_t}, Csize_t, Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, + Cdouble, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Cdouble}), + local_dims, + Csize_t(length(local_dims)), + isempty(initial_pivots) ? C_NULL : initial_pivots, + Csize_t(n_initial_pivots), + eval_fn, + user_data, + tolerance, + Csize_t(max_bonddim), + Csize_t(max_iter), + out_tci, + out_final_error + ) +end + +# ============================================================================ +# HDF5 initialization and save/load functions +# ============================================================================ + +""" + t4a_hdf5_init(library_path::AbstractString) -> Cint + +Initialize the HDF5 library by loading it from the specified path. +This must be called before using any HDF5 functions. + +# Arguments +- `library_path`: Path to the HDF5 shared library (e.g., libhdf5.so or libhdf5.dylib) + +# Returns +`T4A_SUCCESS` (0) on success, or an error code on failure. +""" +function t4a_hdf5_init(library_path::AbstractString) + return ccall( + _sym(:t4a_hdf5_init), + Cint, + (Cstring,), + library_path + ) +end + +""" + t4a_hdf5_is_initialized() -> Bool + +Check if the HDF5 library is initialized and ready to use. +""" +function t4a_hdf5_is_initialized() + return ccall( + _sym(:t4a_hdf5_is_initialized), + Bool, + () + ) +end + +""" + t4a_hdf5_save_itensor(filepath::AbstractString, name::AbstractString, tensor::Ptr{Cvoid}) -> Cint + +Save a tensor as an ITensors.jl-compatible ITensor in an HDF5 file. +""" +function t4a_hdf5_save_itensor(filepath::AbstractString, name::AbstractString, tensor::Ptr{Cvoid}) + return ccall( + _sym(:t4a_hdf5_save_itensor), + Cint, + (Cstring, Cstring, Ptr{Cvoid}), + filepath, + name, + tensor + ) +end + +""" + t4a_hdf5_load_itensor(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) -> Cint + +Load a tensor from an ITensors.jl-compatible ITensor in an HDF5 file. +""" +function t4a_hdf5_load_itensor(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) + return ccall( + _sym(:t4a_hdf5_load_itensor), + Cint, + (Cstring, Cstring, Ptr{Ptr{Cvoid}}), + filepath, + name, + out + ) +end + +""" + t4a_hdf5_save_mps(filepath::AbstractString, name::AbstractString, ttn::Ptr{Cvoid}) -> Cint + +Save a tree tensor network (MPS) as an ITensorMPS.jl-compatible MPS in an HDF5 file. +""" +function t4a_hdf5_save_mps(filepath::AbstractString, name::AbstractString, ttn::Ptr{Cvoid}) + return ccall( + _sym(:t4a_hdf5_save_mps), + Cint, + (Cstring, Cstring, Ptr{Cvoid}), + filepath, + name, + ttn + ) +end + +""" + t4a_hdf5_load_mps(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) -> Cint + +Load a tree tensor network (MPS) from an ITensorMPS.jl-compatible MPS in an HDF5 file. +Returns a `t4a_treetn` handle. +""" +function t4a_hdf5_load_mps(filepath::AbstractString, name::AbstractString, out::Ref{Ptr{Cvoid}}) + return ccall( + _sym(:t4a_hdf5_load_mps), + Cint, + (Cstring, Cstring, Ptr{Ptr{Cvoid}}), + filepath, + name, + out + ) +end + +# ============================================================================ +# QuanticsGrids: DiscretizedGrid functions +# ============================================================================ + +# Unfolding scheme enum (must match Rust side) +const UNFOLDING_FUSED = Cint(0) +const UNFOLDING_INTERLEAVED = Cint(1) +const UNFOLDING_GROUPED = Cint(2) + +""" + t4a_qgrid_disc_new(ndims, rs_arr, lower_arr, upper_arr, unfolding, out) -> Cint +""" +function t4a_qgrid_disc_new(ndims::Integer, rs_arr, lower_arr, upper_arr, unfolding::Integer, out) + return ccall( + _sym(:t4a_qgrid_disc_new), + Cint, + (Csize_t, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Cdouble}, Cint, Ptr{Ptr{Cvoid}}), + Csize_t(ndims), + rs_arr, + lower_arr, + upper_arr, + Cint(unfolding), + out ) end @@ -1619,45 +2125,212 @@ function t4a_qgrid_disc_local_dims(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer, ) end -function t4a_qgrid_disc_lower_bound(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) +function t4a_qgrid_disc_lower_bound(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) + return ccall( + _sym(:t4a_qgrid_disc_lower_bound), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t), + ptr, + out_arr, + Csize_t(buf_size) + ) +end + +function t4a_qgrid_disc_upper_bound(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) + return ccall( + _sym(:t4a_qgrid_disc_upper_bound), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t), + ptr, + out_arr, + Csize_t(buf_size) + ) +end + +function t4a_qgrid_disc_grid_step(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) + return ccall( + _sym(:t4a_qgrid_disc_grid_step), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t), + ptr, + out_arr, + Csize_t(buf_size) + ) +end + +function t4a_qgrid_disc_origcoord_to_quantics(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, + out_arr, buf_size::Integer, n_out::Ref{Csize_t}) + return ccall( + _sym(:t4a_qgrid_disc_origcoord_to_quantics), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), + ptr, + coord_arr, + Csize_t(ndims), + out_arr, + Csize_t(buf_size), + n_out + ) +end + +function t4a_qgrid_disc_quantics_to_origcoord(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, + out_arr, buf_size::Integer) + return ccall( + _sym(:t4a_qgrid_disc_quantics_to_origcoord), + Cint, + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}, Csize_t), + ptr, + quantics_arr, + Csize_t(n_quantics), + out_arr, + Csize_t(buf_size) + ) +end + +function t4a_qgrid_disc_origcoord_to_grididx(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, + out_arr, buf_size::Integer) + return ccall( + _sym(:t4a_qgrid_disc_origcoord_to_grididx), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Int64}, Csize_t), + ptr, + coord_arr, + Csize_t(ndims), + out_arr, + Csize_t(buf_size) + ) +end + +function t4a_qgrid_disc_grididx_to_origcoord(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, + out_arr, buf_size::Integer) + return ccall( + _sym(:t4a_qgrid_disc_grididx_to_origcoord), + Cint, + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}, Csize_t), + ptr, + grididx_arr, + Csize_t(ndims), + out_arr, + Csize_t(buf_size) + ) +end + +function t4a_qgrid_disc_grididx_to_quantics(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, + out_arr, buf_size::Integer, n_out::Ref{Csize_t}) + return ccall( + _sym(:t4a_qgrid_disc_grididx_to_quantics), + Cint, + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), + ptr, + grididx_arr, + Csize_t(ndims), + out_arr, + Csize_t(buf_size), + n_out + ) +end + +function t4a_qgrid_disc_quantics_to_grididx(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, + out_arr, buf_size::Integer, n_out::Ref{Csize_t}) + return ccall( + _sym(:t4a_qgrid_disc_quantics_to_grididx), + Cint, + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), + ptr, + quantics_arr, + Csize_t(n_quantics), + out_arr, + Csize_t(buf_size), + n_out + ) +end + +# ============================================================================ +# QuanticsGrids: InherentDiscreteGrid functions +# ============================================================================ + +function t4a_qgrid_int_new(ndims::Integer, rs_arr, origin_arr, unfolding::Integer, out) + return ccall( + _sym(:t4a_qgrid_int_new), + Cint, + (Csize_t, Ptr{Csize_t}, Ptr{Int64}, Cint, Ptr{Ptr{Cvoid}}), + Csize_t(ndims), + rs_arr, + origin_arr, + Cint(unfolding), + out + ) +end + +function t4a_qgrid_int_release(ptr::Ptr{Cvoid}) + ptr == C_NULL && return + ccall( + _sym(:t4a_qgrid_int_release), + Cvoid, + (Ptr{Cvoid},), + ptr + ) +end + +function t4a_qgrid_int_clone(ptr::Ptr{Cvoid}) + return ccall( + _sym(:t4a_qgrid_int_clone), + Ptr{Cvoid}, + (Ptr{Cvoid},), + ptr + ) +end + +function t4a_qgrid_int_ndims(ptr::Ptr{Cvoid}, out::Ref{Csize_t}) + return ccall( + _sym(:t4a_qgrid_int_ndims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, + out + ) +end + +function t4a_qgrid_int_rs(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) return ccall( - _sym(:t4a_qgrid_disc_lower_bound), + _sym(:t4a_qgrid_int_rs), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t), + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), ptr, out_arr, Csize_t(buf_size) ) end -function t4a_qgrid_disc_upper_bound(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) +function t4a_qgrid_int_local_dims(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer, n_out::Ref{Csize_t}) return ccall( - _sym(:t4a_qgrid_disc_upper_bound), + _sym(:t4a_qgrid_int_local_dims), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t), + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), ptr, out_arr, - Csize_t(buf_size) + Csize_t(buf_size), + n_out ) end -function t4a_qgrid_disc_grid_step(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) +function t4a_qgrid_int_origin(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) return ccall( - _sym(:t4a_qgrid_disc_grid_step), + _sym(:t4a_qgrid_int_origin), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t), + (Ptr{Cvoid}, Ptr{Int64}, Csize_t), ptr, out_arr, Csize_t(buf_size) ) end -function t4a_qgrid_disc_origcoord_to_quantics(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, - out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +function t4a_qgrid_int_origcoord_to_quantics(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, + out_arr, buf_size::Integer, n_out::Ref{Csize_t}) return ccall( - _sym(:t4a_qgrid_disc_origcoord_to_quantics), + _sym(:t4a_qgrid_int_origcoord_to_quantics), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), ptr, coord_arr, Csize_t(ndims), @@ -1667,12 +2340,12 @@ function t4a_qgrid_disc_origcoord_to_quantics(ptr::Ptr{Cvoid}, coord_arr, ndims: ) end -function t4a_qgrid_disc_quantics_to_origcoord(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, - out_arr, buf_size::Integer) +function t4a_qgrid_int_quantics_to_origcoord(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, + out_arr, buf_size::Integer) return ccall( - _sym(:t4a_qgrid_disc_quantics_to_origcoord), + _sym(:t4a_qgrid_int_quantics_to_origcoord), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}, Csize_t), + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t), ptr, quantics_arr, Csize_t(n_quantics), @@ -1681,12 +2354,12 @@ function t4a_qgrid_disc_quantics_to_origcoord(ptr::Ptr{Cvoid}, quantics_arr, n_q ) end -function t4a_qgrid_disc_origcoord_to_grididx(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, - out_arr, buf_size::Integer) +function t4a_qgrid_int_origcoord_to_grididx(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, + out_arr, buf_size::Integer) return ccall( - _sym(:t4a_qgrid_disc_origcoord_to_grididx), + _sym(:t4a_qgrid_int_origcoord_to_grididx), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}, Csize_t, Ptr{Int64}, Csize_t), + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t), ptr, coord_arr, Csize_t(ndims), @@ -1695,12 +2368,12 @@ function t4a_qgrid_disc_origcoord_to_grididx(ptr::Ptr{Cvoid}, coord_arr, ndims:: ) end -function t4a_qgrid_disc_grididx_to_origcoord(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, - out_arr, buf_size::Integer) +function t4a_qgrid_int_grididx_to_origcoord(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, + out_arr, buf_size::Integer) return ccall( - _sym(:t4a_qgrid_disc_grididx_to_origcoord), + _sym(:t4a_qgrid_int_grididx_to_origcoord), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}, Csize_t), + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t), ptr, grididx_arr, Csize_t(ndims), @@ -1709,10 +2382,10 @@ function t4a_qgrid_disc_grididx_to_origcoord(ptr::Ptr{Cvoid}, grididx_arr, ndims ) end -function t4a_qgrid_disc_grididx_to_quantics(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, - out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +function t4a_qgrid_int_grididx_to_quantics(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, + out_arr, buf_size::Integer, n_out::Ref{Csize_t}) return ccall( - _sym(:t4a_qgrid_disc_grididx_to_quantics), + _sym(:t4a_qgrid_int_grididx_to_quantics), Cint, (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), ptr, @@ -1724,10 +2397,10 @@ function t4a_qgrid_disc_grididx_to_quantics(ptr::Ptr{Cvoid}, grididx_arr, ndims: ) end -function t4a_qgrid_disc_quantics_to_grididx(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, - out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +function t4a_qgrid_int_quantics_to_grididx(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, + out_arr, buf_size::Integer, n_out::Ref{Csize_t}) return ccall( - _sym(:t4a_qgrid_disc_quantics_to_grididx), + _sym(:t4a_qgrid_int_quantics_to_grididx), Cint, (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), ptr, @@ -1740,295 +2413,549 @@ function t4a_qgrid_disc_quantics_to_grididx(ptr::Ptr{Cvoid}, quantics_arr, n_qua end # ============================================================================ -# QuanticsGrids: InherentDiscreteGrid functions +# QuanticsTCI: QTCI lifecycle functions # ============================================================================ -function t4a_qgrid_int_new(ndims::Integer, rs_arr, origin_arr, unfolding::Integer, out) +function t4a_qtci_f64_release(ptr::Ptr{Cvoid}) + ptr == C_NULL && return + ccall( + _sym(:t4a_qtci_f64_release), + Cvoid, + (Ptr{Cvoid},), + ptr + ) +end + +# ============================================================================ +# QuanticsTCI: High-level interpolation functions +# ============================================================================ + +""" + t4a_quanticscrossinterpolate_f64(grid, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters) -> Cint + +Continuous domain interpolation using a DiscretizedGrid. +options: Ptr to QtciOptions (or C_NULL for defaults). +initial_pivots: flat array of i64 pivots (or C_NULL). +out_ranks/out_errors: optional per-iteration output buffers (or C_NULL). +out_n_iters: optional output for number of iterations (or C_NULL). +""" +function t4a_quanticscrossinterpolate_f64( + grid::Ptr{Cvoid}, + eval_fn::Ptr{Cvoid}, + user_data::Ptr{Cvoid}, + options::Ptr{Cvoid}, + tolerance::Cdouble, + max_bonddim::Csize_t, + max_iter::Csize_t, + initial_pivots, + n_pivots::Csize_t, + out_qtci, + out_ranks, + out_errors, + out_n_iters, +) return ccall( - _sym(:t4a_qgrid_int_new), + _sym(:t4a_quanticscrossinterpolate_f64), Cint, - (Csize_t, Ptr{Csize_t}, Ptr{Int64}, Cint, Ptr{Ptr{Cvoid}}), - Csize_t(ndims), - rs_arr, - origin_arr, - Cint(unfolding), - out + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cdouble, Csize_t, Csize_t, + Ptr{Int64}, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + grid, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, + initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters ) end -function t4a_qgrid_int_release(ptr::Ptr{Cvoid}) - ptr == C_NULL && return - ccall( - _sym(:t4a_qgrid_int_release), - Cvoid, +""" + t4a_quanticscrossinterpolate_discrete_f64(sizes, ndims, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, unfoldingscheme, initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters) -> Cint + +Discrete domain interpolation with integer indices. +options: Ptr to QtciOptions (or C_NULL for defaults). +initial_pivots: flat array of i64 pivots (or C_NULL). +out_ranks/out_errors: optional per-iteration output buffers (or C_NULL). +out_n_iters: optional output for number of iterations (or C_NULL). +""" +function t4a_quanticscrossinterpolate_discrete_f64( + sizes::Vector{Csize_t}, + ndims::Csize_t, + eval_fn::Ptr{Cvoid}, + user_data::Ptr{Cvoid}, + options::Ptr{Cvoid}, + tolerance::Cdouble, + max_bonddim::Csize_t, + max_iter::Csize_t, + unfoldingscheme::Cint, + initial_pivots, + n_pivots::Csize_t, + out_qtci, + out_ranks, + out_errors, + out_n_iters, +) + return ccall( + _sym(:t4a_quanticscrossinterpolate_discrete_f64), + Cint, + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cdouble, Csize_t, Csize_t, Cint, + Ptr{Int64}, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + sizes, ndims, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, unfoldingscheme, + initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters + ) +end + +# ============================================================================ +# QuanticsTCI: Accessors +# ============================================================================ + +function t4a_qtci_f64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) + return ccall( + _sym(:t4a_qtci_f64_rank), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, out_rank + ) +end + +function t4a_qtci_f64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}, buf_len::Csize_t) + return ccall( + _sym(:t4a_qtci_f64_link_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + ptr, out_dims, buf_len + ) +end + +# ============================================================================ +# QuanticsTCI: Operations +# ============================================================================ + +function t4a_qtci_f64_evaluate(ptr::Ptr{Cvoid}, indices::Vector{Int64}, n_indices::Csize_t, out_value::Ref{Cdouble}) + return ccall( + _sym(:t4a_qtci_f64_evaluate), + Cint, + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}), + ptr, indices, n_indices, out_value + ) +end + +function t4a_qtci_f64_sum(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) + return ccall( + _sym(:t4a_qtci_f64_sum), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), + ptr, out_value + ) +end + +function t4a_qtci_f64_integral(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) + return ccall( + _sym(:t4a_qtci_f64_integral), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), + ptr, out_value + ) +end + +function t4a_qtci_f64_to_tensor_train(ptr::Ptr{Cvoid}) + return ccall( + _sym(:t4a_qtci_f64_to_tensor_train), + Ptr{Cvoid}, (Ptr{Cvoid},), ptr ) end -function t4a_qgrid_int_clone(ptr::Ptr{Cvoid}) +""" + t4a_qtci_f64_clone(ptr) -> Ptr{Cvoid} + +Clone a QuanticsTCI f64 handle. +""" +function t4a_qtci_f64_clone(ptr::Ptr{Cvoid}) return ccall( - _sym(:t4a_qgrid_int_clone), + _sym(:t4a_qtci_f64_clone), Ptr{Cvoid}, (Ptr{Cvoid},), ptr ) end -function t4a_qgrid_int_ndims(ptr::Ptr{Cvoid}, out::Ref{Csize_t}) +""" + t4a_qtci_f64_max_bond_error(ptr, out_value) -> Cint + +Get the maximum bond error from the QuanticsTCI. +""" +function t4a_qtci_f64_max_bond_error(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) return ccall( - _sym(:t4a_qgrid_int_ndims), + _sym(:t4a_qtci_f64_max_bond_error), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}), + (Ptr{Cvoid}, Ptr{Cdouble}), ptr, - out + out_value ) end -function t4a_qgrid_int_rs(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) +""" + t4a_qtci_f64_max_rank(ptr, out_rank) -> Cint + +Get the maximum rank from the QuanticsTCI. +""" +function t4a_qtci_f64_max_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) return ccall( - _sym(:t4a_qgrid_int_rs), + _sym(:t4a_qtci_f64_max_rank), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + (Ptr{Cvoid}, Ptr{Csize_t}), ptr, - out_arr, - Csize_t(buf_size) + out_rank ) end -function t4a_qgrid_int_local_dims(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +# ============================================================================ +# QuanticsTCI c64: Lifecycle +# ============================================================================ + +function t4a_qtci_c64_release(ptr::Ptr{Cvoid}) + ptr == C_NULL && return + ccall( + _sym(:t4a_qtci_c64_release), + Cvoid, + (Ptr{Cvoid},), + ptr + ) +end + +function t4a_qtci_c64_clone(ptr::Ptr{Cvoid}) return ccall( - _sym(:t4a_qgrid_int_local_dims), - Cint, - (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), - ptr, - out_arr, - Csize_t(buf_size), - n_out + _sym(:t4a_qtci_c64_clone), + Ptr{Cvoid}, + (Ptr{Cvoid},), + ptr ) end -function t4a_qgrid_int_origin(ptr::Ptr{Cvoid}, out_arr, buf_size::Integer) +# ============================================================================ +# QuanticsTCI c64: Accessors +# ============================================================================ + +function t4a_qtci_c64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) return ccall( - _sym(:t4a_qgrid_int_origin), + _sym(:t4a_qtci_c64_rank), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t), - ptr, - out_arr, - Csize_t(buf_size) + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, out_rank ) end -function t4a_qgrid_int_origcoord_to_quantics(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, - out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +function t4a_qtci_c64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}, buf_len::Csize_t) return ccall( - _sym(:t4a_qgrid_int_origcoord_to_quantics), + _sym(:t4a_qtci_c64_link_dims), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), - ptr, - coord_arr, - Csize_t(ndims), - out_arr, - Csize_t(buf_size), - n_out + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), + ptr, out_dims, buf_len ) end -function t4a_qgrid_int_quantics_to_origcoord(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, - out_arr, buf_size::Integer) +# ============================================================================ +# QuanticsTCI c64: Operations +# ============================================================================ + +function t4a_qtci_c64_evaluate(ptr::Ptr{Cvoid}, indices::Vector{Int64}, n_indices::Csize_t, out_re::Ref{Cdouble}, out_im::Ref{Cdouble}) return ccall( - _sym(:t4a_qgrid_int_quantics_to_origcoord), + _sym(:t4a_qtci_c64_evaluate), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t), - ptr, - quantics_arr, - Csize_t(n_quantics), - out_arr, - Csize_t(buf_size) + (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}), + ptr, indices, n_indices, out_re, out_im ) end -function t4a_qgrid_int_origcoord_to_grididx(ptr::Ptr{Cvoid}, coord_arr, ndims::Integer, - out_arr, buf_size::Integer) +function t4a_qtci_c64_sum(ptr::Ptr{Cvoid}, out_re::Ref{Cdouble}, out_im::Ref{Cdouble}) return ccall( - _sym(:t4a_qgrid_int_origcoord_to_grididx), + _sym(:t4a_qtci_c64_sum), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t), - ptr, - coord_arr, - Csize_t(ndims), - out_arr, - Csize_t(buf_size) + (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), + ptr, out_re, out_im ) end -function t4a_qgrid_int_grididx_to_origcoord(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, - out_arr, buf_size::Integer) +function t4a_qtci_c64_integral(ptr::Ptr{Cvoid}, out_re::Ref{Cdouble}, out_im::Ref{Cdouble}) return ccall( - _sym(:t4a_qgrid_int_grididx_to_origcoord), + _sym(:t4a_qtci_c64_integral), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t), - ptr, - grididx_arr, - Csize_t(ndims), - out_arr, - Csize_t(buf_size) + (Ptr{Cvoid}, Ptr{Cdouble}, Ptr{Cdouble}), + ptr, out_re, out_im ) end -function t4a_qgrid_int_grididx_to_quantics(ptr::Ptr{Cvoid}, grididx_arr, ndims::Integer, - out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +function t4a_qtci_c64_to_tensor_train(ptr::Ptr{Cvoid}) return ccall( - _sym(:t4a_qgrid_int_grididx_to_quantics), - Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), - ptr, - grididx_arr, - Csize_t(ndims), - out_arr, - Csize_t(buf_size), - n_out + _sym(:t4a_qtci_c64_to_tensor_train), + Ptr{Cvoid}, + (Ptr{Cvoid},), + ptr ) end -function t4a_qgrid_int_quantics_to_grididx(ptr::Ptr{Cvoid}, quantics_arr, n_quantics::Integer, - out_arr, buf_size::Integer, n_out::Ref{Csize_t}) +function t4a_qtci_c64_max_bond_error(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) return ccall( - _sym(:t4a_qgrid_int_quantics_to_grididx), + _sym(:t4a_qtci_c64_max_bond_error), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Int64}, Csize_t, Ptr{Csize_t}), + (Ptr{Cvoid}, Ptr{Cdouble}), ptr, - quantics_arr, - Csize_t(n_quantics), - out_arr, - Csize_t(buf_size), - n_out + out_value ) end -# ============================================================================ -# QuanticsTCI: QTCI lifecycle functions -# ============================================================================ - -function t4a_qtci_f64_release(ptr::Ptr{Cvoid}) - ptr == C_NULL && return - ccall( - _sym(:t4a_qtci_f64_release), - Cvoid, - (Ptr{Cvoid},), - ptr +function t4a_qtci_c64_max_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) + return ccall( + _sym(:t4a_qtci_c64_max_rank), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), + ptr, + out_rank ) end # ============================================================================ -# QuanticsTCI: High-level interpolation functions +# QuanticsTCI c64: High-level interpolation functions # ============================================================================ """ - t4a_quanticscrossinterpolate_f64(grid, eval_fn, user_data, tolerance, max_bonddim, max_iter, out_qtci) -> Cint + t4a_quanticscrossinterpolate_c64(grid, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters) -> Cint -Continuous domain interpolation using a DiscretizedGrid. +Continuous domain complex interpolation using a DiscretizedGrid. """ -function t4a_quanticscrossinterpolate_f64( +function t4a_quanticscrossinterpolate_c64( grid::Ptr{Cvoid}, eval_fn::Ptr{Cvoid}, user_data::Ptr{Cvoid}, + options::Ptr{Cvoid}, tolerance::Cdouble, max_bonddim::Csize_t, max_iter::Csize_t, - out_qtci::Ref{Ptr{Cvoid}}, + initial_pivots, + n_pivots::Csize_t, + out_qtci, + out_ranks, + out_errors, + out_n_iters, ) return ccall( - _sym(:t4a_quanticscrossinterpolate_f64), + _sym(:t4a_quanticscrossinterpolate_c64), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cdouble, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}), - grid, eval_fn, user_data, tolerance, max_bonddim, max_iter, out_qtci + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cdouble, Csize_t, Csize_t, + Ptr{Int64}, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + grid, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, + initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters ) end """ - t4a_quanticscrossinterpolate_discrete_f64(sizes, ndims, eval_fn, user_data, tolerance, max_bonddim, max_iter, unfoldingscheme, out_qtci) -> Cint + t4a_quanticscrossinterpolate_discrete_c64(sizes, ndims, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, unfoldingscheme, initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters) -> Cint -Discrete domain interpolation with integer indices. +Discrete domain complex interpolation with integer indices. """ -function t4a_quanticscrossinterpolate_discrete_f64( +function t4a_quanticscrossinterpolate_discrete_c64( sizes::Vector{Csize_t}, ndims::Csize_t, eval_fn::Ptr{Cvoid}, user_data::Ptr{Cvoid}, + options::Ptr{Cvoid}, tolerance::Cdouble, max_bonddim::Csize_t, max_iter::Csize_t, unfoldingscheme::Cint, - out_qtci::Ref{Ptr{Cvoid}}, + initial_pivots, + n_pivots::Csize_t, + out_qtci, + out_ranks, + out_errors, + out_n_iters, ) return ccall( - _sym(:t4a_quanticscrossinterpolate_discrete_f64), + _sym(:t4a_quanticscrossinterpolate_discrete_c64), Cint, - (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, Cdouble, Csize_t, Csize_t, Cint, Ptr{Ptr{Cvoid}}), - sizes, ndims, eval_fn, user_data, tolerance, max_bonddim, max_iter, unfoldingscheme, out_qtci + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cdouble, Csize_t, Csize_t, Cint, + Ptr{Int64}, Csize_t, Ptr{Ptr{Cvoid}}, Ptr{Csize_t}, Ptr{Cdouble}, Ptr{Csize_t}), + sizes, ndims, eval_fn, user_data, options, tolerance, max_bonddim, max_iter, unfoldingscheme, + initial_pivots, n_pivots, out_qtci, out_ranks, out_errors, out_n_iters ) end # ============================================================================ -# QuanticsTCI: Accessors +# QtciOptions lifecycle and setters # ============================================================================ -function t4a_qtci_f64_rank(ptr::Ptr{Cvoid}, out_rank::Ref{Csize_t}) +""" + t4a_qtci_options_default() -> Ptr{Cvoid} + +Create a new QtciOptions handle with default settings. +""" +function t4a_qtci_options_default() return ccall( - _sym(:t4a_qtci_f64_rank), + _sym(:t4a_qtci_options_default), + Ptr{Cvoid}, + () + ) +end + +""" + t4a_qtci_options_release(ptr) + +Release a QtciOptions handle. +""" +function t4a_qtci_options_release(ptr::Ptr{Cvoid}) + ptr == C_NULL && return + ccall( + _sym(:t4a_qtci_options_release), + Cvoid, + (Ptr{Cvoid},), + ptr + ) +end + +""" + t4a_qtci_options_clone(ptr) -> Ptr{Cvoid} + +Clone a QtciOptions handle. +""" +function t4a_qtci_options_clone(ptr::Ptr{Cvoid}) + return ccall( + _sym(:t4a_qtci_options_clone), + Ptr{Cvoid}, + (Ptr{Cvoid},), + ptr + ) +end + +""" + t4a_qtci_options_set_tolerance(ptr, tolerance) -> Cint + +Set the tolerance for QTCI. +""" +function t4a_qtci_options_set_tolerance(ptr::Ptr{Cvoid}, tolerance::Float64) + return ccall( + _sym(:t4a_qtci_options_set_tolerance), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}), - ptr, out_rank + (Ptr{Cvoid}, Cdouble), + ptr, + tolerance ) end -function t4a_qtci_f64_link_dims(ptr::Ptr{Cvoid}, out_dims::Vector{Csize_t}, buf_len::Csize_t) +""" + t4a_qtci_options_set_maxbonddim(ptr, dim) -> Cint + +Set the maximum bond dimension for QTCI. +""" +function t4a_qtci_options_set_maxbonddim(ptr::Ptr{Cvoid}, dim::Integer) return ccall( - _sym(:t4a_qtci_f64_link_dims), + _sym(:t4a_qtci_options_set_maxbonddim), Cint, - (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t), - ptr, out_dims, buf_len + (Ptr{Cvoid}, Csize_t), + ptr, + Csize_t(dim) ) end -# ============================================================================ -# QuanticsTCI: Operations -# ============================================================================ +""" + t4a_qtci_options_set_maxiter(ptr, iter) -> Cint -function t4a_qtci_f64_evaluate(ptr::Ptr{Cvoid}, indices::Vector{Int64}, n_indices::Csize_t, out_value::Ref{Cdouble}) +Set the maximum number of iterations for QTCI. +""" +function t4a_qtci_options_set_maxiter(ptr::Ptr{Cvoid}, iter::Integer) return ccall( - _sym(:t4a_qtci_f64_evaluate), + _sym(:t4a_qtci_options_set_maxiter), Cint, - (Ptr{Cvoid}, Ptr{Int64}, Csize_t, Ptr{Cdouble}), - ptr, indices, n_indices, out_value + (Ptr{Cvoid}, Csize_t), + ptr, + Csize_t(iter) ) end -function t4a_qtci_f64_sum(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) +""" + t4a_qtci_options_set_nrandominitpivot(ptr, n) -> Cint + +Set the number of random initial pivots for QTCI. +""" +function t4a_qtci_options_set_nrandominitpivot(ptr::Ptr{Cvoid}, n::Integer) return ccall( - _sym(:t4a_qtci_f64_sum), + _sym(:t4a_qtci_options_set_nrandominitpivot), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}), - ptr, out_value + (Ptr{Cvoid}, Csize_t), + ptr, + Csize_t(n) ) end -function t4a_qtci_f64_integral(ptr::Ptr{Cvoid}, out_value::Ref{Cdouble}) +""" + t4a_qtci_options_set_unfoldingscheme(ptr, scheme) -> Cint + +Set the unfolding scheme for QTCI. +scheme: 0=Fused, 1=Interleaved, 2=Grouped +""" +function t4a_qtci_options_set_unfoldingscheme(ptr::Ptr{Cvoid}, scheme::Integer) return ccall( - _sym(:t4a_qtci_f64_integral), + _sym(:t4a_qtci_options_set_unfoldingscheme), Cint, - (Ptr{Cvoid}, Ptr{Cdouble}), - ptr, out_value + (Ptr{Cvoid}, Cint), + ptr, + Cint(scheme) ) end -function t4a_qtci_f64_to_tensor_train(ptr::Ptr{Cvoid}) +""" + t4a_qtci_options_set_normalize_error(ptr, flag) -> Cint + +Set whether to normalize errors in QTCI. +""" +function t4a_qtci_options_set_normalize_error(ptr::Ptr{Cvoid}, flag::Integer) return ccall( - _sym(:t4a_qtci_f64_to_tensor_train), - Ptr{Cvoid}, - (Ptr{Cvoid},), - ptr + _sym(:t4a_qtci_options_set_normalize_error), + Cint, + (Ptr{Cvoid}, Cint), + ptr, + Cint(flag) + ) +end + +""" + t4a_qtci_options_set_verbosity(ptr, level) -> Cint + +Set the verbosity level for QTCI. +""" +function t4a_qtci_options_set_verbosity(ptr::Ptr{Cvoid}, level::Integer) + return ccall( + _sym(:t4a_qtci_options_set_verbosity), + Cint, + (Ptr{Cvoid}, Csize_t), + ptr, + Csize_t(level) + ) +end + +""" + t4a_qtci_options_set_nsearchglobalpivot(ptr, n) -> Cint + +Set the number of global pivot searches for QTCI. +""" +function t4a_qtci_options_set_nsearchglobalpivot(ptr::Ptr{Cvoid}, n::Integer) + return ccall( + _sym(:t4a_qtci_options_set_nsearchglobalpivot), + Cint, + (Ptr{Cvoid}, Csize_t), + ptr, + Csize_t(n) + ) +end + +""" + t4a_qtci_options_set_nsearch(ptr, n) -> Cint + +Set the number of searches for QTCI. +""" +function t4a_qtci_options_set_nsearch(ptr::Ptr{Cvoid}, n::Integer) + return ccall( + _sym(:t4a_qtci_options_set_nsearch), + Cint, + (Ptr{Cvoid}, Csize_t), + ptr, + Csize_t(n) ) end @@ -2098,6 +3025,27 @@ function t4a_qtransform_shift(r::Csize_t, offset::Int64, bc::Cint, out) ) end +""" + t4a_qtransform_shift_multivar(r, offset, bc, nvariables, target_var, out) -> Cint + +Create a shift operator acting on one variable in a multi-variable quantics system. +""" +function t4a_qtransform_shift_multivar( + r::Csize_t, + offset::Int64, + bc::Cint, + nvariables::Csize_t, + target_var::Csize_t, + out, +) + return ccall( + _sym(:t4a_qtransform_shift_multivar), + Cint, + (Csize_t, Int64, Cint, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}), + r, offset, bc, nvariables, target_var, out + ) +end + """ t4a_qtransform_flip(r, bc, out) -> Cint @@ -2112,6 +3060,26 @@ function t4a_qtransform_flip(r::Csize_t, bc::Cint, out) ) end +""" + t4a_qtransform_flip_multivar(r, bc, nvariables, target_var, out) -> Cint + +Create a flip operator acting on one variable in a multi-variable quantics system. +""" +function t4a_qtransform_flip_multivar( + r::Csize_t, + bc::Cint, + nvariables::Csize_t, + target_var::Csize_t, + out, +) + return ccall( + _sym(:t4a_qtransform_flip_multivar), + Cint, + (Csize_t, Cint, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}), + r, bc, nvariables, target_var, out + ) +end + """ t4a_qtransform_phase_rotation(r, theta, out) -> Cint @@ -2126,6 +3094,26 @@ function t4a_qtransform_phase_rotation(r::Csize_t, theta::Cdouble, out) ) end +""" + t4a_qtransform_phase_rotation_multivar(r, theta, nvariables, target_var, out) -> Cint + +Create a phase rotation operator acting on one variable in a multi-variable quantics system. +""" +function t4a_qtransform_phase_rotation_multivar( + r::Csize_t, + theta::Cdouble, + nvariables::Csize_t, + target_var::Csize_t, + out, +) + return ccall( + _sym(:t4a_qtransform_phase_rotation_multivar), + Cint, + (Csize_t, Cdouble, Csize_t, Csize_t, Ptr{Ptr{Cvoid}}), + r, theta, nvariables, target_var, out + ) +end + """ t4a_qtransform_cumsum(r, out) -> Cint @@ -2156,6 +3144,56 @@ function t4a_qtransform_fourier(r::Csize_t, forward::Cint, maxbonddim::Csize_t, ) end +""" + t4a_qtransform_affine(r, a_num, a_den, b_num, b_den, m, n, bc, out) -> Cint + +Create a general affine transformation operator with rational coefficients. +`a_num`/`a_den` encode the MxN matrix in column-major order. +`b_num`/`b_den` encode the M-element translation vector. +`bc` has length `m`. +""" +function t4a_qtransform_affine( + r::Csize_t, + a_num, + a_den, + b_num, + b_den, + m::Csize_t, + n::Csize_t, + bc, + out, +) + return ccall( + _sym(:t4a_qtransform_affine), + Cint, + (Csize_t, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}, Csize_t, Csize_t, Ptr{Cint}, Ptr{Ptr{Cvoid}}), + r, a_num, a_den, b_num, b_den, m, n, bc, out + ) +end + +""" + t4a_qtransform_binaryop(r, a1, b1, a2, b2, bc1, bc2, out) -> Cint + +Create a two-output binary-operation transform. +""" +function t4a_qtransform_binaryop( + r::Csize_t, + a1::Int8, + b1::Int8, + a2::Int8, + b2::Int8, + bc1::Cint, + bc2::Cint, + out, +) + return ccall( + _sym(:t4a_qtransform_binaryop), + Cint, + (Csize_t, Int8, Int8, Int8, Int8, Cint, Cint, Ptr{Ptr{Cvoid}}), + r, a1, b1, a2, b2, bc1, bc2, out + ) +end + """ t4a_linop_apply(op, state, method, rtol, maxdim, out) -> Cint @@ -2172,4 +3210,32 @@ function t4a_linop_apply(op::Ptr{Cvoid}, state::Ptr{Cvoid}, method::Cint, ) end +""" + t4a_linop_set_input_space(op, state) -> Cint + +Reset a linear operator's true input site indices to match a TreeTN state. +""" +function t4a_linop_set_input_space(op::Ptr{Cvoid}, state::Ptr{Cvoid}) + return ccall( + _sym(:t4a_linop_set_input_space), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}), + op, state + ) +end + +""" + t4a_linop_set_output_space(op, state) -> Cint + +Reset a linear operator's true output site indices to match a TreeTN state. +""" +function t4a_linop_set_output_space(op::Ptr{Cvoid}, state::Ptr{Cvoid}) + return ccall( + _sym(:t4a_linop_set_output_space), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}), + op, state + ) +end + end # module C_API diff --git a/src/QuanticsGrids.jl b/src/QuanticsGrids.jl index 0b526ba..d707bfe 100644 --- a/src/QuanticsGrids.jl +++ b/src/QuanticsGrids.jl @@ -29,6 +29,7 @@ export DiscretizedGrid, InherentDiscreteGrid export origcoord_to_quantics, quantics_to_origcoord export origcoord_to_grididx, grididx_to_origcoord export grididx_to_quantics, quantics_to_grididx +export localdimensions # ============================================================================ # Unfolding scheme helper @@ -39,8 +40,10 @@ function _unfolding_to_cint(unfolding::Symbol) return C_API.UNFOLDING_FUSED elseif unfolding == :interleaved return C_API.UNFOLDING_INTERLEAVED + elseif unfolding == :grouped + return C_API.UNFOLDING_GROUPED else - error("Unknown unfolding scheme: $unfolding. Use :fused or :interleaved.") + error("Unknown unfolding scheme: $unfolding. Use :fused, :interleaved, or :grouped.") end end @@ -62,7 +65,7 @@ A discretized grid with continuous domain and floating-point coordinates. - `ndims(g)` - Number of dimensions - `rs(g)` - Resolution (bits) per dimension -- `local_dimensions(g)` - Local dimensions of tensor sites +- `localdimensions(g)` - Local dimensions of tensor sites - `lower_bound(g)` - Lower bounds per dimension - `upper_bound(g)` - Upper bounds per dimension - `grid_step(g)` - Grid spacing per dimension @@ -136,11 +139,11 @@ Resolution (bits) per dimension. rs(g::DiscretizedGrid) = g._rs """ - local_dimensions(g::DiscretizedGrid) -> Vector{Int} + localdimensions(g::DiscretizedGrid) -> Vector{Int} Local dimensions of all tensor sites. """ -function local_dimensions(g::DiscretizedGrid) +function localdimensions(g::DiscretizedGrid) max_sites = sum(g._rs) + g._ndims # upper bound out = Vector{Csize_t}(undef, max_sites) n_out = Ref{Csize_t}(0) @@ -356,7 +359,7 @@ end Base.ndims(g::InherentDiscreteGrid) = g._ndims rs(g::InherentDiscreteGrid) = g._rs -function local_dimensions(g::InherentDiscreteGrid) +function localdimensions(g::InherentDiscreteGrid) max_sites = sum(g._rs) + g._ndims out = Vector{Csize_t}(undef, max_sites) n_out = Ref{Csize_t}(0) diff --git a/src/QuanticsTCI.jl b/src/QuanticsTCI.jl index f19aed0..b1f3d6f 100644 --- a/src/QuanticsTCI.jl +++ b/src/QuanticsTCI.jl @@ -5,23 +5,41 @@ High-level Julia wrappers for Quantics Tensor Cross Interpolation. This module combines TCI with quantics grid representations to efficiently interpolate functions on continuous or discrete domains. + +Supports both `Float64` and `ComplexF64` element types. """ module QuanticsTCI using ..C_API -using ..QuanticsGrids: DiscretizedGrid, InherentDiscreteGrid +using ..QuanticsGrids: DiscretizedGrid, InherentDiscreteGrid, localdimensions, + _unfolding_to_cint using ..SimpleTT: SimpleTensorTrain +import ..SimpleTT +import ..linkdims, ..evaluate, ..maxbonderror, ..maxrank, ..rank export QuanticsTensorCI2 -export quanticscrossinterpolate, quanticscrossinterpolate_discrete -export rank, link_dims, evaluate, sum, integral, to_tensor_train +export quanticscrossinterpolate +export evaluate, integral, to_tensor_train +export linkdims, maxbonderror, maxrank + +# ============================================================================ +# Type dispatch helpers +# ============================================================================ + +const _QtciScalar = Union{Float64, ComplexF64} + +_suffix(::Type{Float64}) = "f64" +_suffix(::Type{ComplexF64}) = "c64" + +_qtci_api(::Type{T}, name::Symbol) where {T<:_QtciScalar} = + getfield(C_API, Symbol("t4a_qtci_", _suffix(T), "_", name)) # ============================================================================ # Callback infrastructure for passing Julia functions to Rust # ============================================================================ -# For continuous domain (f64 coordinates) -function _trampoline_f64( +# For continuous domain (f64 coordinates) returning Float64 +function _trampoline_f64_f64( coords_ptr::Ptr{Float64}, ndims::Csize_t, result_ptr::Ptr{Float64}, @@ -35,13 +53,35 @@ function _trampoline_f64( unsafe_store!(result_ptr, val) return Cint(0) catch e - @error "Error in QTCI callback" exception = (e, catch_backtrace()) + @error "Error in QTCI f64 callback" exception = (e, catch_backtrace()) return Cint(-1) end end -# For discrete domain (i64 indices, 1-indexed) -function _trampoline_i64( +# For continuous domain (f64 coordinates) returning ComplexF64 +# Result buffer has 2 doubles: [re, im] +function _trampoline_f64_c64( + coords_ptr::Ptr{Float64}, + ndims::Csize_t, + result_ptr::Ptr{Float64}, + user_data::Ptr{Cvoid}, +)::Cint + try + f_ref = unsafe_pointer_to_objref(user_data)::Ref{Any} + f = f_ref[] + coords = unsafe_wrap(Array, coords_ptr, Int(ndims)) + val = ComplexF64(f(coords...)) + unsafe_store!(result_ptr, real(val), 1) + unsafe_store!(result_ptr, imag(val), 2) + return Cint(0) + catch e + @error "Error in QTCI c64 callback" exception = (e, catch_backtrace()) + return Cint(-1) + end +end + +# For discrete domain (i64 indices, 1-indexed) returning Float64 +function _trampoline_i64_f64( indices_ptr::Ptr{Int64}, ndims::Csize_t, result_ptr::Ptr{Float64}, @@ -55,35 +95,118 @@ function _trampoline_i64( unsafe_store!(result_ptr, val) return Cint(0) catch e - @error "Error in QTCI callback" exception = (e, catch_backtrace()) + @error "Error in QTCI discrete f64 callback" exception = (e, catch_backtrace()) + return Cint(-1) + end +end + +# For discrete domain (i64 indices, 1-indexed) returning ComplexF64 +function _trampoline_i64_c64( + indices_ptr::Ptr{Int64}, + ndims::Csize_t, + result_ptr::Ptr{Float64}, + user_data::Ptr{Cvoid}, +)::Cint + try + f_ref = unsafe_pointer_to_objref(user_data)::Ref{Any} + f = f_ref[] + indices = unsafe_wrap(Array, indices_ptr, Int(ndims)) + val = ComplexF64(f(indices...)) + unsafe_store!(result_ptr, real(val), 1) + unsafe_store!(result_ptr, imag(val), 2) + return Cint(0) + catch e + @error "Error in QTCI discrete c64 callback" exception = (e, catch_backtrace()) return Cint(-1) end end # Create C function pointers lazily to avoid precompilation issues -const _TRAMPOLINE_F64_PTR = Ref{Ptr{Cvoid}}(C_NULL) -const _TRAMPOLINE_I64_PTR = Ref{Ptr{Cvoid}}(C_NULL) +const _TRAMPOLINE_F64_F64_PTR = Ref{Ptr{Cvoid}}(C_NULL) +const _TRAMPOLINE_F64_C64_PTR = Ref{Ptr{Cvoid}}(C_NULL) +const _TRAMPOLINE_I64_F64_PTR = Ref{Ptr{Cvoid}}(C_NULL) +const _TRAMPOLINE_I64_C64_PTR = Ref{Ptr{Cvoid}}(C_NULL) + +function _get_trampoline_f64(::Type{Float64}) + if _TRAMPOLINE_F64_F64_PTR[] == C_NULL + _TRAMPOLINE_F64_F64_PTR[] = @cfunction( + _trampoline_f64_f64, + Cint, + (Ptr{Float64}, Csize_t, Ptr{Float64}, Ptr{Cvoid}) + ) + end + return _TRAMPOLINE_F64_F64_PTR[] +end -function _get_trampoline_f64_ptr() - if _TRAMPOLINE_F64_PTR[] == C_NULL - _TRAMPOLINE_F64_PTR[] = @cfunction( - _trampoline_f64, +function _get_trampoline_f64(::Type{ComplexF64}) + if _TRAMPOLINE_F64_C64_PTR[] == C_NULL + _TRAMPOLINE_F64_C64_PTR[] = @cfunction( + _trampoline_f64_c64, Cint, (Ptr{Float64}, Csize_t, Ptr{Float64}, Ptr{Cvoid}) ) end - return _TRAMPOLINE_F64_PTR[] + return _TRAMPOLINE_F64_C64_PTR[] +end + +function _get_trampoline_i64(::Type{Float64}) + if _TRAMPOLINE_I64_F64_PTR[] == C_NULL + _TRAMPOLINE_I64_F64_PTR[] = @cfunction( + _trampoline_i64_f64, + Cint, + (Ptr{Int64}, Csize_t, Ptr{Float64}, Ptr{Cvoid}) + ) + end + return _TRAMPOLINE_I64_F64_PTR[] end -function _get_trampoline_i64_ptr() - if _TRAMPOLINE_I64_PTR[] == C_NULL - _TRAMPOLINE_I64_PTR[] = @cfunction( - _trampoline_i64, +function _get_trampoline_i64(::Type{ComplexF64}) + if _TRAMPOLINE_I64_C64_PTR[] == C_NULL + _TRAMPOLINE_I64_C64_PTR[] = @cfunction( + _trampoline_i64_c64, Cint, (Ptr{Int64}, Csize_t, Ptr{Float64}, Ptr{Cvoid}) ) end - return _TRAMPOLINE_I64_PTR[] + return _TRAMPOLINE_I64_C64_PTR[] +end + +# ============================================================================ +# QtciOptions helper +# ============================================================================ + +""" +Build a QtciOptions handle from keyword arguments. +Caller must call `C_API.t4a_qtci_options_release(opts)` after use. +""" +function _build_options(; + tolerance::Float64 = 1e-8, + maxbonddim::Int = 0, + maxiter::Int = 200, + nrandominitpivot::Int = 5, + verbosity::Int = 0, + unfoldingscheme::Symbol = :interleaved, + nsearchglobalpivot::Int = 5, + nsearch::Int = 100, + normalizeerror::Bool = true, +) + opts = C_API.t4a_qtci_options_default() + opts == C_NULL && error("Failed to create QtciOptions") + + C_API.check_status(C_API.t4a_qtci_options_set_tolerance(opts, tolerance)) + C_API.check_status(C_API.t4a_qtci_options_set_maxbonddim(opts, maxbonddim)) + C_API.check_status(C_API.t4a_qtci_options_set_maxiter(opts, maxiter)) + C_API.check_status(C_API.t4a_qtci_options_set_nrandominitpivot(opts, nrandominitpivot)) + C_API.check_status(C_API.t4a_qtci_options_set_verbosity(opts, verbosity)) + + scheme_c = _unfolding_to_cint(unfoldingscheme) + C_API.check_status(C_API.t4a_qtci_options_set_unfoldingscheme(opts, scheme_c)) + + C_API.check_status(C_API.t4a_qtci_options_set_nsearchglobalpivot(opts, nsearchglobalpivot)) + C_API.check_status(C_API.t4a_qtci_options_set_nsearch(opts, nsearch)) + C_API.check_status(C_API.t4a_qtci_options_set_normalize_error(opts, normalizeerror ? 1 : 0)) + + return opts end # ============================================================================ @@ -91,21 +214,21 @@ end # ============================================================================ """ - QuanticsTensorCI2 + QuanticsTensorCI2{V} A Quantics TCI (Tensor Cross Interpolation) object. -This wraps the Rust QuanticsTCI and provides access to the interpolated -quantics tensor train representation of a function. +`V` can be `Float64` or `ComplexF64`, determining which C API variant is used. # Methods -- `rank(qtci)` - Get the maximum bond dimension -- `link_dims(qtci)` - Get the link (bond) dimensions +- `linkdims(qtci)` - Get the link (bond) dimensions - `evaluate(qtci, indices)` - Evaluate at grid indices - `sum(qtci)` - Compute the factorized sum over all grid points - `integral(qtci)` - Compute the integral over the continuous domain -- `to_tensor_train(qtci)` - Convert to SimpleTensorTrain +- `to_tensor_train(qtci)` - Convert to SimpleTensorTrain{V} +- `maxbonderror(qtci)` - Get the maximum bond error +- `maxrank(qtci)` - Get the maximum rank (bond dimension) # Callable Interface @@ -114,55 +237,87 @@ The object can be called directly with indices: qtci(1, 2) # Equivalent to evaluate(qtci, [1, 2]) ``` """ -mutable struct QuanticsTensorCI2 +mutable struct QuanticsTensorCI2{V<:_QtciScalar} ptr::Ptr{Cvoid} - function QuanticsTensorCI2(ptr::Ptr{Cvoid}) + function QuanticsTensorCI2{V}(ptr::Ptr{Cvoid}) where {V<:_QtciScalar} ptr == C_NULL && error("Failed to create QuanticsTensorCI2: null pointer") - qtci = new(ptr) + qtci = new{V}(ptr) finalizer(qtci) do obj - C_API.t4a_qtci_f64_release(obj.ptr) + _qtci_api(V, :release)(obj.ptr) end return qtci end end +# ============================================================================ +# Accessors +# ============================================================================ + """ rank(qtci::QuanticsTensorCI2) -> Int Get the maximum bond dimension (rank). """ -function rank(qtci::QuanticsTensorCI2) +function rank(qtci::QuanticsTensorCI2{V}) where {V} out_rank = Ref{Csize_t}(0) - status = C_API.t4a_qtci_f64_rank(qtci.ptr, out_rank) + status = _qtci_api(V, :rank)(qtci.ptr, out_rank) C_API.check_status(status) return Int(out_rank[]) end """ - link_dims(qtci::QuanticsTensorCI2) -> Vector{Int} + linkdims(qtci::QuanticsTensorCI2) -> Vector{Int} Get the link (bond) dimensions. """ -function link_dims(qtci::QuanticsTensorCI2) +function linkdims(qtci::QuanticsTensorCI2{V}) where {V} r = rank(qtci) r == 0 && return Int[] - # Use a generous buffer since we don't know the exact number of sites - buf = Vector{Csize_t}(undef, 1024) - status = C_API.t4a_qtci_f64_link_dims(qtci.ptr, buf, Csize_t(length(buf))) + # Convert to SimpleTT and query its linkdims + tt_ptr = _qtci_api(V, :to_tensor_train)(qtci.ptr) + tt_ptr == C_NULL && error("Failed to convert QTCI to tensor train") + tt = SimpleTensorTrain{Float64}(tt_ptr) # to_tensor_train always returns f64 SimpleTT + if V === ComplexF64 + tt = SimpleTensorTrain{ComplexF64}(tt_ptr) + end + return SimpleTT.linkdims(tt) +end + +""" + maxbonderror(qtci::QuanticsTensorCI2) -> Float64 + +Get the maximum bond error from the QuanticsTCI. +""" +function maxbonderror(qtci::QuanticsTensorCI2{V}) where {V} + out_value = Ref{Cdouble}(0.0) + status = _qtci_api(V, :max_bond_error)(qtci.ptr, out_value) + C_API.check_status(status) + return out_value[] +end + +""" + maxrank(qtci::QuanticsTensorCI2) -> Int + +Get the maximum rank (bond dimension) from the QuanticsTCI. +""" +function maxrank(qtci::QuanticsTensorCI2{V}) where {V} + out_rank = Ref{Csize_t}(0) + status = _qtci_api(V, :max_rank)(qtci.ptr, out_rank) C_API.check_status(status) - # Find actual length by looking for trailing zeros - result = Int.(buf) - last_nonzero = findlast(x -> x > 0, result) - return isnothing(last_nonzero) ? Int[] : result[1:last_nonzero] + return Int(out_rank[]) end +# ============================================================================ +# Operations: Float64 +# ============================================================================ + """ - evaluate(qtci::QuanticsTensorCI2, indices::Vector{<:Integer}) -> Float64 + evaluate(qtci::QuanticsTensorCI2{Float64}, indices::Vector{<:Integer}) -> Float64 -Evaluate the QTCI at the given grid indices (1-indexed). +Evaluate the QTCI at the given grid indices. """ -function evaluate(qtci::QuanticsTensorCI2, indices::Vector{<:Integer}) +function evaluate(qtci::QuanticsTensorCI2{Float64}, indices::Vector{<:Integer}) idx = Int64.(indices) out_value = Ref{Cdouble}(0.0) status = C_API.t4a_qtci_f64_evaluate(qtci.ptr, idx, Csize_t(length(idx)), out_value) @@ -170,17 +325,12 @@ function evaluate(qtci::QuanticsTensorCI2, indices::Vector{<:Integer}) return out_value[] end -# Callable interface: qtci(i, j, ...) -function (qtci::QuanticsTensorCI2)(indices::Integer...) - evaluate(qtci, collect(Int64, indices)) -end - """ - sum(qtci::QuanticsTensorCI2) -> Float64 + Base.sum(qtci::QuanticsTensorCI2{Float64}) -> Float64 Compute the factorized sum over all grid points. """ -function sum(qtci::QuanticsTensorCI2) +function Base.sum(qtci::QuanticsTensorCI2{Float64}) out_value = Ref{Cdouble}(0.0) status = C_API.t4a_qtci_f64_sum(qtci.ptr, out_value) C_API.check_status(status) @@ -188,15 +338,11 @@ function sum(qtci::QuanticsTensorCI2) end """ - integral(qtci::QuanticsTensorCI2) -> Float64 + integral(qtci::QuanticsTensorCI2{Float64}) -> Float64 Compute the integral over the continuous domain. - -This is the sum multiplied by the grid step sizes. -Only meaningful for QTCI constructed with a DiscretizedGrid. -For discrete grids, this returns the plain sum. """ -function integral(qtci::QuanticsTensorCI2) +function integral(qtci::QuanticsTensorCI2{Float64}) out_value = Ref{Cdouble}(0.0) status = C_API.t4a_qtci_f64_integral(qtci.ptr, out_value) C_API.check_status(status) @@ -204,20 +350,86 @@ function integral(qtci::QuanticsTensorCI2) end """ - to_tensor_train(qtci::QuanticsTensorCI2) -> SimpleTensorTrain + to_tensor_train(qtci::QuanticsTensorCI2{Float64}) -> SimpleTensorTrain{Float64} Convert the QTCI to a SimpleTensorTrain. """ -function to_tensor_train(qtci::QuanticsTensorCI2) +function to_tensor_train(qtci::QuanticsTensorCI2{Float64}) ptr = C_API.t4a_qtci_f64_to_tensor_train(qtci.ptr) ptr == C_NULL && error("Failed to convert QTCI to TensorTrain") - return SimpleTensorTrain(ptr) + return SimpleTensorTrain{Float64}(ptr) end +# ============================================================================ +# Operations: ComplexF64 +# ============================================================================ + +""" + evaluate(qtci::QuanticsTensorCI2{ComplexF64}, indices::Vector{<:Integer}) -> ComplexF64 + +Evaluate the QTCI at the given grid indices. +""" +function evaluate(qtci::QuanticsTensorCI2{ComplexF64}, indices::Vector{<:Integer}) + idx = Int64.(indices) + out_re = Ref{Cdouble}(0.0) + out_im = Ref{Cdouble}(0.0) + status = C_API.t4a_qtci_c64_evaluate(qtci.ptr, idx, Csize_t(length(idx)), out_re, out_im) + C_API.check_status(status) + return ComplexF64(out_re[], out_im[]) +end + +""" + Base.sum(qtci::QuanticsTensorCI2{ComplexF64}) -> ComplexF64 + +Compute the factorized sum over all grid points. +""" +function Base.sum(qtci::QuanticsTensorCI2{ComplexF64}) + out_re = Ref{Cdouble}(0.0) + out_im = Ref{Cdouble}(0.0) + status = C_API.t4a_qtci_c64_sum(qtci.ptr, out_re, out_im) + C_API.check_status(status) + return ComplexF64(out_re[], out_im[]) +end + +""" + integral(qtci::QuanticsTensorCI2{ComplexF64}) -> ComplexF64 + +Compute the integral over the continuous domain. +""" +function integral(qtci::QuanticsTensorCI2{ComplexF64}) + out_re = Ref{Cdouble}(0.0) + out_im = Ref{Cdouble}(0.0) + status = C_API.t4a_qtci_c64_integral(qtci.ptr, out_re, out_im) + C_API.check_status(status) + return ComplexF64(out_re[], out_im[]) +end + +""" + to_tensor_train(qtci::QuanticsTensorCI2{ComplexF64}) -> SimpleTensorTrain{ComplexF64} + +Convert the QTCI to a SimpleTensorTrain. +""" +function to_tensor_train(qtci::QuanticsTensorCI2{ComplexF64}) + ptr = C_API.t4a_qtci_c64_to_tensor_train(qtci.ptr) + ptr == C_NULL && error("Failed to convert QTCI to TensorTrain") + return SimpleTensorTrain{ComplexF64}(ptr) +end + +# ============================================================================ +# Callable interface +# ============================================================================ + +function (qtci::QuanticsTensorCI2)(indices::Integer...) + evaluate(qtci, collect(Int64, indices)) +end + +# ============================================================================ # Display -function Base.show(io::IO, qtci::QuanticsTensorCI2) - r = rank(qtci) - print(io, "QuanticsTensorCI2(rank=$r)") +# ============================================================================ + +function Base.show(io::IO, qtci::QuanticsTensorCI2{V}) where {V} + r = maxrank(qtci) + print(io, "QuanticsTensorCI2{$V}(maxrank=$r)") end # ============================================================================ @@ -225,18 +437,31 @@ end # ============================================================================ """ - quanticscrossinterpolate(grid::DiscretizedGrid, f; kwargs...) -> QuanticsTensorCI2 + quanticscrossinterpolate(::Type{V}, f, grid::DiscretizedGrid; kwargs...) -> (QuanticsTensorCI2{V}, Vector{Int}, Vector{Float64}) Perform quantics cross interpolation on a continuous domain. # Arguments +- `V`: Element type (`Float64` or `ComplexF64`) +- `f`: Function that takes Float64 coordinates and returns `V` - `grid`: DiscretizedGrid describing the domain -- `f`: Function that takes Float64 coordinates and returns Float64 # Keyword Arguments -- `tolerance`: Convergence tolerance (default: 1e-8) -- `max_bonddim`: Maximum bond dimension, 0 = unlimited (default: 0) -- `max_iter`: Maximum iterations (default: 200) +- `tolerance::Float64 = 1e-8`: Convergence tolerance +- `maxbonddim::Int = 0`: Maximum bond dimension (0 = unlimited) +- `maxiter::Int = 200`: Maximum iterations +- `initialpivots::Union{Nothing, Vector{Vector{Int}}} = nothing`: Initial pivots (1-indexed) +- `nrandominitpivot::Int = 5`: Number of random initial pivots +- `verbosity::Int = 0`: Verbosity level +- `unfoldingscheme::Symbol = :interleaved`: `:interleaved` or `:fused` +- `nsearchglobalpivot::Int = 5`: Number of global pivot searches +- `nsearch::Int = 100`: Number of searches +- `normalizeerror::Bool = true`: Whether to normalize errors + +# Returns +- `qtci::QuanticsTensorCI2{V}`: The QTCI object +- `ranks::Vector{Int}`: Per-iteration maximum ranks +- `errors::Vector{Float64}`: Per-iteration errors # Example ```julia @@ -244,97 +469,228 @@ using Tensor4all.QuanticsGrids using Tensor4all.QuanticsTCI grid = DiscretizedGrid(1, 10, [0.0], [1.0]) -qtci = quanticscrossinterpolate(grid, x -> sin(x)) +qtci, ranks, errors = quanticscrossinterpolate(Float64, x -> sin(x), grid) integral(qtci) # Should be close to 1 - cos(1) ~ 0.4597 ``` """ function quanticscrossinterpolate( - grid::DiscretizedGrid, - f; + ::Type{V}, + f, + grid::DiscretizedGrid; tolerance::Float64 = 1e-8, - max_bonddim::Int = 0, - max_iter::Int = 200, -) - f_ref = Ref{Any}(f) - out_qtci = Ref{Ptr{Cvoid}}(C_NULL) - - GC.@preserve f_ref begin - user_data = pointer_from_objref(f_ref) - trampoline_ptr = _get_trampoline_f64_ptr() - - status = C_API.t4a_quanticscrossinterpolate_f64( - grid.ptr, - trampoline_ptr, - user_data, - tolerance, - Csize_t(max_bonddim), - Csize_t(max_iter), - out_qtci, - ) + maxbonddim::Int = 0, + maxiter::Int = 200, + initialpivots::Union{Nothing, Vector{Vector{Int}}} = nothing, + nrandominitpivot::Int = 5, + verbosity::Int = 0, + unfoldingscheme::Symbol = :interleaved, + nsearchglobalpivot::Int = 5, + nsearch::Int = 100, + normalizeerror::Bool = true, +) where {V<:_QtciScalar} + # Build options + opts = _build_options(; + tolerance, maxbonddim, maxiter, nrandominitpivot, + verbosity, unfoldingscheme, nsearchglobalpivot, nsearch, normalizeerror, + ) - C_API.check_status(status) - end + try + # Prepare initial pivots (convert 1-indexed to C API format if needed) + if initialpivots !== nothing && !isempty(initialpivots) + n_sites = length(initialpivots[1]) + n_pivots = length(initialpivots) + flat_pivots = Vector{Int64}(undef, n_sites * n_pivots) + for j in 1:n_pivots + for i in 1:n_sites + flat_pivots[i + n_sites * (j - 1)] = Int64(initialpivots[j][i]) + end + end + pivots_ptr = flat_pivots + pivots_n = Csize_t(n_pivots) + else + pivots_ptr = C_NULL + pivots_n = Csize_t(0) + end - return QuanticsTensorCI2(out_qtci[]) + # Prepare output buffers + out_qtci = Ref{Ptr{Cvoid}}(C_NULL) + out_ranks = Vector{Csize_t}(undef, maxiter) + out_errors = Vector{Cdouble}(undef, maxiter) + out_n_iters = Ref{Csize_t}(0) + + f_ref = Ref{Any}(f) + GC.@preserve f_ref pivots_ptr begin + user_data = pointer_from_objref(f_ref) + trampoline_ptr = _get_trampoline_f64(V) + + crossinterp_fn = if V === Float64 + C_API.t4a_quanticscrossinterpolate_f64 + else + C_API.t4a_quanticscrossinterpolate_c64 + end + + status = crossinterp_fn( + grid.ptr, + trampoline_ptr, + user_data, + opts, + tolerance, + Csize_t(maxbonddim), + Csize_t(maxiter), + pivots_ptr, + pivots_n, + out_qtci, + out_ranks, + out_errors, + out_n_iters, + ) + + C_API.check_status(status) + end + + n_iters = Int(out_n_iters[]) + ranks = Vector{Int}(out_ranks[1:n_iters]) + errors = Vector{Float64}(out_errors[1:n_iters]) + qtci = QuanticsTensorCI2{V}(out_qtci[]) + + return qtci, ranks, errors + finally + C_API.t4a_qtci_options_release(opts) + end end """ - quanticscrossinterpolate_discrete(sizes, f; kwargs...) -> QuanticsTensorCI2 + quanticscrossinterpolate(::Type{V}, f, size::NTuple{N,Int}; kwargs...) -> (QuanticsTensorCI2{V}, Vector{Int}, Vector{Float64}) Perform quantics cross interpolation on a discrete integer domain. # Arguments -- `sizes`: Grid sizes per dimension (must be powers of 2) -- `f`: Function that takes 1-indexed integer indices and returns Float64 +- `V`: Element type (`Float64` or `ComplexF64`) +- `f`: Function that takes 1-indexed integer indices and returns `V` +- `size`: Grid sizes per dimension (must be powers of 2) # Keyword Arguments -- `tolerance`: Convergence tolerance (default: 1e-8) -- `max_bonddim`: Maximum bond dimension, 0 = unlimited (default: 0) -- `max_iter`: Maximum iterations (default: 200) -- `unfoldingscheme`: :interleaved or :fused (default: :interleaved) +Same as the `DiscretizedGrid` variant, plus: +- `unfoldingscheme::Symbol = :interleaved`: `:interleaved` or `:fused` # Example ```julia -using Tensor4all.QuanticsTCI - -qtci = quanticscrossinterpolate_discrete([8, 8], (i, j) -> Float64(i + j)) +qtci, ranks, errors = quanticscrossinterpolate(Float64, (i, j) -> Float64(i + j), (8, 8)) qtci(3, 4) # Should be close to 7.0 ``` """ -function quanticscrossinterpolate_discrete( - sizes::Vector{<:Integer}, - f; +function quanticscrossinterpolate( + ::Type{V}, + f, + size::NTuple{N,Int}; tolerance::Float64 = 1e-8, - max_bonddim::Int = 0, - max_iter::Int = 200, + maxbonddim::Int = 0, + maxiter::Int = 200, + initialpivots::Union{Nothing, Vector{Vector{Int}}} = nothing, + nrandominitpivot::Int = 5, + verbosity::Int = 0, unfoldingscheme::Symbol = :interleaved, -) - scheme = unfoldingscheme == :fused ? Cint(0) : Cint(1) - - f_ref = Ref{Any}(f) - out_qtci = Ref{Ptr{Cvoid}}(C_NULL) - sizes_c = Csize_t.(sizes) - - GC.@preserve f_ref begin - user_data = pointer_from_objref(f_ref) - trampoline_ptr = _get_trampoline_i64_ptr() - - status = C_API.t4a_quanticscrossinterpolate_discrete_f64( - sizes_c, - Csize_t(length(sizes)), - trampoline_ptr, - user_data, - tolerance, - Csize_t(max_bonddim), - Csize_t(max_iter), - scheme, - out_qtci, - ) + nsearchglobalpivot::Int = 5, + nsearch::Int = 100, + normalizeerror::Bool = true, +) where {V<:_QtciScalar, N} + # Build options + opts = _build_options(; + tolerance, maxbonddim, maxiter, nrandominitpivot, + verbosity, unfoldingscheme, nsearchglobalpivot, nsearch, normalizeerror, + ) + + try + sizes_c = Csize_t.(collect(size)) + scheme_c = _unfolding_to_cint(unfoldingscheme) + + # Prepare initial pivots + if initialpivots !== nothing && !isempty(initialpivots) + n_sites_pivot = length(initialpivots[1]) + n_pivots = length(initialpivots) + flat_pivots = Vector{Int64}(undef, n_sites_pivot * n_pivots) + for j in 1:n_pivots + for i in 1:n_sites_pivot + flat_pivots[i + n_sites_pivot * (j - 1)] = Int64(initialpivots[j][i]) + end + end + pivots_ptr = flat_pivots + pivots_n = Csize_t(n_pivots) + else + pivots_ptr = C_NULL + pivots_n = Csize_t(0) + end + + # Prepare output buffers + out_qtci = Ref{Ptr{Cvoid}}(C_NULL) + out_ranks = Vector{Csize_t}(undef, maxiter) + out_errors = Vector{Cdouble}(undef, maxiter) + out_n_iters = Ref{Csize_t}(0) + + f_ref = Ref{Any}(f) + GC.@preserve f_ref pivots_ptr begin + user_data = pointer_from_objref(f_ref) + trampoline_ptr = _get_trampoline_i64(V) + + crossinterp_fn = if V === Float64 + C_API.t4a_quanticscrossinterpolate_discrete_f64 + else + C_API.t4a_quanticscrossinterpolate_discrete_c64 + end + + status = crossinterp_fn( + sizes_c, + Csize_t(N), + trampoline_ptr, + user_data, + opts, + tolerance, + Csize_t(maxbonddim), + Csize_t(maxiter), + scheme_c, + pivots_ptr, + pivots_n, + out_qtci, + out_ranks, + out_errors, + out_n_iters, + ) + + C_API.check_status(status) + end + + n_iters = Int(out_n_iters[]) + ranks = Vector{Int}(out_ranks[1:n_iters]) + errors = Vector{Float64}(out_errors[1:n_iters]) + qtci = QuanticsTensorCI2{V}(out_qtci[]) - C_API.check_status(status) + return qtci, ranks, errors + finally + C_API.t4a_qtci_options_release(opts) end +end + +""" + quanticscrossinterpolate(F::Array{V}; kwargs...) -> (QuanticsTensorCI2{V}, Vector{Int}, Vector{Float64}) + +Perform quantics cross interpolation from a dense array. + +The array is wrapped as a function and the size tuple variant is called. +Grid sizes are taken from `size(F)`. - return QuanticsTensorCI2(out_qtci[]) +# Example +```julia +F = [Float64(i + j) for i in 1:8, j in 1:8] +qtci, ranks, errors = quanticscrossinterpolate(F) +``` +""" +function quanticscrossinterpolate( + F::Array{V}; + kwargs... +) where {V<:_QtciScalar} + sz = Base.size(F) + f = (indices...) -> F[indices...] + return quanticscrossinterpolate(V, f, sz; kwargs...) end end # module QuanticsTCI diff --git a/src/QuanticsTransform.jl b/src/QuanticsTransform.jl index 5b4c931..36c3ff1 100644 --- a/src/QuanticsTransform.jl +++ b/src/QuanticsTransform.jl @@ -22,7 +22,10 @@ using ..TreeTN: TreeTensorNetwork export LinearOperator export shift_operator, flip_operator, phase_rotation_operator, cumsum_operator, fourier_operator +export shift_operator_multivar, flip_operator_multivar, phase_rotation_operator_multivar +export affine_operator, binaryop_operator export apply +export set_input_space!, set_output_space!, set_iospaces! export BoundaryCondition, Periodic, Open # ============================================================================ @@ -42,6 +45,12 @@ Boundary condition for quantics operators. Open = 1 end +_bc_cint(bc::BoundaryCondition) = Cint(Int(bc)) + +function _bc_array_cint(bc::AbstractVector{<:BoundaryCondition}) + return Cint[Int(b) for b in bc] +end + # ============================================================================ # LinearOperator type # ============================================================================ @@ -83,7 +92,21 @@ Create a shift operator: f(x) = g(x + offset) mod 2^r. """ function shift_operator(r::Integer, offset::Integer; bc::BoundaryCondition=Periodic) out = Ref{Ptr{Cvoid}}(C_NULL) - status = C_API.t4a_qtransform_shift(Csize_t(r), Int64(offset), Cint(Int(bc)), out) + status = C_API.t4a_qtransform_shift(Csize_t(r), Int64(offset), _bc_cint(bc), out) + C_API.check_status(status) + return LinearOperator(out[]) +end + +""" + shift_operator_multivar(r::Integer, offset::Integer, nvariables::Integer, target_var::Integer; bc=Periodic) -> LinearOperator + +Create a shift operator acting on one variable in a multi-variable quantics system. +""" +function shift_operator_multivar(r::Integer, offset::Integer, nvariables::Integer, + target_var::Integer; bc::BoundaryCondition=Periodic) + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_qtransform_shift_multivar( + Csize_t(r), Int64(offset), _bc_cint(bc), Csize_t(nvariables), Csize_t(target_var), out) C_API.check_status(status) return LinearOperator(out[]) end @@ -99,7 +122,21 @@ Create a flip operator: f(x) = g(2^r - x). """ function flip_operator(r::Integer; bc::BoundaryCondition=Periodic) out = Ref{Ptr{Cvoid}}(C_NULL) - status = C_API.t4a_qtransform_flip(Csize_t(r), Cint(Int(bc)), out) + status = C_API.t4a_qtransform_flip(Csize_t(r), _bc_cint(bc), out) + C_API.check_status(status) + return LinearOperator(out[]) +end + +""" + flip_operator_multivar(r::Integer, nvariables::Integer, target_var::Integer; bc=Periodic) -> LinearOperator + +Create a flip operator acting on one variable in a multi-variable quantics system. +""" +function flip_operator_multivar(r::Integer, nvariables::Integer, target_var::Integer; + bc::BoundaryCondition=Periodic) + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_qtransform_flip_multivar( + Csize_t(r), _bc_cint(bc), Csize_t(nvariables), Csize_t(target_var), out) C_API.check_status(status) return LinearOperator(out[]) end @@ -120,6 +157,71 @@ function phase_rotation_operator(r::Integer, theta::Real) return LinearOperator(out[]) end +""" + phase_rotation_operator_multivar(r::Integer, theta::Real, nvariables::Integer, target_var::Integer) -> LinearOperator + +Create a phase rotation operator acting on one variable in a multi-variable quantics system. +""" +function phase_rotation_operator_multivar(r::Integer, theta::Real, nvariables::Integer, + target_var::Integer) + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_qtransform_phase_rotation_multivar( + Csize_t(r), Cdouble(theta), Csize_t(nvariables), Csize_t(target_var), out) + C_API.check_status(status) + return LinearOperator(out[]) +end + +""" + affine_operator(r::Integer, a_num::AbstractMatrix{<:Integer}, a_den::AbstractMatrix{<:Integer}, + b_num::AbstractVector{<:Integer}, b_den::AbstractVector{<:Integer}; bc) -> LinearOperator + +Create an affine transform with rational coefficients on quantized coordinates. +`bc` must contain one boundary condition per output variable. +""" +function affine_operator(r::Integer, + a_num::AbstractMatrix{<:Integer}, + a_den::AbstractMatrix{<:Integer}, + b_num::AbstractVector{<:Integer}, + b_den::AbstractVector{<:Integer}; + bc::AbstractVector{<:BoundaryCondition}) + size(a_num) == size(a_den) || error("a_num and a_den must have the same size") + m, n = size(a_num) + length(b_num) == m || error("b_num length must match the number of output variables") + length(b_den) == m || error("b_den length must match the number of output variables") + length(bc) == m || error("bc length must match the number of output variables") + + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_qtransform_affine( + Csize_t(r), + Int64.(vec(a_num)), + Int64.(vec(a_den)), + Int64.(collect(b_num)), + Int64.(collect(b_den)), + Csize_t(m), + Csize_t(n), + _bc_array_cint(bc), + out, + ) + C_API.check_status(status) + return LinearOperator(out[]) +end + +""" + binaryop_operator(r::Integer, a1::Integer, b1::Integer, a2::Integer, b2::Integer; + bc1=Periodic, bc2=Periodic) -> LinearOperator + +Create a two-output binary operator corresponding to +`(a1*x + b1*y, a2*x + b2*y)`. +""" +function binaryop_operator(r::Integer, a1::Integer, b1::Integer, a2::Integer, b2::Integer; + bc1::BoundaryCondition=Periodic, bc2::BoundaryCondition=Periodic) + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_qtransform_binaryop( + Csize_t(r), Int8(a1), Int8(b1), Int8(a2), Int8(b2), _bc_cint(bc1), _bc_cint(bc2), out) + C_API.check_status(status) + return LinearOperator(out[]) +end + """ cumsum_operator(r::Integer) -> LinearOperator @@ -160,6 +262,40 @@ end # Operator application # ============================================================================ +""" + set_input_space!(op::LinearOperator, state::TreeTensorNetwork) -> LinearOperator + +Reset the operator's true input site indices to match `state`. +""" +function set_input_space!(op::LinearOperator, state::TreeTensorNetwork) + status = C_API.t4a_linop_set_input_space(op.ptr, state.handle) + C_API.check_status(status) + return op +end + +""" + set_output_space!(op::LinearOperator, state::TreeTensorNetwork) -> LinearOperator + +Reset the operator's true output site indices to match `state`. +""" +function set_output_space!(op::LinearOperator, state::TreeTensorNetwork) + status = C_API.t4a_linop_set_output_space(op.ptr, state.handle) + C_API.check_status(status) + return op +end + +""" + set_iospaces!(op::LinearOperator, input_state::TreeTensorNetwork, output_state::TreeTensorNetwork=input_state) -> LinearOperator + +Reset the operator's true input and output site indices to match the given states. +""" +function set_iospaces!(op::LinearOperator, input_state::TreeTensorNetwork, + output_state::TreeTensorNetwork=input_state) + set_input_space!(op, input_state) + set_output_space!(op, output_state) + return op +end + """ apply(op::LinearOperator, state::TreeTensorNetwork; method=:naive, rtol=0.0, maxdim=0) -> TreeTensorNetwork diff --git a/src/SimpleTT.jl b/src/SimpleTT.jl index 447b358..535ebce 100644 --- a/src/SimpleTT.jl +++ b/src/SimpleTT.jl @@ -5,117 +5,250 @@ High-level Julia wrappers for SimpleTT tensor trains from tensor4all-simplett. SimpleTT is a simple tensor train (TT/MPS) library with statically determined shapes (site dimensions are fixed at construction time). + +Supports both `Float64` and `ComplexF64` element types. """ module SimpleTT using LinearAlgebra using ..C_API +import ..rank, ..linkdims, ..compress!, ..evaluate + +export SimpleTensorTrain, sitedims, linkdims, rank, compress!, evaluate, sitetensor, fulltensor, scale! + +# ============================================================================ +# Type dispatch helpers +# ============================================================================ -export SimpleTensorTrain +"""Scalar types supported by SimpleTensorTrain.""" +const _SimpleTTScalar = Union{Float64, ComplexF64} + +_suffix(::Type{Float64}) = "f64" +_suffix(::Type{ComplexF64}) = "c64" """ - SimpleTensorTrain{T<:Real} + _api(T, name) -> Function -A simple tensor train (TT/MPS) with statically determined shapes. +Get the C API function for the given scalar type and operation name. +E.g., `_api(Float64, :clone)` returns `C_API.t4a_simplett_f64_clone`. +""" +_api(::Type{T}, name::Symbol) where {T<:_SimpleTTScalar} = + getfield(C_API, Symbol("t4a_simplett_", _suffix(T), "_", name)) -SimpleTensorTrain is a simple tensor train library where site dimensions are fixed -at construction time. +# ============================================================================ +# Type definition +# ============================================================================ -Currently only supports `Float64` values. """ -mutable struct SimpleTensorTrain{T<:Real} + SimpleTensorTrain{T} + +A simple tensor train (TT/MPS) with statically determined shapes. + +`T` can be `Float64` or `ComplexF64`. +""" +mutable struct SimpleTensorTrain{T<:_SimpleTTScalar} ptr::Ptr{Cvoid} - function SimpleTensorTrain{Float64}(ptr::Ptr{Cvoid}) + function SimpleTensorTrain{T}(ptr::Ptr{Cvoid}) where {T<:_SimpleTTScalar} ptr == C_NULL && error("Failed to create SimpleTensorTrain: null pointer") - tt = new{Float64}(ptr) + tt = new{T}(ptr) finalizer(tt) do obj - C_API.t4a_simplett_f64_release(obj.ptr) + _api(T, :release)(obj.ptr) end return tt end end -# Convenience constructor -SimpleTensorTrain(ptr::Ptr{Cvoid}) = SimpleTensorTrain{Float64}(ptr) +# ============================================================================ +# Constructors +# ============================================================================ """ - SimpleTensorTrain(site_dims::Vector{<:Integer}, value::Real) + SimpleTensorTrain(site_dims::Vector{<:Integer}, value::Float64) -Create a constant tensor train with the given site dimensions and value. +Create a constant Float64 tensor train with the given site dimensions and value. # Example ```julia tt = SimpleTensorTrain([2, 3, 4], 1.0) # All elements equal to 1.0 ``` """ -function SimpleTensorTrain(site_dims::Vector{<:Integer}, value::Real) +function SimpleTensorTrain(site_dims::Vector{<:Integer}, value::Float64) dims = Csize_t.(site_dims) - ptr = C_API.t4a_simplett_f64_constant(dims, Float64(value)) + ptr = C_API.t4a_simplett_f64_constant(dims, value) return SimpleTensorTrain{Float64}(ptr) end +""" + SimpleTensorTrain(site_dims::Vector{<:Integer}, value::ComplexF64) + +Create a constant ComplexF64 tensor train with the given site dimensions and value. + +# Example +```julia +tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) +``` +""" +function SimpleTensorTrain(site_dims::Vector{<:Integer}, value::ComplexF64) + dims = Csize_t.(site_dims) + ptr = C_API.t4a_simplett_c64_constant(dims, real(value), imag(value)) + return SimpleTensorTrain{ComplexF64}(ptr) +end + +# Allow implicit conversion from Real to Float64 +function SimpleTensorTrain(site_dims::Vector{<:Integer}, value::Real) + return SimpleTensorTrain(site_dims, Float64(value)) +end + +# Allow implicit conversion from Complex to ComplexF64 (non-Float64 real part) +function SimpleTensorTrain(site_dims::Vector{<:Integer}, value::Complex) + return SimpleTensorTrain(site_dims, ComplexF64(value)) +end + +""" + SimpleTensorTrain(site_tensors::Vector{<:AbstractArray{Float64,3}}) + +Create a Float64 tensor train from a vector of 3D site tensors. +Each tensor has shape `(left_dim, site_dim, right_dim)`. +The first tensor must have `left_dim == 1` and the last must have `right_dim == 1`. +""" +function SimpleTensorTrain(site_tensors::Vector{<:AbstractArray{Float64,3}}) + n_sites = length(site_tensors) + n_sites > 0 || error("site_tensors must be non-empty") + + left_dims_vec = Csize_t[] + site_dims_vec = Csize_t[] + right_dims_vec = Csize_t[] + all_data = Cdouble[] + + for tensor in site_tensors + l, s, r = size(tensor) + push!(left_dims_vec, Csize_t(l)) + push!(site_dims_vec, Csize_t(s)) + push!(right_dims_vec, Csize_t(r)) + # Data is already column-major in Julia, matching C API expectation + append!(all_data, vec(tensor)) + end + + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_simplett_f64_from_site_tensors( + n_sites, left_dims_vec, site_dims_vec, right_dims_vec, + all_data, length(all_data), out_ptr + ) + C_API.check_status(status) + return SimpleTensorTrain{Float64}(out_ptr[]) +end + +""" + SimpleTensorTrain(site_tensors::Vector{<:AbstractArray{ComplexF64,3}}) + +Create a ComplexF64 tensor train from a vector of 3D site tensors. +Each tensor has shape `(left_dim, site_dim, right_dim)`. +Complex data is passed as interleaved (re, im) doubles. +""" +function SimpleTensorTrain(site_tensors::Vector{<:AbstractArray{ComplexF64,3}}) + n_sites = length(site_tensors) + n_sites > 0 || error("site_tensors must be non-empty") + + left_dims_vec = Csize_t[] + site_dims_vec = Csize_t[] + right_dims_vec = Csize_t[] + all_data = Cdouble[] + + for tensor in site_tensors + l, s, r = size(tensor) + push!(left_dims_vec, Csize_t(l)) + push!(site_dims_vec, Csize_t(s)) + push!(right_dims_vec, Csize_t(r)) + # Reinterpret ComplexF64 to interleaved doubles + flat = vec(tensor) + interleaved = reinterpret(Float64, flat) + append!(all_data, interleaved) + end + + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_simplett_c64_from_site_tensors( + n_sites, left_dims_vec, site_dims_vec, right_dims_vec, + all_data, length(all_data), out_ptr + ) + C_API.check_status(status) + return SimpleTensorTrain{ComplexF64}(out_ptr[]) +end + """ zeros(::Type{SimpleTensorTrain}, site_dims::Vector{<:Integer}) + zeros(::Type{SimpleTensorTrain{T}}, site_dims::Vector{<:Integer}) Create a zero tensor train with the given site dimensions. # Example ```julia -tt = zeros(SimpleTensorTrain, [2, 3, 4]) +tt = zeros(SimpleTensorTrain, [2, 3, 4]) # Float64 +tt = zeros(SimpleTensorTrain{ComplexF64}, [2, 3]) # ComplexF64 ``` """ -function Base.zeros(::Type{SimpleTensorTrain}, site_dims::Vector{<:Integer}) +function Base.zeros(::Type{SimpleTensorTrain{T}}, site_dims::Vector{<:Integer}) where {T<:_SimpleTTScalar} dims = Csize_t.(site_dims) - ptr = C_API.t4a_simplett_f64_zeros(dims) - return SimpleTensorTrain{Float64}(ptr) + ptr = _api(T, :zeros)(dims) + return SimpleTensorTrain{T}(ptr) end +Base.zeros(::Type{SimpleTensorTrain}, site_dims::Vector{<:Integer}) = + zeros(SimpleTensorTrain{Float64}, site_dims) + +# ============================================================================ +# Copy +# ============================================================================ + """ copy(tt::SimpleTensorTrain) Create a deep copy of the tensor train. """ -function Base.copy(tt::SimpleTensorTrain{Float64}) - new_ptr = C_API.t4a_simplett_f64_clone(tt.ptr) - return SimpleTensorTrain{Float64}(new_ptr) +function Base.copy(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} + new_ptr = _api(T, :clone)(tt.ptr) + return SimpleTensorTrain{T}(new_ptr) end +# ============================================================================ +# Basic queries +# ============================================================================ + """ length(tt::SimpleTensorTrain) -> Int Get the number of sites in the tensor train. """ -function Base.length(tt::SimpleTensorTrain{Float64}) +function Base.length(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} out_len = Ref{Csize_t}(0) - status = C_API.t4a_simplett_f64_len(tt.ptr, out_len) + status = _api(T, :len)(tt.ptr, out_len) C_API.check_status(status) return Int(out_len[]) end """ - site_dims(tt::SimpleTensorTrain) -> Vector{Int} + sitedims(tt::SimpleTensorTrain) -> Vector{Int} Get the site (physical) dimensions. """ -function site_dims(tt::SimpleTensorTrain{Float64}) +function sitedims(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} n = length(tt) dims = Vector{Csize_t}(undef, n) - status = C_API.t4a_simplett_f64_site_dims(tt.ptr, dims) + status = _api(T, :site_dims)(tt.ptr, dims) C_API.check_status(status) return Int.(dims) end """ - link_dims(tt::SimpleTensorTrain) -> Vector{Int} + linkdims(tt::SimpleTensorTrain) -> Vector{Int} Get the link (bond) dimensions. Returns n-1 values for n sites. """ -function link_dims(tt::SimpleTensorTrain{Float64}) +function linkdims(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} n = length(tt) n <= 1 && return Int[] dims = Vector{Csize_t}(undef, n - 1) - status = C_API.t4a_simplett_f64_link_dims(tt.ptr, dims) + status = _api(T, :link_dims)(tt.ptr, dims) C_API.check_status(status) return Int.(dims) end @@ -125,38 +258,60 @@ end Get the maximum bond dimension (rank). """ -function rank(tt::SimpleTensorTrain{Float64}) +function rank(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} out_rank = Ref{Csize_t}(0) - status = C_API.t4a_simplett_f64_rank(tt.ptr, out_rank) + status = _api(T, :rank)(tt.ptr, out_rank) C_API.check_status(status) return Int(out_rank[]) end +# ============================================================================ +# Evaluate (1-indexed) +# ============================================================================ + """ - evaluate(tt::SimpleTensorTrain, indices::Vector{<:Integer}) -> Float64 + evaluate(tt::SimpleTensorTrain{Float64}, indices::Vector{<:Integer}) -> Float64 -Evaluate the tensor train at a given multi-index (0-based indexing). +Evaluate the tensor train at a given multi-index (1-based indexing). # Example ```julia tt = SimpleTensorTrain([2, 3, 4], 2.0) -val = evaluate(tt, [0, 1, 2]) # Returns 2.0 +val = evaluate(tt, [1, 2, 3]) # Returns 2.0 ``` """ function evaluate(tt::SimpleTensorTrain{Float64}, indices::Vector{<:Integer}) - idx = Csize_t.(indices) + idx = Csize_t.(indices .- 1) # Convert to 0-based out_value = Ref{Cdouble}(0.0) status = C_API.t4a_simplett_f64_evaluate(tt.ptr, idx, out_value) C_API.check_status(status) return out_value[] end -# Callable interface -(tt::SimpleTensorTrain{Float64})(indices::Vector{<:Integer}) = evaluate(tt, indices) -(tt::SimpleTensorTrain{Float64})(indices::Integer...) = evaluate(tt, collect(indices)) +""" + evaluate(tt::SimpleTensorTrain{ComplexF64}, indices::Vector{<:Integer}) -> ComplexF64 +Evaluate the complex tensor train at a given multi-index (1-based indexing). """ - sum(tt::SimpleTensorTrain) -> Float64 +function evaluate(tt::SimpleTensorTrain{ComplexF64}, indices::Vector{<:Integer}) + idx = Csize_t.(indices .- 1) # Convert to 0-based + out_re = Ref{Cdouble}(0.0) + out_im = Ref{Cdouble}(0.0) + status = C_API.t4a_simplett_c64_evaluate(tt.ptr, idx, out_re, out_im) + C_API.check_status(status) + return ComplexF64(out_re[], out_im[]) +end + +# Callable interface (1-indexed) +(tt::SimpleTensorTrain{T})(indices::Vector{<:Integer}) where {T<:_SimpleTTScalar} = evaluate(tt, indices) +(tt::SimpleTensorTrain{T})(indices::Integer...) where {T<:_SimpleTTScalar} = evaluate(tt, collect(indices)) + +# ============================================================================ +# Sum, norm +# ============================================================================ + +""" + sum(tt::SimpleTensorTrain{Float64}) -> Float64 Compute the sum over all tensor train elements. """ @@ -167,35 +322,52 @@ function Base.sum(tt::SimpleTensorTrain{Float64}) return out_value[] end +""" + sum(tt::SimpleTensorTrain{ComplexF64}) -> ComplexF64 + +Compute the sum over all complex tensor train elements. +""" +function Base.sum(tt::SimpleTensorTrain{ComplexF64}) + out_re = Ref{Cdouble}(0.0) + out_im = Ref{Cdouble}(0.0) + status = C_API.t4a_simplett_c64_sum(tt.ptr, out_re, out_im) + C_API.check_status(status) + return ComplexF64(out_re[], out_im[]) +end + """ norm(tt::SimpleTensorTrain) -> Float64 Compute the Frobenius norm of the tensor train. """ -function LinearAlgebra.norm(tt::SimpleTensorTrain{Float64}) +function LinearAlgebra.norm(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} out_value = Ref{Cdouble}(0.0) - status = C_API.t4a_simplett_f64_norm(tt.ptr, out_value) + status = _api(T, :norm)(tt.ptr, out_value) C_API.check_status(status) return out_value[] end +# ============================================================================ +# Site tensor (1-indexed) +# ============================================================================ + """ - site_tensor(tt::SimpleTensorTrain, site::Integer) -> Array{Float64, 3} + sitetensor(tt::SimpleTensorTrain{Float64}, site::Integer) -> Array{Float64, 3} -Get the site tensor at a specific site (0-based indexing). -Returns array with shape (left_dim, site_dim, right_dim). +Get the site tensor at a specific site (1-based indexing). +Returns array with shape `(left_dim, site_dim, right_dim)`. """ -function site_tensor(tt::SimpleTensorTrain{Float64}, site::Integer) - # First get dimensions to allocate buffer +function sitetensor(tt::SimpleTensorTrain{Float64}, site::Integer) n = length(tt) - 0 <= site < n || error("Site index out of bounds: $site (n=$n)") + 1 <= site <= n || error("Site index out of bounds: $site (n=$n)") + c_site = site - 1 # Convert to 0-based - sdims = site_dims(tt) - ldims = link_dims(tt) + sdims = sitedims(tt) + ldims = linkdims(tt) - left_dim = site == 0 ? 1 : ldims[site] - site_dim = sdims[site + 1] # Julia 1-based - right_dim = site == n - 1 ? 1 : ldims[site + 1] + left_dim = site == 1 ? 1 : ldims[site - 1] + site_dim = sdims[site] + right_dim = site == n ? 1 : ldims[site] total_size = left_dim * site_dim * right_dim data = Vector{Cdouble}(undef, total_size) @@ -204,30 +376,294 @@ function site_tensor(tt::SimpleTensorTrain{Float64}, site::Integer) out_site = Ref{Csize_t}(0) out_right = Ref{Csize_t}(0) - status = C_API.t4a_simplett_f64_site_tensor(tt.ptr, site, data, out_left, out_site, out_right) + status = C_API.t4a_simplett_f64_site_tensor(tt.ptr, c_site, data, out_left, out_site, out_right) C_API.check_status(status) - # Data is column-major (left, site, right) from Rust return reshape(data, (Int(out_left[]), Int(out_site[]), Int(out_right[]))) end +""" + sitetensor(tt::SimpleTensorTrain{ComplexF64}, site::Integer) -> Array{ComplexF64, 3} + +Get the site tensor at a specific site (1-based indexing). +Returns array with shape `(left_dim, site_dim, right_dim)`. +Complex data is stored as interleaved (re, im) doubles in the C API. +""" +function sitetensor(tt::SimpleTensorTrain{ComplexF64}, site::Integer) + n = length(tt) + 1 <= site <= n || error("Site index out of bounds: $site (n=$n)") + c_site = site - 1 # Convert to 0-based + + sdims = sitedims(tt) + ldims = linkdims(tt) + + left_dim = site == 1 ? 1 : ldims[site - 1] + site_dim = sdims[site] + right_dim = site == n ? 1 : ldims[site] + + total_size = left_dim * site_dim * right_dim + # Buffer needs 2x doubles for interleaved (re, im) pairs + data = Vector{Cdouble}(undef, 2 * total_size) + + out_left = Ref{Csize_t}(0) + out_site = Ref{Csize_t}(0) + out_right = Ref{Csize_t}(0) + + status = C_API.t4a_simplett_c64_site_tensor(tt.ptr, c_site, data, out_left, out_site, out_right) + C_API.check_status(status) + + # Reinterpret interleaved doubles as ComplexF64 + complex_data = reinterpret(ComplexF64, data) + return reshape(complex_data, (Int(out_left[]), Int(out_site[]), Int(out_right[]))) +end + +# ============================================================================ +# Compress +# ============================================================================ + +""" + compress!(tt::SimpleTensorTrain; method::Symbol=:SVD, tolerance::Float64=1e-12, max_bonddim::Int=typemax(Int)) + +Compress the tensor train in-place. +`method` can be `:SVD`, `:LU`, or `:CI`. +""" +function compress!(tt::SimpleTensorTrain{T}; + method::Symbol=:SVD, + tolerance::Float64=1e-12, + max_bonddim::Int=typemax(Int)) where {T<:_SimpleTTScalar} + method_int = if method == :SVD + 0 + elseif method == :LU + 1 + elseif method == :CI + 2 + else + error("Unknown compression method: $method. Use :SVD, :LU, or :CI.") + end + status = _api(T, :compress)(tt.ptr, method_int, tolerance, max_bonddim) + C_API.check_status(status) + return tt +end + +# ============================================================================ +# Partial sum +# ============================================================================ + +""" + partial_sum(tt::SimpleTensorTrain, dims::Vector{<:Integer}) + +Compute a partial sum over specified dimensions (1-based indexing). +Returns a new `SimpleTensorTrain` with the summed-over dimensions removed. +""" +function partial_sum(tt::SimpleTensorTrain{T}, dims::Vector{<:Integer}) where {T<:_SimpleTTScalar} + c_dims = Csize_t.(dims .- 1) # Convert to 0-based + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + status = _api(T, :partial_sum)(tt.ptr, c_dims, length(c_dims), out_ptr) + C_API.check_status(status) + return SimpleTensorTrain{T}(out_ptr[]) +end + +# ============================================================================ +# Arithmetic: add, subtract, scale +# ============================================================================ + +""" + +(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) -> SimpleTensorTrain{T} + +Add two tensor trains. Returns a new tensor train. +""" +function Base.:+(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + status = _api(T, :add)(a.ptr, b.ptr, out_ptr) + C_API.check_status(status) + return SimpleTensorTrain{T}(out_ptr[]) +end + +""" + -(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) -> SimpleTensorTrain{T} + +Subtract two tensor trains. Returns a new tensor train. +Implemented as `a + (-one(T) * b)`. +""" +function Base.:-(a::SimpleTensorTrain{T}, b::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} + return a + (-one(T) * b) +end + +""" + scale!(tt::SimpleTensorTrain{Float64}, alpha::Real) + +Scale a Float64 tensor train in-place by a real factor. +""" +function scale!(tt::SimpleTensorTrain{Float64}, alpha::Real) + status = C_API.t4a_simplett_f64_scale(tt.ptr, Float64(alpha)) + C_API.check_status(status) + return tt +end + +""" + scale!(tt::SimpleTensorTrain{ComplexF64}, alpha::Number) + +Scale a ComplexF64 tensor train in-place by a complex factor. +""" +function scale!(tt::SimpleTensorTrain{ComplexF64}, alpha::Number) + c = ComplexF64(alpha) + status = C_API.t4a_simplett_c64_scale(tt.ptr, real(c), imag(c)) + C_API.check_status(status) + return tt +end + +""" + *(alpha::Number, tt::SimpleTensorTrain{T}) -> SimpleTensorTrain{T} + +Scalar multiplication (creates a copy, then scales). +""" +function Base.:*(alpha::Number, tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} + result = copy(tt) + scale!(result, alpha) + return result +end + +Base.:*(tt::SimpleTensorTrain{T}, alpha::Number) where {T<:_SimpleTTScalar} = alpha * tt + +# ============================================================================ +# Conjugation +# ============================================================================ + +""" + conj(tt::SimpleTensorTrain{ComplexF64}) -> SimpleTensorTrain{ComplexF64} + +Return a new tensor train with all site tensors element-wise conjugated. +""" +function Base.conj(tt::SimpleTensorTrain{ComplexF64}) + n = length(tt) + site_tensors = [conj(sitetensor(tt, i)) for i in 1:n] + return SimpleTensorTrain(site_tensors) +end + +# Float64 conjugation is a no-op (returns a copy for consistency). +Base.conj(tt::SimpleTensorTrain{Float64}) = copy(tt) + +# ============================================================================ +# Dot product +# ============================================================================ + +""" + dot(a::SimpleTensorTrain{Float64}, b::SimpleTensorTrain{Float64}) -> Float64 + +Compute the dot product (inner product) of two Float64 tensor trains. +""" +function LinearAlgebra.dot(a::SimpleTensorTrain{Float64}, b::SimpleTensorTrain{Float64}) + out_value = Ref{Cdouble}(0.0) + status = C_API.t4a_simplett_f64_dot(a.ptr, b.ptr, out_value) + C_API.check_status(status) + return out_value[] +end + +""" + dot(a::SimpleTensorTrain{ComplexF64}, b::SimpleTensorTrain{ComplexF64}) -> ComplexF64 + +Compute the dot product (inner product) of two ComplexF64 tensor trains. +Follows Julia convention: `dot(a, b) = sum(conj(a_i) * b_i)`. + +The Rust C API computes the bilinear form `sum(a_i * b_i)`, so we conjugate +the first argument before calling it. +""" +function LinearAlgebra.dot(a::SimpleTensorTrain{ComplexF64}, b::SimpleTensorTrain{ComplexF64}) + a_conj = conj(a) + out_re = Ref{Cdouble}(0.0) + out_im = Ref{Cdouble}(0.0) + status = C_API.t4a_simplett_c64_dot(a_conj.ptr, b.ptr, out_re, out_im) + C_API.check_status(status) + return ComplexF64(out_re[], out_im[]) +end + +# ============================================================================ +# Reverse +# ============================================================================ + +""" + reverse(tt::SimpleTensorTrain) -> SimpleTensorTrain + +Reverse the site ordering of the tensor train. Returns a new tensor train. +""" +function Base.reverse(tt::SimpleTensorTrain{T}) where {T<:_SimpleTTScalar} + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + status = _api(T, :reverse)(tt.ptr, out_ptr) + C_API.check_status(status) + return SimpleTensorTrain{T}(out_ptr[]) +end + +# ============================================================================ +# Full tensor +# ============================================================================ + +""" + fulltensor(tt::SimpleTensorTrain{Float64}) -> Array{Float64} + +Convert the tensor train to a full (dense) tensor. +Returns an Array with dimensions equal to `sitedims(tt)`. +""" +function fulltensor(tt::SimpleTensorTrain{Float64}) + # Query the required buffer size + out_len = Ref{Csize_t}(0) + status = C_API.t4a_simplett_f64_fulltensor(tt.ptr, C_NULL, 0, out_len) + C_API.check_status(status) + + total_size = Int(out_len[]) + data = Vector{Cdouble}(undef, total_size) + out_len2 = Ref{Csize_t}(0) + status = C_API.t4a_simplett_f64_fulltensor(tt.ptr, data, total_size, out_len2) + C_API.check_status(status) + + dims = sitedims(tt) + return reshape(data, Tuple(dims)) +end + +""" + fulltensor(tt::SimpleTensorTrain{ComplexF64}) -> Array{ComplexF64} + +Convert the complex tensor train to a full (dense) tensor. +Returns an Array with dimensions equal to `sitedims(tt)`. +""" +function fulltensor(tt::SimpleTensorTrain{ComplexF64}) + # Query the required buffer size (returns number of doubles = 2 * n_elements) + out_len = Ref{Csize_t}(0) + status = C_API.t4a_simplett_c64_fulltensor(tt.ptr, C_NULL, 0, out_len) + C_API.check_status(status) + + total_doubles = Int(out_len[]) + data = Vector{Cdouble}(undef, total_doubles) + out_len2 = Ref{Csize_t}(0) + status = C_API.t4a_simplett_c64_fulltensor(tt.ptr, data, total_doubles, out_len2) + C_API.check_status(status) + + # Reinterpret interleaved doubles as ComplexF64 + complex_data = reinterpret(ComplexF64, data) + dims = sitedims(tt) + return reshape(complex_data, Tuple(dims)) +end + +# ============================================================================ +# Display +# ============================================================================ + """ show(io::IO, tt::SimpleTensorTrain) Display tensor train information. """ -function Base.show(io::IO, tt::SimpleTensorTrain{T}) where T +function Base.show(io::IO, tt::SimpleTensorTrain{T}) where {T} n = length(tt) r = rank(tt) print(io, "SimpleTensorTrain{$T}(sites=$n, rank=$r)") end -function Base.show(io::IO, ::MIME"text/plain", tt::SimpleTensorTrain{T}) where T +function Base.show(io::IO, ::MIME"text/plain", tt::SimpleTensorTrain{T}) where {T} n = length(tt) println(io, "SimpleTensorTrain{$T}") println(io, " Sites: $n") - println(io, " Site dims: $(site_dims(tt))") - println(io, " Link dims: $(link_dims(tt))") + println(io, " Site dims: $(sitedims(tt))") + println(io, " Link dims: $(linkdims(tt))") println(io, " Max rank: $(rank(tt))") end diff --git a/src/Tensor4all.jl b/src/Tensor4all.jl index 109a225..b3e2cdd 100644 --- a/src/Tensor4all.jl +++ b/src/Tensor4all.jl @@ -566,10 +566,9 @@ function Tensor(inds::Vector{Index}, data::AbstractArray{ComplexF64}) r = length(inds) index_ptrs = [idx.ptr for idx in inds] dims_vec = Csize_t[dim(idx) for idx in inds] - data_re = Cdouble[real(z) for z in flat_data] - data_im = Cdouble[imag(z) for z in flat_data] + data_interleaved = collect(reinterpret(Float64, flat_data)) - ptr = C_API.t4a_tensor_new_dense_c64(r, index_ptrs, dims_vec, data_re, data_im) + ptr = C_API.t4a_tensor_new_dense_c64(r, index_ptrs, dims_vec, data_interleaved) return Tensor(ptr) end @@ -691,18 +690,16 @@ function data(t::Tensor) elseif kind == DenseC64 # Query length out_len = Ref{Csize_t}(0) - status = C_API.t4a_tensor_get_data_c64(t.ptr, nothing, nothing, 0, out_len) + status = C_API.t4a_tensor_get_data_c64(t.ptr, nothing, 0, out_len) C_API.check_status(status) - # Get data - buf_re = Vector{Cdouble}(undef, out_len[]) - buf_im = Vector{Cdouble}(undef, out_len[]) - status = C_API.t4a_tensor_get_data_c64(t.ptr, buf_re, buf_im, out_len[], out_len) + # Get interleaved re/im data + buf = Vector{Cdouble}(undef, 2 * out_len[]) + status = C_API.t4a_tensor_get_data_c64(t.ptr, buf, out_len[], out_len) C_API.check_status(status) - # Combine — data is already column-major from Rust - buf = [ComplexF64(r, i) for (r, i) in zip(buf_re, buf_im)] - return isempty(d) ? reshape(buf, 1) : reshape(buf, d...) + complex_buf = copy(reinterpret(ComplexF64, buf)) + return isempty(d) ? reshape(complex_buf, 1) : reshape(complex_buf, d...) else error("Unsupported storage kind for data extraction: $kind") @@ -894,6 +891,52 @@ end export save_itensor, load_itensor +# ============================================================================ +# Generic functions shared across submodules +# ============================================================================ +# These functions are defined here so that submodules (SimpleTT, TreeTN, +# QuanticsTCI, etc.) can extend them with methods for their own types, +# avoiding name collisions when multiple submodules are loaded together. + +""" + linkdims(obj) -> Vector{Int} + +Get the link (bond) dimensions. Dispatches to the appropriate method +depending on the type of `obj` (SimpleTensorTrain, TreeTensorNetwork, etc.). +""" +function linkdims end + +""" + compress!(obj; kwargs...) + +Compress a tensor network object in-place. +""" +function compress! end + +""" + evaluate(obj, indices) -> scalar + +Evaluate a tensor network object at the given multi-index. +Dispatches to the appropriate method depending on the type of `obj`. +""" +function evaluate end + +""" + maxbonderror(obj) -> Float64 + +Get the maximum bond error across all edges/bonds. +""" +function maxbonderror end + +""" + maxrank(obj) -> Int + +Get the maximum rank (bond dimension) across all edges/bonds. +""" +function maxrank end + +export linkdims, compress!, evaluate, maxbonderror, maxrank + # ============================================================================ # SimpleTT Submodule (Simple Tensor Train) # ============================================================================ @@ -902,13 +945,6 @@ export save_itensor, load_itensor # Use: using Tensor4all.SimpleTT include("SimpleTT.jl") -# ============================================================================ -# TensorCI Submodule (Tensor Cross Interpolation) -# ============================================================================ -# TensorCI provides tensor cross interpolation algorithms. -# Use: using Tensor4all.TensorCI -include("TensorCI.jl") - # ============================================================================ # TreeTN Submodule (Tree Tensor Network: MPS, MPO, TTN) # ============================================================================ @@ -916,12 +952,21 @@ include("TensorCI.jl") # Use: using Tensor4all.TreeTN include("TreeTN.jl") +# ============================================================================ +# TreeTCI Submodule (Tree Tensor Cross Interpolation) +# ============================================================================ +# TreeTCI provides tree-structured tensor cross interpolation. +# Use: using Tensor4all.TreeTCI +include("TreeTCI.jl") + # ============================================================================ # QuanticsGrids Submodule # ============================================================================ # Quantics grid types for coordinate conversions in QTT methods. # Use: using Tensor4all.QuanticsGrids include("QuanticsGrids.jl") +using .QuanticsGrids: DiscretizedGrid, InherentDiscreteGrid, localdimensions +export DiscretizedGrid, InherentDiscreteGrid, localdimensions # ============================================================================ # QuanticsTCI Submodule (Quantics Tensor Cross Interpolation) diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl new file mode 100644 index 0000000..965e58d --- /dev/null +++ b/src/TreeTCI.jl @@ -0,0 +1,738 @@ +""" + TreeTCI + +Tree-structured tensor cross interpolation via tensor4all-rs. + +Provides `TreeTciGraph` for defining tree topologies and `SimpleTreeTci` +for running TCI on arbitrary tree structures. Results are returned as +`TreeTN.TreeTensorNetwork`. + +All user-facing APIs use **1-indexed** indices (Julia convention). +Internal C API calls convert to 0-indexed. + +# Usage +```julia +using Tensor4all.TreeTCI + +graph = TreeTciGraph(4, [(0, 1), (1, 2), (2, 3)]) +f(batch) = [sum(Float64, batch[:, j]) for j in 1:size(batch, 2)] +tci, ranks, errors = crossinterpolate2(f, [3, 3, 3, 3], graph) +ttn = to_treetn(tci, f) +``` +""" +module TreeTCI + +using ..C_API +import ..TreeTN: TreeTensorNetwork +import ..evaluate, ..maxbonderror, ..maxrank + +export TreeTciGraph, SimpleTreeTci +export crossinterpolate2, evaluate +export bonddims, maxbonderror, maxrank, maxsamplevalue +export to_treetn + +const _TreeTciScalar = Union{Float64, ComplexF64} + +_suffix(::Type{Float64}) = "f64" +_suffix(::Type{ComplexF64}) = "c64" +_sym_for(::Type{T}, name::Symbol) where {T<:_TreeTciScalar} = + C_API._sym(Symbol("t4a_treetci_", _suffix(T), "_", name)) +_cross_sym_for(::Type{T}) where {T<:_TreeTciScalar} = + C_API._sym(Symbol("t4a_treetci_crossinterpolate2_", _suffix(T))) + +function _infer_scalar_type(f, local_dims::Vector{<:Integer}, initial_pivots::Vector{Vector{Int}}) + sample_indices = isempty(initial_pivots) ? ones(Int, length(local_dims)) : initial_pivots[1] + sample_values = f(reshape(sample_indices, :, 1)) + length(sample_values) == 1 || + error("TreeTCI batch callback must return exactly one value for a single-point batch") + sample_value = sample_values[1] + if sample_value isa Real + return Float64 + elseif sample_value isa Complex + return ComplexF64 + end + error("TreeTCI batch callback must return real or complex values, got $(typeof(sample_value))") +end + +# ============================================================================ +# TreeTciGraph +# ============================================================================ + +""" + TreeTciGraph(n_sites, edges) + +Define a tree graph structure for TreeTCI. + +# Arguments +- `n_sites::Int`: Number of sites +- `edges::Vector{Tuple{Int,Int}}`: Edge list (0-based site indices) + +# Examples +```julia +# Linear chain: 0-1-2-3 +graph = TreeTciGraph(4, [(0, 1), (1, 2), (2, 3)]) + +# Star graph: 0 at center +graph = TreeTciGraph(4, [(0, 1), (0, 2), (0, 3)]) + +# 7-site branching tree +graph = TreeTciGraph(7, [(0, 1), (1, 2), (1, 3), (3, 4), (4, 5), (4, 6)]) +``` +""" +mutable struct TreeTciGraph + ptr::Ptr{Cvoid} + n_sites::Int + + function TreeTciGraph(n_sites::Int, edges::Vector{Tuple{Int, Int}}) + n_edges = length(edges) + edges_flat = Vector{Csize_t}(undef, 2 * n_edges) + for (i, (u, v)) in enumerate(edges) + edges_flat[2 * i - 1] = Csize_t(u) + edges_flat[2 * i] = Csize_t(v) + end + + ptr = ccall( + C_API._sym(:t4a_treetci_graph_new), + Ptr{Cvoid}, + (Csize_t, Ptr{Csize_t}, Csize_t), + Csize_t(n_sites), + n_edges == 0 ? C_NULL : edges_flat, + Csize_t(n_edges), + ) + if ptr == C_NULL + error("Failed to create TreeTciGraph: $(C_API.last_error_message())") + end + + graph = new(ptr, n_sites) + finalizer(graph) do obj + if obj.ptr != C_NULL + ccall(C_API._sym(:t4a_treetci_graph_release), Cvoid, (Ptr{Cvoid},), obj.ptr) + obj.ptr = C_NULL + end + end + return graph + end +end + +# ============================================================================ +# Batch Eval Trampoline +# ============================================================================ + +# The Rust side passes 0-indexed batch values. The trampoline adds 1 before +# calling the user function so that user functions always see 1-indexed values. + +""" +Internal trampoline for f64 batch callbacks. + +The user function signature is: `f(batch::Matrix{Int}) -> Vector{Float64}` +where `batch` is column-major `(n_sites, n_points)` with **1-based** indices. + +The C API provides 0-based indices; we add 1 before calling the user function. +""" +function _treetci_batch_trampoline( + batch_data::Ptr{Csize_t}, + n_sites::Csize_t, + n_points::Csize_t, + results::Ptr{Cdouble}, + user_data::Ptr{Cvoid}, +)::Cint + try + f_ref = unsafe_pointer_to_objref(user_data)::Ref{Any} + f = f_ref[] + raw_batch = unsafe_wrap(Array, batch_data, (Int(n_sites), Int(n_points))) + # Convert 0-indexed (from C API) to 1-indexed (for user) + batch_1indexed = Int.(raw_batch) .+ 1 + vals = f(batch_1indexed) + length(vals) == Int(n_points) || + error("Batch callback returned $(length(vals)) values for $(Int(n_points)) points") + for i in 1:Int(n_points) + unsafe_store!(results, Float64(vals[i]), i) + end + return Cint(0) + catch err + @error "TreeTCI batch eval callback error" exception = (err, catch_backtrace()) + return Cint(-1) + end +end + +""" +Internal trampoline for c64 batch callbacks. + +The user function signature is: `f(batch::Matrix{Int}) -> Vector{ComplexF64}` +where `batch` is column-major `(n_sites, n_points)` with **1-based** indices. + +The C API provides 0-based indices; we add 1 before calling the user function. +Results are written as interleaved doubles. +""" +function _treetci_batch_trampoline_c64( + batch_data::Ptr{Csize_t}, + n_sites::Csize_t, + n_points::Csize_t, + results::Ptr{Cdouble}, + user_data::Ptr{Cvoid}, +)::Cint + try + f_ref = unsafe_pointer_to_objref(user_data)::Ref{Any} + f = f_ref[] + raw_batch = unsafe_wrap(Array, batch_data, (Int(n_sites), Int(n_points))) + # Convert 0-indexed (from C API) to 1-indexed (for user) + batch_1indexed = Int.(raw_batch) .+ 1 + vals = ComplexF64.(f(batch_1indexed)) + length(vals) == Int(n_points) || + error("Batch callback returned $(length(vals)) values for $(Int(n_points)) points") + interleaved = reinterpret(Float64, vals) + for i in eachindex(interleaved) + unsafe_store!(results, interleaved[i], i) + end + return Cint(0) + catch err + @error "TreeTCI complex batch eval callback error" exception = (err, catch_backtrace()) + return Cint(-1) + end +end + +const _BATCH_TRAMPOLINE_PTR = Ref{Ptr{Cvoid}}(C_NULL) +const _BATCH_TRAMPOLINE_C64_PTR = Ref{Ptr{Cvoid}}(C_NULL) + +function _get_batch_trampoline(::Type{Float64}) + if _BATCH_TRAMPOLINE_PTR[] == C_NULL + _BATCH_TRAMPOLINE_PTR[] = @cfunction( + _treetci_batch_trampoline, + Cint, + (Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Cdouble}, Ptr{Cvoid}), + ) + end + return _BATCH_TRAMPOLINE_PTR[] +end + +function _get_batch_trampoline(::Type{ComplexF64}) + if _BATCH_TRAMPOLINE_C64_PTR[] == C_NULL + _BATCH_TRAMPOLINE_C64_PTR[] = @cfunction( + _treetci_batch_trampoline_c64, + Cint, + (Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Cdouble}, Ptr{Cvoid}), + ) + end + return _BATCH_TRAMPOLINE_C64_PTR[] +end + +# ============================================================================ +# Proposer helpers +# ============================================================================ + +const _PROPOSER_DEFAULT = Cint(0) +const _PROPOSER_SIMPLE = Cint(1) +const _PROPOSER_TRUNCATED_DEFAULT = Cint(2) + +function _proposer_to_cint(proposer::Symbol)::Cint + if proposer === :default + return _PROPOSER_DEFAULT + elseif proposer === :simple + return _PROPOSER_SIMPLE + elseif proposer === :truncated_default + return _PROPOSER_TRUNCATED_DEFAULT + end + error("Unknown proposer: $proposer. Use :default, :simple, or :truncated_default") +end + +# ============================================================================ +# SimpleTreeTci +# ============================================================================ + +""" + SimpleTreeTci{T<:Union{Float64, ComplexF64}}(local_dims, graph) + +Stateful TreeTCI object for tree-structured tensor cross interpolation. + +# Arguments +- `local_dims::Vector{Int}`: Local dimension at each site (length = graph.n_sites) +- `graph::TreeTciGraph`: Tree graph structure + +# Lifecycle +```julia +tci = SimpleTreeTci([2, 2, 2, 2], graph) +add_global_pivots!(tci, [ones(Int, 4)]) +for _ in 1:20 + sweep!(tci, f; tolerance=1e-8) + maxbonderror(tci) < 1e-8 && break +end +ttn = to_treetn(tci, f) +``` +""" +mutable struct SimpleTreeTci{T<:_TreeTciScalar} + ptr::Ptr{Cvoid} + graph::TreeTciGraph + local_dims::Vector{Int} + + function SimpleTreeTci{T}(local_dims::Vector{<:Integer}, graph::TreeTciGraph) where {T<:_TreeTciScalar} + dims_int = Int.(local_dims) + length(dims_int) == graph.n_sites || + error("local_dims length ($(length(dims_int))) != graph.n_sites ($(graph.n_sites))") + + dims_csize = Csize_t.(dims_int) + ptr = ccall( + _sym_for(T, :new), + Ptr{Cvoid}, + (Ptr{Csize_t}, Csize_t, Ptr{Cvoid}), + dims_csize, + Csize_t(length(dims_csize)), + graph.ptr, + ) + if ptr == C_NULL + error("Failed to create SimpleTreeTci: $(C_API.last_error_message())") + end + + tci = new{T}(ptr, graph, dims_int) + finalizer(tci) do obj + if obj.ptr != C_NULL + ccall(_sym_for(T, :release), Cvoid, (Ptr{Cvoid},), obj.ptr) + obj.ptr = C_NULL + end + end + return tci + end +end + +SimpleTreeTci(local_dims::Vector{<:Integer}, graph::TreeTciGraph) = + SimpleTreeTci{Float64}(local_dims, graph) + +# ============================================================================ +# Pivot management +# ============================================================================ + +""" + add_global_pivots!(tci, pivots) + +Add global pivots. Each pivot is a full multi-index over all sites (**1-based**). +Internally converts to 0-based for the C API. + +# Arguments +- `tci::SimpleTreeTci` +- `pivots::Vector{Vector{Int}}`: Each element has length `n_sites`, 1-based indices +""" +function add_global_pivots!(tci::SimpleTreeTci{T}, pivots::Vector{Vector{Int}}) where {T} + n_sites = length(tci.local_dims) + n_pivots = length(pivots) + n_pivots == 0 && return tci + + pivots_flat = Vector{Csize_t}(undef, n_sites * n_pivots) + for j in 1:n_pivots + pivot = pivots[j] + length(pivot) == n_sites || + error("Pivot $j has length $(length(pivot)), expected $n_sites") + for i in 1:n_sites + # Convert 1-indexed to 0-indexed for C API + pivots_flat[i + n_sites * (j - 1)] = Csize_t(pivot[i] - 1) + end + end + + C_API.check_status(ccall( + _sym_for(T, :add_global_pivots), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t), + tci.ptr, + pivots_flat, + Csize_t(n_sites), + Csize_t(n_pivots), + )) + return tci +end + +# ============================================================================ +# Sweep +# ============================================================================ + +""" + sweep!(tci, f; proposer=:default, tolerance=1e-8, max_bond_dim=0) + +Run one optimization iteration (visit all edges once). + +# Arguments +- `tci::SimpleTreeTci` +- `f`: Batch evaluation function `f(batch::Matrix{Int}) -> Vector{T}` + where `batch` is column-major `(n_sites, n_points)` with **1-based** indices +- `proposer`: `:default`, `:simple`, or `:truncated_default` +- `tolerance`: Relative tolerance +- `max_bond_dim`: Maximum bond dimension (0 = unlimited) +""" +function sweep!( + tci::SimpleTreeTci{T}, + f; + proposer::Symbol = :default, + tolerance::Float64 = 1e-8, + max_bond_dim::Int = 0, +) where {T} + f_ref = Ref{Any}(f) + GC.@preserve f_ref begin + C_API.check_status(ccall( + _sym_for(T, :sweep), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Cint, Cdouble, Csize_t), + tci.ptr, + _get_batch_trampoline(T), + pointer_from_objref(f_ref), + _proposer_to_cint(proposer), + tolerance, + Csize_t(max_bond_dim), + )) + end + return tci +end + +# ============================================================================ +# State inspection +# ============================================================================ + +"""Maximum bond error across all edges.""" +function maxbonderror(tci::SimpleTreeTci{T}) where {T} + out = Ref{Cdouble}(0.0) + C_API.check_status(ccall( + _sym_for(T, :max_bond_error), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), + tci.ptr, + out, + )) + return out[] +end + +"""Maximum rank (bond dimension) across all edges.""" +function maxrank(tci::SimpleTreeTci{T}) where {T} + out = Ref{Csize_t}(0) + C_API.check_status(ccall( + _sym_for(T, :max_rank), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}), + tci.ptr, + out, + )) + return Int(out[]) +end + +"""Maximum observed sample value (for normalization).""" +function maxsamplevalue(tci::SimpleTreeTci{T}) where {T} + out = Ref{Cdouble}(0.0) + C_API.check_status(ccall( + _sym_for(T, :max_sample_value), + Cint, + (Ptr{Cvoid}, Ptr{Cdouble}), + tci.ptr, + out, + )) + return out[] +end + +"""Bond dimensions (ranks) at each edge.""" +function bonddims(tci::SimpleTreeTci{T}) where {T} + n_edges_ref = Ref{Csize_t}(0) + C_API.check_status(ccall( + _sym_for(T, :bond_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + tci.ptr, + Ptr{Csize_t}(C_NULL), + Csize_t(0), + n_edges_ref, + )) + + n_edges = Int(n_edges_ref[]) + buf = Vector{Csize_t}(undef, n_edges) + C_API.check_status(ccall( + _sym_for(T, :bond_dims), + Cint, + (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + tci.ptr, + buf, + Csize_t(n_edges), + n_edges_ref, + )) + return Int.(buf) +end + +# ============================================================================ +# Materialization +# ============================================================================ + +function _wrap_treetn(handle::Ptr{Cvoid}, n_sites::Int) + node_names = collect(1:n_sites) + node_map = Dict{Int, Int}(i => i - 1 for i in node_names) + return TreeTensorNetwork{Int}(handle, node_map, node_names) +end + +""" + to_treetn(tci, f; center_site=0) + +Convert converged TreeTCI state to a TreeTensorNetwork. + +# Arguments +- `tci::SimpleTreeTci`: Converged state +- `f`: Batch evaluation function (same as `sweep!`; receives 1-indexed batches) +- `center_site`: BFS root site for materialization (0-based) +""" +function to_treetn(tci::SimpleTreeTci{T}, f; center_site::Int = 0) where {T} + f_ref = Ref{Any}(f) + out_ptr = Ref{Ptr{Cvoid}}(C_NULL) + GC.@preserve f_ref begin + C_API.check_status(ccall( + _sym_for(T, :to_treetn), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Ptr{Ptr{Cvoid}}), + tci.ptr, + _get_batch_trampoline(T), + pointer_from_objref(f_ref), + Csize_t(center_site), + out_ptr, + )) + end + return _wrap_treetn(out_ptr[], length(tci.local_dims)) +end + +# ============================================================================ +# TreeTensorNetwork evaluation +# ============================================================================ + +""" + evaluate(ttn, indices::Vector{<:Integer}) -> T + +Evaluate a TreeTensorNetwork at a single multi-index (**1-based**). +Returns a scalar (`Float64` or `ComplexF64` depending on storage). + +# Example +```julia +val = evaluate(ttn, [1, 2, 3, 1]) +``` +""" +function evaluate(ttn::TreeTensorNetwork, indices::Vector{<:Integer}) + # Convert 1-indexed to 0-indexed for C API + indices_0 = Csize_t.(indices .- 1) + vals = _evaluate_batch(ttn, indices_0, length(indices), 1) + return vals[1] +end + +""" + evaluate(ttn, indices::Vector{Vector{T}}) -> Vector + +Evaluate a TreeTensorNetwork at multiple multi-indices (**1-based**). +Each element of `indices` is a multi-index of length `n_sites`. + +# Example +```julia +vals = evaluate(ttn, [[1,2,3,1], [2,1,2,2]]) +``` +""" +function evaluate(ttn::TreeTensorNetwork, indices::Vector{Vector{T}}) where {T<:Integer} + n_sites = length(indices[1]) + n_points = length(indices) + flat = Vector{Csize_t}(undef, n_sites * n_points) + for j in 1:n_points + length(indices[j]) == n_sites || + error("Index $j has length $(length(indices[j])), expected $n_sites") + for i in 1:n_sites + # Convert 1-indexed to 0-indexed for C API + flat[i + n_sites * (j - 1)] = Csize_t(indices[j][i] - 1) + end + end + return _evaluate_batch(ttn, flat, n_sites, n_points) +end + +""" + evaluate(ttn, batch::AbstractMatrix{<:Integer}) -> Vector + +Evaluate a TreeTensorNetwork at multiple multi-indices given as a matrix. +`batch` has shape `(n_sites, n_points)` — the rightmost dimension is the batch index. +Indices are **1-based**. + +# Example +```julia +batch = [1 2; 2 1; 3 2; 1 2] # 4 sites, 2 points, 1-based +vals = evaluate(ttn, batch) +``` +""" +function evaluate(ttn::TreeTensorNetwork, batch::AbstractMatrix{<:Integer}) + n_sites, n_points = size(batch) + # Convert 1-indexed to 0-indexed for C API + flat = Csize_t.(vec(batch) .- 1) + return _evaluate_batch(ttn, flat, n_sites, n_points) +end + +"""Internal: call C API evaluate using IndexId-based interface and return typed results.""" +function _evaluate_batch(ttn::TreeTensorNetwork, flat::Vector{Csize_t}, n_sites::Int, n_points::Int) + # Step 1: Query the number of site indices + n_indices_ref = Ref{Csize_t}(0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetn_all_site_index_ids), Cint, + (Ptr{Cvoid}, Ptr{UInt64}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + ttn.handle, C_NULL, C_NULL, 0, n_indices_ref, + )) + n_indices = Int(n_indices_ref[]) + + # Step 2: Fetch index IDs and their vertex names + index_ids = Vector{UInt64}(undef, n_indices) + vertex_names = Vector{Csize_t}(undef, n_indices) + C_API.check_status(ccall( + C_API._sym(:t4a_treetn_all_site_index_ids), Cint, + (Ptr{Cvoid}, Ptr{UInt64}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + ttn.handle, index_ids, vertex_names, n_indices, n_indices_ref, + )) + + # Step 3: Reorder flat values to match index_ids order. + # `flat` is laid out by site order (0, 1, ..., n_sites-1) for each point. + # `vertex_names[i]` tells us which site (0-based) index i belongs to. + # flat values are already 0-indexed (caller does the conversion). + reordered = Vector{Csize_t}(undef, n_indices * n_points) + for p in 0:(n_points - 1) + for i in 1:n_indices + site = Int(vertex_names[i]) # 0-based site number + reordered[i + n_indices * p] = flat[site + 1 + n_sites * p] # +1 for Julia 1-based array + end + end + + # Step 4: Call evaluate + out_re = Vector{Cdouble}(undef, n_points) + out_im = Vector{Cdouble}(undef, n_points) + C_API.check_status(ccall( + C_API._sym(:t4a_treetn_evaluate), Cint, + (Ptr{Cvoid}, Ptr{UInt64}, Csize_t, Ptr{Csize_t}, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}), + ttn.handle, index_ids, n_indices, reordered, n_points, out_re, out_im, + )) + + # Detect if complex by checking if any imaginary part is nonzero + if all(iszero, out_im) + return out_re + else + return ComplexF64.(out_re .+ im .* out_im) + end +end + +# ============================================================================ +# High-level convenience function +# ============================================================================ + +""" + crossinterpolate2([T], f, localdims, graph; kwargs...) -> (tci, ranks, errors) + +Run TreeTCI to convergence on a tree graph. + +The sweep loop runs in Julia, printing convergence info when `verbosity > 0`. +This matches the API style of `TensorCrossInterpolation.crossinterpolate2`. + +# Arguments +- `T`: Scalar type (`Float64` or `ComplexF64`). Inferred if omitted. +- `f`: Batch evaluation function -- `f(batch) -> Vector{T}` where `batch` is a + `Matrix{Int}` of shape `(n_sites, n_points)` in column-major layout. + `batch[i, j]` is the **1-based** local index at site `i` for evaluation point `j`. + The function must return a `Vector` of length `n_points`. +- `localdims::Union{Vector{Int}, NTuple{N,Int}}`: Local dimensions at each site +- `graph::TreeTciGraph`: Tree graph structure + +# Keyword Arguments +- `initialpivots::Vector{Vector{Int}} = [ones(Int, n)]`: Initial pivots (**1-based**) +- `tolerance::Float64 = 1e-8`: Target tolerance +- `maxbonddim::Int = typemax(Int)`: Maximum bond dimension +- `maxiter::Int = 20`: Maximum sweeps +- `verbosity::Int = 0`: 0=silent, 1=summary per loginterval, 2=bond dims and timing +- `loginterval::Int = 10`: Print every N iterations (when verbosity >= 1) +- `normalizeerror::Bool = true`: Normalize error by max sample value +- `proposer::Symbol = :default`: `:default`, `:simple`, or `:truncated_default` +- `center_site::Int = 0`: Materialization center site (0-based) + +# Returns +- `tci::SimpleTreeTci{T}`: The converged TCI state (call `to_treetn(tci, f)` to materialize) +- `ranks::Vector{Int}`: Max rank per iteration +- `errors::Vector{Float64}`: (Normalized) error per iteration + +# Example +```julia +using Tensor4all.TreeTCI + +# Define a star graph: site 0 connected to sites 1,2,3,4 +graph = TreeTciGraph(5, [(0,1), (0,2), (0,3), (0,4)]) + +# Batch evaluation function (1-based indices) +function f(batch) + n_sites, n_pts = size(batch) + [prod(Float64(batch[i, j]) for i in 1:n_sites) for j in 1:n_pts] +end + +# Run TCI +tci, ranks, errors = crossinterpolate2(f, fill(3, 5), graph; + tolerance=1e-10, verbosity=1) + +# Materialize to TreeTensorNetwork +ttn = to_treetn(tci, f) +``` +""" +function crossinterpolate2( + ::Type{T}, + f, + localdims::Union{Vector{<:Integer}, NTuple{N,Integer}}, + graph::TreeTciGraph; + initialpivots::Vector{Vector{Int}} = [ones(Int, graph.n_sites)], + tolerance::Float64 = 1e-8, + maxbonddim::Int = typemax(Int), + maxiter::Int = 20, + verbosity::Int = 0, + loginterval::Int = 10, + normalizeerror::Bool = true, + proposer::Symbol = :default, + center_site::Int = 0, +) where {T<:_TreeTciScalar, N} + dims_int = Int.(collect(localdims)) + n_sites = length(dims_int) + n_sites == graph.n_sites || + error("localdims length ($n_sites) != graph.n_sites ($(graph.n_sites))") + + bd = maxbonddim == typemax(Int) ? 0 : maxbonddim + + # Create state and add initial pivots (1-indexed; add_global_pivots! converts to 0-indexed) + tci = SimpleTreeTci{T}(dims_int, graph) + add_global_pivots!(tci, initialpivots) + + ranks = Int[] + errors = Float64[] + t_start = time() + + # Sweep loop in Julia + for iter in 1:maxiter + t_sweep_start = time() + sweep!(tci, f; proposer=proposer, tolerance=tolerance, max_bond_dim=bd) + t_sweep = time() - t_sweep_start + + r = maxrank(tci) + err = maxbonderror(tci) + msv = maxsamplevalue(tci) + normalized_err = (normalizeerror && msv > 0) ? err / msv : err + + push!(ranks, r) + push!(errors, normalized_err) + + should_log = iter % loginterval == 0 || iter == 1 || normalized_err < tolerance + if verbosity >= 1 && should_log + @info "TreeTCI" iteration=iter rank=r error=normalized_err maxsamplevalue=msv + end + if verbosity >= 2 && should_log + bd_vec = bonddims(tci) + elapsed = time() - t_start + @info "TreeTCI detail" iteration=iter bonddims=bd_vec sweep_sec=round(t_sweep; digits=3) elapsed_sec=round(elapsed; digits=3) + end + + if normalized_err < tolerance + break + end + end + + return tci, ranks, errors +end + +function crossinterpolate2( + f, + localdims::Union{Vector{<:Integer}, NTuple{N,Integer}}, + graph::TreeTciGraph; + kwargs..., +) where {N} + pivots = get(kwargs, :initialpivots, Vector{Int}[]) + T = _infer_scalar_type(f, collect(localdims), pivots) + return crossinterpolate2(T, f, localdims, graph; kwargs...) +end + +end # module TreeTCI diff --git a/src/TreeTN.jl b/src/TreeTN.jl index 1d5bffe..5a2f2b6 100644 --- a/src/TreeTN.jl +++ b/src/TreeTN.jl @@ -27,7 +27,8 @@ using LinearAlgebra # Import from parent module import ..Tensor4all: Index, Tensor, dim, id, tags, indices, rank, dims, data import ..Tensor4all: hascommoninds, commoninds, uniqueinds, HasCommonIndsPredicate -import ..Tensor4all: C_API +import ..Tensor4all: C_API, linkdims +import ..SimpleTT: SimpleTensorTrain, sitetensor # ============================================================================ # CanonicalForm Enum @@ -154,6 +155,128 @@ function MPS(tensors::Vector{Tensor}) return TreeTensorNetwork{Int}(out[], node_map, node_names) end +""" + MPS(tt::SimpleTensorTrain{T}) where T + +Convert a SimpleTensorTrain to an MPS (TreeTensorNetwork{Int}). + +Extracts site tensors from the SimpleTT and builds Tensor objects with +appropriate site and link indices. +""" +function MPS(tt::SimpleTensorTrain{T}) where T + n = length(tt) + n == 0 && error("Cannot create MPS from empty SimpleTensorTrain") + + tensors = Tensor[] + links = Index[] + + for i in 1:n + st = sitetensor(tt, i) # shape (left, site, right) + left_dim, site_dim, right_dim = size(st) + + site_idx = Index(site_dim) + inds = Index[] + + if i > 1 + push!(inds, links[end]) # left link from previous + end + push!(inds, site_idx) + if i < n + link = Index(right_dim; tags="Link,l=$i") + push!(links, link) + push!(inds, link) + end + + # Remove singleton boundary dimensions + if i == 1 && i == n + # Single site: shape (1, site, 1) -> (site,) + d = reshape(st, site_dim) + elseif i == 1 + # First site: shape (1, site, right) -> (site, right) + d = reshape(st, site_dim, right_dim) + elseif i == n + # Last site: shape (left, site, 1) -> (left, site) + d = reshape(st, left_dim, site_dim) + else + # Middle: shape (left, site, right) + d = st + end + + push!(tensors, Tensor(inds, d)) + end + + return MPS(tensors) +end + +""" + SimpleTensorTrain(mps::TreeTensorNetwork{Int}) + +Convert an MPS (TreeTensorNetwork{Int}) to a SimpleTensorTrain. + +For each site, extracts tensor data ordered as (left_link, site, right_link) +and reshapes to 3D arrays with shape (left_dim, site_dim, right_dim). +""" +function SimpleTensorTrain(mps::TreeTensorNetwork{Int}) + n = nv(mps) + n == 0 && error("Cannot create SimpleTensorTrain from empty MPS") + + site_tensors = Array{Float64,3}[] + first_tensor = mps[1] + first_data = data(first_tensor) + is_complex = eltype(first_data) <: Complex + + if is_complex + site_tensors_c = Array{ComplexF64,3}[] + end + + for i in 1:n + tensor = mps[i] + si = siteinds(mps, i) + site_dim = isempty(si) ? 1 : dim(si[1]) + + if n == 1 + left_dim = 1 + right_dim = 1 + elseif i == 1 + left_dim = 1 + right_dim = linkdim(mps, 1) + elseif i == n + left_dim = linkdim(mps, n - 1) + right_dim = 1 + else + left_dim = linkdim(mps, i - 1) + right_dim = linkdim(mps, i) + end + + # Build desired index order: left_link, site, right_link + desired_inds = Index[] + if i > 1 + push!(desired_inds, linkind(mps, i - 1)) + end + append!(desired_inds, si) + if i < n + push!(desired_inds, linkind(mps, i)) + end + + arr = Array(tensor, desired_inds) + + # Reshape to 3D (left, site, right) + st = reshape(arr, left_dim, site_dim, right_dim) + + if is_complex + push!(site_tensors_c, Array{ComplexF64,3}(st)) + else + push!(site_tensors, Array{Float64,3}(st)) + end + end + + if is_complex + return SimpleTensorTrain(site_tensors_c) + else + return SimpleTensorTrain(site_tensors) + end +end + # Note: MPO(tensors::Vector{Tensor}) is not defined separately because # MPO === MPS === TreeTensorNetwork{Int}, so MPS(tensors) works for both. # Defining a separate function would overwrite the MPS constructor. diff --git a/test/runtests.jl b/test/runtests.jl index 878eaab..57145a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,4 +19,10 @@ skip_hdf5 = get(ENV, "T4A_SKIP_HDF5_TESTS", "") == "1" if !skip_hdf5 include("test_hdf5_itensors_compat.jl") end + include("test_simplett.jl") + include("test_conversions.jl") + include("test_quanticsgrids.jl") + include("test_quanticstci.jl") + include("test_quanticstransform.jl") + include("test_treetci.jl") end diff --git a/test/test_conversions.jl b/test/test_conversions.jl new file mode 100644 index 0000000..fff7205 --- /dev/null +++ b/test/test_conversions.jl @@ -0,0 +1,105 @@ +using Test +using Tensor4all +using Tensor4all.SimpleTT +using Tensor4all.TreeTN + +import Tensor4all.SimpleTT: sitedims, linkdims, evaluate, sitetensor, fulltensor +import Tensor4all.TreeTN: MPS, nv, inner, linkdims as ttn_linkdims +using LinearAlgebra: norm + +@testset "SimpleTT <-> TreeTN Conversions" begin + @testset "SimpleTT -> MPS -> SimpleTT round-trip" begin + @testset "rank-1 constant" begin + tt = SimpleTensorTrain([2, 3, 4], 1.5) + mps = MPS(tt) + @test nv(mps) == 3 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 3 + @test sitedims(tt2) == [2, 3, 4] + + # Check values match + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "higher rank" begin + t1 = randn(1, 2, 3) + t2 = randn(3, 4, 2) + t3 = randn(2, 3, 1) + tt = SimpleTensorTrain([t1, t2, t3]) + mps = MPS(tt) + @test nv(mps) == 3 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 3 + @test sitedims(tt2) == [2, 4, 3] + @test linkdims(tt2) == [3, 2] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "single site" begin + tt = SimpleTensorTrain([3], 2.0) + mps = MPS(tt) + @test nv(mps) == 1 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 1 + @test sitedims(tt2) == [3] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "two sites" begin + tt = SimpleTensorTrain([2, 5], 3.0) + mps = MPS(tt) + @test nv(mps) == 2 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 2 + @test sitedims(tt2) == [2, 5] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + end + + @testset "ComplexF64 round-trip" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0 + 2.0im) + mps = MPS(tt) + @test nv(mps) == 3 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 3 + @test sitedims(tt2) == [2, 3, 4] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "MPS -> SimpleTT -> MPS" begin + sites = [Tensor4all.Index(2) for _ in 1:4] + mps = TreeTN.random_mps(sites; linkdims=3) + + tt = SimpleTensorTrain(mps) + @test length(tt) == 4 + @test sitedims(tt) == [2, 2, 2, 2] + + mps2 = MPS(tt) + @test nv(mps2) == 4 + + # Check that the dense tensor representation matches + dense1 = TreeTN.to_dense(mps) + arr1 = data(dense1) + arr2 = fulltensor(tt) + @test arr1 ≈ arr2 + end +end diff --git a/test/test_quanticsgrids.jl b/test/test_quanticsgrids.jl new file mode 100644 index 0000000..5f2644a --- /dev/null +++ b/test/test_quanticsgrids.jl @@ -0,0 +1,22 @@ +using Test +using Tensor4all + +@testset "QuanticsGrids" begin + @testset "top-level exports" begin + @test isdefined(Tensor4all, :DiscretizedGrid) + @test isdefined(Tensor4all, :InherentDiscreteGrid) + @test isdefined(Tensor4all, :localdimensions) + end + + @testset "grouped unfolding" begin + grid = Tensor4all.QuanticsGrids.DiscretizedGrid( + 2, [2, 2], [0.0, 0.0], [1.0, 1.0]; unfolding=:grouped) + + q = Tensor4all.QuanticsGrids.origcoord_to_quantics(grid, [0.25, 0.75]) + x = Tensor4all.QuanticsGrids.quantics_to_origcoord(grid, q) + + @test length(q) == 4 + @test x ≈ [0.25, 0.75] atol=0.3 + @test Tensor4all.QuanticsGrids.localdimensions(grid) == fill(2, 4) + end +end diff --git a/test/test_quanticstci.jl b/test/test_quanticstci.jl new file mode 100644 index 0000000..c92fd5d --- /dev/null +++ b/test/test_quanticstci.jl @@ -0,0 +1,191 @@ +using Test +using Tensor4all +using Tensor4all.QuanticsGrids +using Tensor4all.QuanticsTCI +using Tensor4all.SimpleTT: SimpleTensorTrain +import Tensor4all: evaluate, linkdims, maxbonderror, maxrank +import Tensor4all.QuanticsTCI: integral, to_tensor_train + +@testset "QuanticsTCI" begin + @testset "Float64 - continuous grid" begin + # Create a 1D grid with 8 bits (256 points) over [0, 1) + grid = DiscretizedGrid(1, 8, [0.0], [1.0]) + + # Interpolate f(x) = x^2 + qtci, ranks, errors = quanticscrossinterpolate(Float64, x -> x^2, grid; + tolerance=1e-8, maxiter=50) + + @testset "return types" begin + @test qtci isa QuanticsTensorCI2{Float64} + @test ranks isa Vector{Int} + @test errors isa Vector{Float64} + @test length(ranks) > 0 + @test length(errors) > 0 + @test length(ranks) == length(errors) + end + + @testset "evaluate" begin + # Evaluate at a known point using grid indices (1-indexed, one per original dimension) + gi = origcoord_to_grididx(grid, [0.5]) + val = evaluate(qtci, gi) + @test val isa Float64 + @test val ≈ 0.25 atol=1e-4 + end + + @testset "callable interface" begin + gi = origcoord_to_grididx(grid, [0.25]) + val = qtci(gi...) + @test val isa Float64 + @test val ≈ 0.0625 atol=1e-4 + end + + @testset "sum" begin + s = sum(qtci) + @test s isa Float64 + @test isfinite(s) + end + + @testset "integral" begin + # integral of x^2 from 0 to 1 = 1/3 + val = integral(qtci) + @test val isa Float64 + # Left Riemann sum with 256 points has O(1/N) error ≈ 0.002 + @test val ≈ 1/3 atol=0.01 + end + + @testset "maxbonderror" begin + err = maxbonderror(qtci) + @test err isa Float64 + @test err >= 0.0 + end + + @testset "maxrank" begin + r = maxrank(qtci) + @test r isa Int + @test r >= 1 + end + + @testset "linkdims" begin + ld = linkdims(qtci) + @test ld isa Vector{Int} + @test all(d -> d >= 1, ld) + end + + @testset "to_tensor_train" begin + tt = to_tensor_train(qtci) + @test tt isa SimpleTensorTrain{Float64} + end + + @testset "display" begin + buf = IOBuffer() + show(buf, qtci) + s = String(take!(buf)) + @test occursin("QuanticsTensorCI2{Float64}", s) + end + end + + @testset "Float64 - discrete (size tuple)" begin + # Interpolate f(i, j) = i + j on an 8x8 grid + qtci, ranks, errors = quanticscrossinterpolate( + Float64, (i, j) -> Float64(i + j), (8, 8); + tolerance=1e-10, maxiter=50) + + @test qtci isa QuanticsTensorCI2{Float64} + @test length(ranks) > 0 + @test length(errors) > 0 + + # The function i+j should be exactly representable with low rank + @test maxrank(qtci) <= 4 + + @testset "linkdims" begin + ld = linkdims(qtci) + @test ld isa Vector{Int} + end + end + + @testset "Float64 - from Array" begin + # Create a simple 8x8 array + F = Float64[i + j for i in 1:8, j in 1:8] + qtci, ranks, errors = quanticscrossinterpolate(F; + tolerance=1e-10, maxiter=50) + + @test qtci isa QuanticsTensorCI2{Float64} + @test length(ranks) > 0 + end + + @testset "ComplexF64 - continuous grid" begin + grid = DiscretizedGrid(1, 8, [0.0], [1.0]) + + # Interpolate f(x) = exp(im * x) + qtci, ranks, errors = quanticscrossinterpolate( + ComplexF64, x -> exp(im * x), grid; + tolerance=1e-8, maxiter=50) + + @test qtci isa QuanticsTensorCI2{ComplexF64} + @test ranks isa Vector{Int} + @test errors isa Vector{Float64} + + @testset "evaluate" begin + gi = origcoord_to_grididx(grid, [0.5]) + val = evaluate(qtci, gi) + @test val isa ComplexF64 + @test val ≈ exp(im * 0.5) atol=1e-4 + end + + @testset "sum" begin + s = sum(qtci) + @test s isa ComplexF64 + end + + @testset "integral" begin + val = integral(qtci) + @test val isa ComplexF64 + end + + @testset "to_tensor_train" begin + tt = to_tensor_train(qtci) + @test tt isa SimpleTensorTrain{ComplexF64} + end + + @testset "maxbonderror and maxrank" begin + @test maxbonderror(qtci) isa Float64 + @test maxrank(qtci) isa Int + end + + @testset "display" begin + buf = IOBuffer() + show(buf, qtci) + s = String(take!(buf)) + @test occursin("QuanticsTensorCI2{ComplexF64}", s) + end + end + + @testset "ComplexF64 - discrete (size tuple)" begin + qtci, ranks, errors = quanticscrossinterpolate( + ComplexF64, (i, j) -> ComplexF64(i + im * j), (8, 8); + tolerance=1e-8, maxiter=50) + + @test qtci isa QuanticsTensorCI2{ComplexF64} + @test length(ranks) > 0 + end + + @testset "kwargs: options are passed" begin + grid = DiscretizedGrid(1, 4, [0.0], [1.0]) + + # Test that verbosity and other kwargs don't error + qtci, ranks, errors = quanticscrossinterpolate( + Float64, x -> x, grid; + tolerance=1e-4, + maxbonddim=10, + maxiter=5, + nrandominitpivot=2, + verbosity=0, + nsearchglobalpivot=2, + nsearch=10, + normalizeerror=false) + + @test qtci isa QuanticsTensorCI2{Float64} + # With maxiter=5, we should have at most 5 iterations + @test length(ranks) <= 5 + end +end diff --git a/test/test_quanticstransform.jl b/test/test_quanticstransform.jl new file mode 100644 index 0000000..ed83d8a --- /dev/null +++ b/test/test_quanticstransform.jl @@ -0,0 +1,110 @@ +using Test +using Tensor4all +using Tensor4all: dim +using Tensor4all.SimpleTT: SimpleTensorTrain +using Tensor4all.TreeTN: MPS, siteinds +using Tensor4all.QuanticsTransform: + LinearOperator, + affine_operator, + apply, + binaryop_operator, + flip_operator_multivar, + phase_rotation_operator_multivar, + set_iospaces!, + shift_operator, + shift_operator_multivar + +const CAPI = Tensor4all.C_API + +@testset "QuanticsTransform C API bindings" begin + @testset "multivar constructors" begin + out = Ref{Ptr{Cvoid}}(C_NULL) + + status = CAPI.t4a_qtransform_shift_multivar( + Csize_t(4), Int64(1), Cint(0), Csize_t(3), Csize_t(1), out) + @test status == 0 + @test out[] != C_NULL + CAPI.t4a_linop_release(out[]) + + out[] = C_NULL + status = CAPI.t4a_qtransform_flip_multivar( + Csize_t(4), Cint(1), Csize_t(3), Csize_t(2), out) + @test status == 0 + @test out[] != C_NULL + CAPI.t4a_linop_release(out[]) + + out[] = C_NULL + status = CAPI.t4a_qtransform_phase_rotation_multivar( + Csize_t(4), Cdouble(pi / 3), Csize_t(3), Csize_t(0), out) + @test status == 0 + @test out[] != C_NULL + CAPI.t4a_linop_release(out[]) + end + + @testset "affine and binaryop constructors" begin + out = Ref{Ptr{Cvoid}}(C_NULL) + + a_num = Int64[1, 1, 0, 0, 1, 1] + a_den = fill(Int64(1), 6) + b_num = Int64[0, 0, 0] + b_den = fill(Int64(1), 3) + bc = Cint[1, 1, 0] + + status = CAPI.t4a_qtransform_affine( + Csize_t(4), a_num, a_den, b_num, b_den, Csize_t(3), Csize_t(2), bc, out) + @test status == 0 + @test out[] != C_NULL + CAPI.t4a_linop_release(out[]) + + out[] = C_NULL + status = CAPI.t4a_qtransform_binaryop( + Csize_t(4), Int8(1), Int8(1), Int8(1), Int8(-1), Cint(1), Cint(0), out) + @test status == 0 + @test out[] != C_NULL + CAPI.t4a_linop_release(out[]) + end + + @testset "mapping rewrite enables high-level apply" begin + tt = SimpleTensorTrain([2, 2, 2], 1.0) + mps = MPS(tt) + op = shift_operator(3, 1) + + set_iospaces!(op, mps) + result = apply(op, mps; method=:naive) + + @test result isa Tensor4all.TreeTN.TreeTensorNetwork + end + + @testset "high-level multivar wrappers construct operators" begin + @test shift_operator_multivar(3, 1, 2, 0) isa LinearOperator + @test flip_operator_multivar(3, 2, 1; bc=Tensor4all.QuanticsTransform.Open) isa LinearOperator + @test phase_rotation_operator_multivar(3, pi / 4, 2, 1) isa LinearOperator + @test binaryop_operator(3, 1, 1, 1, -1) isa LinearOperator + end + + @testset "affine wrapper supports explicit output space" begin + input_mps = MPS(SimpleTensorTrain(fill(4, 3), 1.0)) + output_mps = MPS(SimpleTensorTrain(fill(8, 3), 0.0)) + + a_num = Int64[ + 1 -1 + 1 0 + 0 1 + ] + a_den = ones(Int64, 3, 2) + b_num = Int64[0, 0, 0] + b_den = ones(Int64, 3) + bc = [ + Tensor4all.QuanticsTransform.Open, + Tensor4all.QuanticsTransform.Periodic, + Tensor4all.QuanticsTransform.Periodic, + ] + + op = affine_operator(3, a_num, a_den, b_num, b_den; bc=bc) + set_iospaces!(op, input_mps, output_mps) + result = apply(op, input_mps; method=:naive) + + @test result isa Tensor4all.TreeTN.TreeTensorNetwork + @test dim(siteinds(result, 1)[1]) == 8 + end +end diff --git a/test/test_simplett.jl b/test/test_simplett.jl index 1c34ebd..5bb99c3 100644 --- a/test/test_simplett.jl +++ b/test/test_simplett.jl @@ -1,91 +1,378 @@ using Test using Tensor4all using Tensor4all.SimpleTT - -# Import functions from SimpleTT module -import Tensor4all.SimpleTT: site_dims, link_dims, rank, evaluate, site_tensor +import Tensor4all.SimpleTT: sitedims, linkdims, rank, evaluate, sitetensor, fulltensor, scale! +using LinearAlgebra: dot, norm @testset "SimpleTT" begin - @testset "constant tensor train" begin - # Create a constant tensor train - tt = SimpleTensorTrain([2, 3, 4], 1.5) + @testset "Float64" begin + @testset "Construction" begin + @testset "constant tensor train" begin + tt = SimpleTensorTrain([2, 3, 4], 1.5) + @test length(tt) == 3 + @test sitedims(tt) == [2, 3, 4] + @test rank(tt) == 1 + end - @test length(tt) == 3 - @test site_dims(tt) == [2, 3, 4] - @test rank(tt) == 1 # Constant has rank 1 + @testset "zeros tensor train" begin + tt = zeros(SimpleTensorTrain, [2, 3]) + @test length(tt) == 2 + @test sitedims(tt) == [2, 3] + @test sum(tt) == 0.0 + end - # Sum should be value * product of dimensions - expected_sum = 1.5 * 2 * 3 * 4 - @test sum(tt) ≈ expected_sum - end + @testset "from site tensors" begin + t1 = randn(1, 2, 3) + t2 = randn(3, 4, 1) + tt = SimpleTensorTrain([t1, t2]) + @test length(tt) == 2 + @test sitedims(tt) == [2, 4] + @test linkdims(tt) == [3] + end + end - @testset "zeros tensor train" begin - tt = zeros(SimpleTensorTrain, [2, 3]) + @testset "Accessors" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0) + @test length(tt) == 3 + @test sitedims(tt) == [2, 3, 4] - @test length(tt) == 2 - @test site_dims(tt) == [2, 3] - @test sum(tt) == 0.0 - end + ldims = linkdims(tt) + @test length(ldims) == 2 # n_sites - 1 + @test all(d -> d == 1, ldims) # rank-1 constant - @testset "evaluate" begin - tt = SimpleTensorTrain([2, 3, 4], 2.0) + @test rank(tt) == 1 + end - # All elements should be 2.0 - @test evaluate(tt, [0, 0, 0]) ≈ 2.0 - @test evaluate(tt, [1, 2, 3]) ≈ 2.0 + @testset "Evaluation (1-indexed)" begin + tt = SimpleTensorTrain([2, 3, 4], 2.0) - # Test callable interface - @test tt([0, 1, 2]) ≈ 2.0 - @test tt(0, 1, 2) ≈ 2.0 - end + # 1-indexed evaluation + @test evaluate(tt, [1, 1, 1]) ≈ 2.0 + @test evaluate(tt, [2, 3, 4]) ≈ 2.0 + @test evaluate(tt, [1, 2, 3]) ≈ 2.0 - @testset "copy" begin - tt1 = SimpleTensorTrain([2, 3], 3.0) - tt2 = copy(tt1) + # Callable interface (1-indexed) + @test tt([1, 1, 1]) ≈ 2.0 + @test tt(1, 2, 3) ≈ 2.0 + end - @test length(tt2) == length(tt1) - @test site_dims(tt2) == site_dims(tt1) - @test sum(tt2) ≈ sum(tt1) - end + @testset "1-indexing verification" begin + tt = SimpleTensorTrain([2, 3, 4], 1.5) - @testset "link_dims" begin - tt = SimpleTensorTrain([2, 3, 4], 1.0) + # evaluate with [1,1,1] should work (first element) + @test evaluate(tt, [1, 1, 1]) ≈ 1.5 - ldims = link_dims(tt) - @test length(ldims) == 2 # n_sites - 1 + # sitetensor(tt, 1) returns the first site tensor + t1 = sitetensor(tt, 1) + @test size(t1, 2) == 2 # first site has dim 2 + end - # For rank-1 constant, link dims should all be 1 - @test all(d -> d == 1, ldims) - end + @testset "Site tensor" begin + tt = SimpleTensorTrain([2, 3], 1.0) + + # First site (1-indexed) + t1 = sitetensor(tt, 1) + @test size(t1, 1) == 1 # left dim + @test size(t1, 2) == 2 # site dim + @test size(t1, 3) == 1 # right dim (rank-1, single site link) + + # Last site (1-indexed) + t2 = sitetensor(tt, 2) + @test size(t2, 1) == 1 # left dim + @test size(t2, 2) == 3 # site dim + @test size(t2, 3) == 1 # right dim + end + + @testset "Arithmetic" begin + tt1 = SimpleTensorTrain([2, 3], 1.0) + tt2 = SimpleTensorTrain([2, 3], 2.0) + + # Addition + tt3 = tt1 + tt2 + @test sum(tt3) ≈ sum(tt1) + sum(tt2) + + # Subtraction + tt4 = tt1 - tt2 + @test sum(tt4) ≈ sum(tt1) - sum(tt2) + + # Scalar multiplication + tt5 = 3.0 * tt1 + @test sum(tt5) ≈ 3.0 * sum(tt1) + + tt6 = tt1 * 3.0 + @test sum(tt6) ≈ 3.0 * sum(tt1) + + # Dot product + @test dot(tt1, tt1) ≈ norm(tt1)^2 + @test dot(tt1, tt2) ≈ 2.0 * dot(tt1, tt1) + end + + @testset "In-place scale!" begin + tt = SimpleTensorTrain([2, 3], 1.0) + original_sum = sum(tt) + scale!(tt, 2.5) + @test sum(tt) ≈ 2.5 * original_sum + end - @testset "site_tensor" begin - tt = SimpleTensorTrain([2, 3], 1.0) + @testset "reverse" begin + tt = SimpleTensorTrain([2, 3, 4], 1.5) + tt_rev = reverse(tt) + @test length(tt_rev) == 3 + @test sitedims(tt_rev) == [4, 3, 2] + @test sum(tt_rev) ≈ sum(tt) + end - # Get site tensor at site 0 - t0 = site_tensor(tt, 0) - @test size(t0, 1) == 1 # left dim - @test size(t0, 2) == 2 # site dim - @test size(t0, 3) == 1 # right dim + @testset "fulltensor" begin + tt = SimpleTensorTrain([2, 3], 1.5) + arr = fulltensor(tt) + @test size(arr) == (2, 3) + @test all(x -> x ≈ 1.5, arr) + end - # Get site tensor at site 1 - t1 = site_tensor(tt, 1) - @test size(t1, 1) == 1 # left dim - @test size(t1, 2) == 3 # site dim - @test size(t1, 3) == 1 # right dim + @testset "copy" begin + tt1 = SimpleTensorTrain([2, 3], 3.0) + tt2 = copy(tt1) + @test length(tt2) == length(tt1) + @test sitedims(tt2) == sitedims(tt1) + @test sum(tt2) ≈ sum(tt1) + + # Ensure deep copy: modifying tt2 doesn't affect tt1 + scale!(tt2, 0.0) + @test sum(tt1) ≈ 3.0 * 2 * 3 + @test sum(tt2) ≈ 0.0 + end + + @testset "norm" begin + tt = SimpleTensorTrain([2, 3], 1.5) + # norm^2 = sum of squares = 1.5^2 * 2 * 3 = 2.25 * 6 = 13.5 + @test norm(tt) ≈ sqrt(1.5^2 * 2 * 3) + end + + @testset "sum" begin + tt = SimpleTensorTrain([2, 3, 4], 1.5) + @test sum(tt) ≈ 1.5 * 2 * 3 * 4 + end + + @testset "compress!" begin + # Create a TT by adding two rank-1 TTs (result has rank 2) + tt1 = SimpleTensorTrain([2, 3, 4], 1.0) + tt2 = SimpleTensorTrain([2, 3, 4], 2.0) + tt = tt1 + tt2 + @test rank(tt) == 2 + + original_sum = sum(tt) + compress!(tt; method=:SVD, tolerance=1e-12) + @test rank(tt) == 1 # Should compress back to rank 1 + @test sum(tt) ≈ original_sum + end + + @testset "show" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0) + + io = IOBuffer() + show(io, tt) + s = String(take!(io)) + @test occursin("SimpleTensorTrain", s) + @test occursin("3", s) # sites + + show(io, MIME"text/plain"(), tt) + s = String(take!(io)) + @test occursin("Sites:", s) + end + + @testset "from site tensors - evaluation" begin + # Create site tensors and verify evaluate matches direct contraction + t1 = randn(1, 2, 3) + t2 = randn(3, 4, 1) + tt = SimpleTensorTrain([t1, t2]) + + arr = fulltensor(tt) + # Check a few evaluations match the full tensor + for i in 1:2, j in 1:4 + @test evaluate(tt, [i, j]) ≈ arr[i, j] + end + end end - @testset "show" begin - tt = SimpleTensorTrain([2, 3, 4], 1.0) + @testset "ComplexF64" begin + @testset "Construction" begin + @testset "constant tensor train" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + @test length(tt) == 2 + @test sitedims(tt) == [2, 3] + @test rank(tt) == 1 + end + + @testset "zeros tensor train" begin + tt = zeros(SimpleTensorTrain{ComplexF64}, [2, 3]) + @test length(tt) == 2 + @test sitedims(tt) == [2, 3] + @test sum(tt) == 0.0 + 0.0im + end + + @testset "from site tensors" begin + t1 = randn(ComplexF64, 1, 2, 3) + t2 = randn(ComplexF64, 3, 4, 1) + tt = SimpleTensorTrain([t1, t2]) + @test length(tt) == 2 + @test sitedims(tt) == [2, 4] + @test linkdims(tt) == [3] + end + end + + @testset "Accessors" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0 + 0.0im) + @test length(tt) == 3 + @test sitedims(tt) == [2, 3, 4] + + ldims = linkdims(tt) + @test length(ldims) == 2 + @test all(d -> d == 1, ldims) + + @test rank(tt) == 1 + end + + @testset "Evaluation (1-indexed)" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0 + 2.0im) + + @test evaluate(tt, [1, 1, 1]) ≈ 1.0 + 2.0im + @test evaluate(tt, [2, 3, 4]) ≈ 1.0 + 2.0im + + # Callable interface + @test tt([1, 1, 1]) ≈ 1.0 + 2.0im + @test tt(1, 2, 3) ≈ 1.0 + 2.0im + end + + @testset "1-indexing verification" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + + @test evaluate(tt, [1, 1]) ≈ 1.0 + 2.0im + + t1 = sitetensor(tt, 1) + @test size(t1, 2) == 2 # first site has dim 2 + end + + @testset "Site tensor" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 0.0im) + + t1 = sitetensor(tt, 1) + @test size(t1, 1) == 1 + @test size(t1, 2) == 2 + @test size(t1, 3) == 1 + @test eltype(t1) == ComplexF64 + + t2 = sitetensor(tt, 2) + @test size(t2, 1) == 1 + @test size(t2, 2) == 3 + @test size(t2, 3) == 1 + @test eltype(t2) == ComplexF64 + end + + @testset "Arithmetic" begin + tt1 = SimpleTensorTrain([2, 3], 1.0 + 1.0im) + tt2 = SimpleTensorTrain([2, 3], 2.0 + 0.5im) + + # Addition + tt3 = tt1 + tt2 + @test sum(tt3) ≈ sum(tt1) + sum(tt2) + + # Subtraction + tt4 = tt1 - tt2 + @test sum(tt4) ≈ sum(tt1) - sum(tt2) + + # Scalar multiplication (complex scalar) + tt5 = (2.0 + 1.0im) * tt1 + @test sum(tt5) ≈ (2.0 + 1.0im) * sum(tt1) + + tt6 = tt1 * (2.0 + 1.0im) + @test sum(tt6) ≈ (2.0 + 1.0im) * sum(tt1) + + # Dot product: dot(a, b) = sum(conj(a) .* b) + @test dot(tt1, tt1) ≈ norm(tt1)^2 + end + + @testset "In-place scale!" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 1.0im) + original_sum = sum(tt) + scale!(tt, 2.0 + 0.5im) + @test sum(tt) ≈ (2.0 + 0.5im) * original_sum + end + + @testset "reverse" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0 + 2.0im) + tt_rev = reverse(tt) + @test length(tt_rev) == 3 + @test sitedims(tt_rev) == [4, 3, 2] + @test sum(tt_rev) ≈ sum(tt) + end + + @testset "fulltensor" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + arr = fulltensor(tt) + @test size(arr) == (2, 3) + @test eltype(arr) == ComplexF64 + @test all(x -> x ≈ 1.0 + 2.0im, arr) + end + + @testset "copy" begin + tt1 = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + tt2 = copy(tt1) + @test length(tt2) == length(tt1) + @test sitedims(tt2) == sitedims(tt1) + @test sum(tt2) ≈ sum(tt1) + + # Ensure deep copy + scale!(tt2, 0.0 + 0.0im) + @test sum(tt1) ≈ (1.0 + 2.0im) * 2 * 3 + @test sum(tt2) ≈ 0.0 + 0.0im + end + + @testset "norm" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + # norm^2 = sum of |z|^2 = |1+2i|^2 * 2 * 3 = 5 * 6 = 30 + @test norm(tt) ≈ sqrt(abs2(1.0 + 2.0im) * 2 * 3) + end + + @testset "sum" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + @test sum(tt) ≈ (1.0 + 2.0im) * 2 * 3 + end + + @testset "compress!" begin + tt1 = SimpleTensorTrain([2, 3, 4], 1.0 + 0.0im) + tt2 = SimpleTensorTrain([2, 3, 4], 0.0 + 2.0im) + tt = tt1 + tt2 + @test rank(tt) == 2 + + original_sum = sum(tt) + compress!(tt; method=:SVD, tolerance=1e-12) + @test rank(tt) == 1 + @test sum(tt) ≈ original_sum + end + + @testset "show" begin + tt = SimpleTensorTrain([2, 3], 1.0 + 2.0im) + + io = IOBuffer() + show(io, tt) + s = String(take!(io)) + @test occursin("SimpleTensorTrain", s) + @test occursin("ComplexF64", s) + + show(io, MIME"text/plain"(), tt) + s = String(take!(io)) + @test occursin("Sites:", s) + end - # Test that show doesn't error - io = IOBuffer() - show(io, tt) - s = String(take!(io)) - @test occursin("SimpleTensorTrain", s) - @test occursin("3", s) # sites + @testset "from site tensors - evaluation" begin + t1 = randn(ComplexF64, 1, 2, 3) + t2 = randn(ComplexF64, 3, 4, 1) + tt = SimpleTensorTrain([t1, t2]) - show(io, MIME"text/plain"(), tt) - s = String(take!(io)) - @test occursin("Sites:", s) + arr = fulltensor(tt) + for i in 1:2, j in 1:4 + @test evaluate(tt, [i, j]) ≈ arr[i, j] + end + end end end diff --git a/test/test_treetci.jl b/test/test_treetci.jl new file mode 100644 index 0000000..1f54caf --- /dev/null +++ b/test/test_treetci.jl @@ -0,0 +1,145 @@ +using Test +using Tensor4all.TreeTCI +import Tensor4all: evaluate, maxbonderror, maxrank +import Tensor4all.TreeTCI: maxsamplevalue, + bonddims, to_treetn, sweep!, add_global_pivots! + +@testset "TreeTCI" begin + @testset "1-indexed crossinterpolate2 - linear chain" begin + # Linear chain: 0-1-2-3 + graph = TreeTciGraph(4, [(0, 1), (1, 2), (2, 3)]) + + # Batch evaluation function: product of indices (1-based) + # f(batch) where batch is (n_sites, n_points) with 1-based indices + function f_product(batch) + n_sites, n_pts = size(batch) + [prod(Float64(batch[i, j]) for i in 1:n_sites) for j in 1:n_pts] + end + + tci, ranks, errors = crossinterpolate2(Float64, f_product, [3, 3, 3, 3], graph; + tolerance=1e-10, maxiter=20, + initialpivots=[ones(Int, 4)]) + + @test tci isa SimpleTreeTci{Float64} + @test length(ranks) > 0 + @test length(errors) > 0 + @test ranks isa Vector{Int} + @test errors isa Vector{Float64} + + @testset "state inspection" begin + @test maxrank(tci) isa Int + @test maxrank(tci) >= 1 + @test maxbonderror(tci) isa Float64 + @test maxsamplevalue(tci) isa Float64 + @test maxsamplevalue(tci) > 0.0 + + bd = bonddims(tci) + @test bd isa Vector{Int} + @test length(bd) == 3 # n_edges = n_sites - 1 for linear chain + end + + @testset "materialize to TreeTN" begin + ttn = to_treetn(tci, f_product) + @test ttn !== nothing + + @testset "evaluate with 1-based indices" begin + # f(1,1,1,1) = 1*1*1*1 = 1.0 + val = evaluate(ttn, [1, 1, 1, 1]) + @test val ≈ 1.0 atol=1e-8 + + # f(2,3,1,2) = 2*3*1*2 = 12.0 + val = evaluate(ttn, [2, 3, 1, 2]) + @test val ≈ 12.0 atol=1e-6 + + # f(3,3,3,3) = 3*3*3*3 = 81.0 + val = evaluate(ttn, [3, 3, 3, 3]) + @test val ≈ 81.0 atol=1e-6 + end + + @testset "batch evaluate with 1-based indices" begin + batch = [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]] + vals = evaluate(ttn, batch) + @test vals[1] ≈ 1.0 atol=1e-8 + @test vals[2] ≈ 16.0 atol=1e-6 + @test vals[3] ≈ 81.0 atol=1e-6 + end + + @testset "matrix evaluate with 1-based indices" begin + # 4 sites, 2 points + batch_mat = [1 2; 1 2; 1 2; 1 2] + vals = evaluate(ttn, batch_mat) + @test vals[1] ≈ 1.0 atol=1e-8 + @test vals[2] ≈ 16.0 atol=1e-6 + end + end + end + + @testset "1-indexed crossinterpolate2 - sum function" begin + graph = TreeTciGraph(3, [(0, 1), (1, 2)]) + + # Sum function (1-based indices) + function f_sum(batch) + n_sites, n_pts = size(batch) + [Base.sum(Float64(batch[i, j]) for i in 1:n_sites) for j in 1:n_pts] + end + + tci, ranks, errors = crossinterpolate2(Float64, f_sum, [4, 4, 4], graph; + tolerance=1e-10, maxiter=20) + + @test tci isa SimpleTreeTci{Float64} + + ttn = to_treetn(tci, f_sum) + + # f(1,1,1) = 1+1+1 = 3.0 + @test evaluate(ttn, [1, 1, 1]) ≈ 3.0 atol=1e-8 + # f(4,4,4) = 4+4+4 = 12.0 + @test evaluate(ttn, [4, 4, 4]) ≈ 12.0 atol=1e-6 + # f(1,2,3) = 1+2+3 = 6.0 + @test evaluate(ttn, [1, 2, 3]) ≈ 6.0 atol=1e-6 + end + + @testset "default initialpivots is ones" begin + graph = TreeTciGraph(3, [(0, 1), (1, 2)]) + + function f_const(batch) + n_sites, n_pts = size(batch) + fill(42.0, n_pts) + end + + # Should work without specifying initialpivots (default is ones) + tci, ranks, errors = crossinterpolate2(Float64, f_const, [2, 2, 2], graph; + tolerance=1e-8, maxiter=5) + + @test tci isa SimpleTreeTci{Float64} + end + + @testset "type inference" begin + graph = TreeTciGraph(2, [(0, 1)]) + + # Float64 inference + f_real(batch) = fill(1.0, size(batch, 2)) + tci, _, _ = crossinterpolate2(f_real, [2, 2], graph; maxiter=3) + @test tci isa SimpleTreeTci{Float64} + end + + @testset "star graph topology" begin + # Star graph: site 0 connected to sites 1,2,3 + graph = TreeTciGraph(4, [(0, 1), (0, 2), (0, 3)]) + + function f_star(batch) + n_sites, n_pts = size(batch) + [Float64(batch[1, j]) * Float64(batch[2, j]) + Float64(batch[3, j]) + Float64(batch[4, j]) + for j in 1:n_pts] + end + + tci, ranks, errors = crossinterpolate2(Float64, f_star, [3, 3, 3, 3], graph; + tolerance=1e-10, maxiter=20, initialpivots=[ones(Int, 4)]) + + ttn = to_treetn(tci, f_star) + + # f(1,1,1,1) = 1*1 + 1 + 1 = 3.0 + @test evaluate(ttn, [1, 1, 1, 1]) ≈ 3.0 atol=1e-6 + # f(2,3,1,1) = 2*3 + 1 + 1 = 8.0 + @test evaluate(ttn, [2, 3, 1, 1]) ≈ 8.0 atol=1e-6 + end +end