From 88e3659bf5786e6d5782bee7d166e1bb008bbfe9 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 7 Apr 2026 11:21:23 +0900 Subject: [PATCH 1/5] feat: add prime level (plev) support to Index type Add plev across the Julia wrapper: - C_API.jl: t4a_index_get_plev, t4a_index_set_plev, t4a_index_prime wrappers - Tensor4all.jl: plev(), prime(), noprime(), setprime() functions - Updated ==, hash, sim, show to include plev - commoninds/uniqueinds/replaceinds/hasind/hasinds now plev-aware - ITensors extension preserves plev in both conversion directions - Tests for all plev functionality Co-Authored-By: Claude Opus 4.6 (1M context) --- ext/Tensor4allITensorsExt.jl | 15 +++-- src/C_API.jl | 44 ++++++++++++++ src/Tensor4all.jl | 115 +++++++++++++++++++++++++++-------- test/itensors_ext_test.jl | 14 +++++ test/test_index.jl | 45 ++++++++++++++ 5 files changed, 202 insertions(+), 31 deletions(-) diff --git a/ext/Tensor4allITensorsExt.jl b/ext/Tensor4allITensorsExt.jl index 6856868..eea6594 100644 --- a/ext/Tensor4allITensorsExt.jl +++ b/ext/Tensor4allITensorsExt.jl @@ -18,8 +18,8 @@ Both Rust and ITensors use UInt64 IDs natively, so conversion is direct. ## Memory Order -Rust uses row-major, Julia/ITensors uses column-major. -Conversion is handled automatically. +Rust and Julia/ITensors use column-major order. +Conversion preserves the logical layout. """ module Tensor4allITensorsExt @@ -40,12 +40,13 @@ IDs are natively UInt64 in both systems. Tags are preserved. function ITensors.Index(idx::Tensor4all.Index) d = Tensor4all.dim(idx) t = Tensor4all.tags(idx) + p = Tensor4all.plev(idx) id64 = Tensor4all.id(idx) # Create ITensors.Index with explicit ID using full constructor # Index(id, space, dir, tags, plev) tagset = isempty(t) ? ITensors.TagSet("") : ITensors.TagSet(t) - return ITensors.Index(id64, d, ITensors.Neither, tagset, 0) + return ITensors.Index(id64, d, ITensors.Neither, tagset, p) end # ============================================================================ @@ -65,12 +66,18 @@ cause an error. function Tensor4all.Index(idx::ITensors.Index) d = ITensors.dim(idx) id64 = ITensors.id(idx) + p = ITensors.plev(idx) # Get tags as comma-separated string tag_set = ITensors.tags(idx) tags_str = _tags_to_string(tag_set) - return Tensor4all.Index(d, id64; tags=tags_str) + t4a_idx = Tensor4all.Index(d, id64; tags=tags_str) + if p != 0 + status = Tensor4all.C_API.t4a_index_set_plev(t4a_idx.ptr, Int64(p)) + Tensor4all.C_API.check_status(status) + end + return t4a_idx end """ diff --git a/src/C_API.jl b/src/C_API.jl index d51c4d1..e8f5b69 100644 --- a/src/C_API.jl +++ b/src/C_API.jl @@ -272,6 +272,21 @@ function t4a_index_get_tags(ptr::Ptr{Cvoid}, buf, buf_len::Integer, out_len::Ref ) end +""" + t4a_index_get_plev(ptr::Ptr{Cvoid}, out_plev::Ref{Int64}) -> Cint + +Get the prime level of an index. +""" +function t4a_index_get_plev(ptr::Ptr{Cvoid}, out_plev::Ref{Int64}) + return ccall( + _sym(:t4a_index_get_plev), + Cint, + (Ptr{Cvoid}, Ptr{Int64}), + ptr, + out_plev + ) +end + # ============================================================================ # Index modifiers # ============================================================================ @@ -322,6 +337,35 @@ function t4a_index_has_tag(ptr::Ptr{Cvoid}, tag::AbstractString) ) end +""" + t4a_index_set_plev(ptr::Ptr{Cvoid}, plev::Integer) -> Cint + +Set the prime level of an index. +""" +function t4a_index_set_plev(ptr::Ptr{Cvoid}, plev::Integer) + return ccall( + _sym(:t4a_index_set_plev), + Cint, + (Ptr{Cvoid}, Int64), + ptr, + Int64(plev) + ) +end + +""" + t4a_index_prime(ptr::Ptr{Cvoid}) -> Cint + +Increment the prime level of an index by 1. +""" +function t4a_index_prime(ptr::Ptr{Cvoid}) + return ccall( + _sym(:t4a_index_prime), + Cint, + (Ptr{Cvoid},), + ptr + ) +end + # ============================================================================ # Storage kind enum # ============================================================================ diff --git a/src/Tensor4all.jl b/src/Tensor4all.jl index dc1094a..7fe9d59 100644 --- a/src/Tensor4all.jl +++ b/src/Tensor4all.jl @@ -57,7 +57,7 @@ include("Algorithm.jl") # Re-export public API # Core types (tensor4all-core-common, tensor4all-core-tensor) -export Index, dim, tags, id, hastag +export Index, dim, tags, id, hastag, plev export Tensor, rank, dims, indices, storage_kind, data export StorageKind, DenseF64, DenseC64, DiagF64, DiagC64 @@ -86,6 +86,7 @@ ITensors.jl's `Index{Int}` (no quantum number symmetry). - `id(i::Index)` - Get the unique ID as UInt64 - `tags(i::Index)` - Get tags as comma-separated string - `hastag(i::Index, tag::AbstractString)` - Check if index has a tag +- `plev(i::Index)` - Get the prime level """ mutable struct Index ptr::Ptr{Cvoid} @@ -179,24 +180,39 @@ function hastag(i::Index, tag::AbstractString) return result == 1 end +""" + plev(i::Index) -> Int + +Get the prime level of an index. +""" +function plev(i::Index) + out_plev = Ref{Int64}(0) + status = C_API.t4a_index_get_plev(i.ptr, out_plev) + C_API.check_status(status) + return Int(out_plev[]) +end + # Show method function Base.show(io::IO, i::Index) d = dim(i) t = tags(i) + p = plev(i) id_val = id(i) id_hex = string(id_val, base=16) - id_short = length(id_hex) >= 8 ? id_hex[end-7:end] : id_hex # Last 8 hex digits + id_short = length(id_hex) >= 8 ? id_hex[end-7:end] : id_hex + prime_str = p > 0 ? repeat("'", p) : "" if isempty(t) - print(io, "(dim=$d|id=...$id_short)") + print(io, "(dim=$d|id=...$id_short)$prime_str") else - print(io, "(dim=$d|id=...$id_short|\"$t\")") + print(io, "(dim=$d|id=...$id_short|\"$t\")$prime_str") end end function Base.show(io::IO, ::MIME"text/plain", i::Index) println(io, "Tensor4all.Index") - println(io, " dim: ", dim(i)) - println(io, " id: ", string(id(i), base=16)) + println(io, " dim: ", dim(i)) + println(io, " id: ", string(id(i), base=16)) + println(io, " plev: ", plev(i)) t = tags(i) if !isempty(t) println(io, " tags: ", t) @@ -215,15 +231,17 @@ function Base.deepcopy(i::Index) return Index(ptr) end -# Equality based on ID + tags (matching ITensors.jl semantics with plev=0) +# Equality based on ID + tags + plev (matching ITensors.jl semantics) function Base.:(==)(i1::Index, i2::Index) - return id(i1) == id(i2) && tags(i1) == tags(i2) + return id(i1) == id(i2) && tags(i1) == tags(i2) && plev(i1) == plev(i2) end function Base.hash(i::Index, h::UInt) - return hash(tags(i), hash(id(i), h)) + return hash(plev(i), hash(tags(i), hash(id(i), h))) end +_index_key(i::Index) = (id(i), tags(i), plev(i)) + # ============================================================================ # Index Utilities # ============================================================================ @@ -235,13 +253,55 @@ Create a new index with the same dimension and tags but a new unique ID. This is useful for creating "similar" indices that won't contract with the original. """ function sim(i::Index) - return Index(dim(i); tags=tags(i)) + j = Index(dim(i); tags=tags(i)) + p = plev(i) + if p != 0 + status = C_API.t4a_index_set_plev(j.ptr, Int64(p)) + C_API.check_status(status) + end + return j +end + +""" + prime(i::Index, n::Integer=1) -> Index + +Return a copy of the index with prime level incremented by `n`. +""" +function prime(i::Index, n::Integer=1) + j = copy(i) + status = C_API.t4a_index_set_plev(j.ptr, Int64(plev(i) + n)) + C_API.check_status(status) + return j +end + +""" + noprime(i::Index) -> Index + +Return a copy of the index with prime level set to 0. +""" +function noprime(i::Index) + j = copy(i) + status = C_API.t4a_index_set_plev(j.ptr, Int64(0)) + C_API.check_status(status) + return j +end + +""" + setprime(i::Index, n::Integer) -> Index + +Return a copy of the index with prime level set to `n`. +""" +function setprime(i::Index, n::Integer) + j = copy(i) + status = C_API.t4a_index_set_plev(j.ptr, Int64(n)) + C_API.check_status(status) + return j end """ hascommoninds(inds1, inds2) -> Bool -Check if two collections of indices have any common indices (by ID). +Check if two collections of indices have any common indices. # Example ```julia @@ -251,9 +311,9 @@ hascommoninds([i], [k]) # false ``` """ function hascommoninds(inds1, inds2) - ids1 = Set(id(i) for i in inds1) + keys1 = Set(_index_key(i) for i in inds1) for i in inds2 - if id(i) in ids1 + if _index_key(i) in keys1 return true end end @@ -263,7 +323,7 @@ end """ commoninds(inds1, inds2) -> Vector{Index} -Return the indices that appear in both collections (by ID). +Return the indices that appear in both collections. Returns indices from inds1. # Example @@ -273,8 +333,8 @@ commoninds([i, j], [j, k]) # [j] ``` """ function commoninds(inds1, inds2) - ids2 = Set(id(i) for i in inds2) - return [i for i in inds1 if id(i) in ids2] + keys2 = Set(_index_key(i) for i in inds2) + return [i for i in inds1 if _index_key(i) in keys2] end # Alias for ITensors compatibility @@ -293,7 +353,7 @@ end """ uniqueinds(inds1, inds2) -> Vector{Index} -Return the indices in inds1 that do not appear in inds2 (by ID). +Return the indices in inds1 that do not appear in inds2. # Example ```julia @@ -302,8 +362,8 @@ uniqueinds([i, j], [j, k]) # [i] ``` """ function uniqueinds(inds1, inds2) - ids2 = Set(id(i) for i in inds2) - return [i for i in inds1 if !(id(i) in ids2)] + keys2 = Set(_index_key(i) for i in inds2) + return [i for i in inds1 if !(_index_key(i) in keys2)] end """ @@ -333,8 +393,8 @@ Replace indices in `inds` according to the mapping old_inds → new_inds. """ function replaceinds(inds, old_inds, new_inds) length(old_inds) == length(new_inds) || error("old_inds and new_inds must have same length") - id_map = Dict(id(o) => n for (o, n) in zip(old_inds, new_inds)) - return [get(id_map, id(i), i) for i in inds] + index_map = Dict(_index_key(o) => n for (o, n) in zip(old_inds, new_inds)) + return [get(index_map, _index_key(i), i) for i in inds] end """ @@ -348,6 +408,7 @@ end export sim, hascommoninds, commoninds, common_inds, commonind export uniqueinds, uniqueind, noncommoninds, replaceinds, replaceind +export prime, noprime, setprime # ============================================================================ # Curried/Predicate Index Functions @@ -363,12 +424,12 @@ A predicate type for checking if an object has common indices with a given set. Used internally by `hascommoninds(is)` curried form. """ struct HasCommonIndsPredicate - target_ids::Set{UInt64} + target_keys::Set{Tuple{UInt64, String, Int}} end function (p::HasCommonIndsPredicate)(x) for idx in indices(x) - if id(idx) in p.target_ids + if _index_key(idx) in p.target_keys return true end end @@ -389,7 +450,7 @@ findfirst(hascommoninds(sites[2:2]), tt) # Returns 2 ``` """ function hascommoninds(is::Vector{Index}) - return HasCommonIndsPredicate(Set(id(i) for i in is)) + return HasCommonIndsPredicate(Set(_index_key(i) for i in is)) end hascommoninds(i::Index) = hascommoninds([i]) @@ -406,7 +467,7 @@ tt = random_tt(sites; linkdims=2) findfirst(hasind(sites[1]), tt) # Returns 1 ``` """ -hasind(i::Index) = x -> any(idx -> id(idx) == id(i), indices(x)) +hasind(i::Index) = x -> any(idx -> idx == i, indices(x)) """ hasinds(is) -> Function @@ -424,9 +485,9 @@ hasinds([i])(t) # true function hasinds(is) return function(x) x_inds = indices(x) - x_ids = Set(id(idx) for idx in x_inds) + x_keys = Set(_index_key(idx) for idx in x_inds) for i in is - if !(id(i) in x_ids) + if !(_index_key(i) in x_keys) return false end end diff --git a/test/itensors_ext_test.jl b/test/itensors_ext_test.jl index e12e8b7..049cd9e 100644 --- a/test/itensors_ext_test.jl +++ b/test/itensors_ext_test.jl @@ -12,6 +12,10 @@ t4a_id = Tensor4all.id(t4a_idx) expected_id = UInt64(t4a_id & 0xFFFFFFFFFFFFFFFF) @test ITensors.id(it_idx) == expected_id + + primed = Tensor4all.setprime(t4a_idx, 2) + primed_it = ITensors.Index(primed) + @test ITensors.plev(primed_it) == 2 end @testset "ITensors.Index → Tensor4all.Index" begin @@ -28,6 +32,10 @@ t4a_id = Tensor4all.id(t4a_idx) @test UInt64(t4a_id & 0xFFFFFFFFFFFFFFFF) == it_id @test UInt64(t4a_id >> 64) == 0 # Upper bits should be 0 + + primed_it = ITensors.Index(it_id, 3, ITensors.Neither, ITensors.TagSet("Link,l=2"), 2) + primed_t4a = Tensor4all.Index(primed_it) + @test Tensor4all.plev(primed_t4a) == 2 end @testset "Roundtrip conversion" begin @@ -38,6 +46,7 @@ @test Tensor4all.dim(orig) == Tensor4all.dim(back) @test Tensor4all.tags(orig) == Tensor4all.tags(back) + @test Tensor4all.plev(orig) == Tensor4all.plev(back) # IDs match in lower 64 bits (upper bits may differ after roundtrip) orig_lo = UInt64(Tensor4all.id(orig) & 0xFFFFFFFFFFFFFFFF) back_lo = UInt64(Tensor4all.id(back) & 0xFFFFFFFFFFFFFFFF) @@ -51,6 +60,11 @@ @test ITensors.dim(it_orig) == ITensors.dim(it_back) @test ITensors.id(it_orig) == ITensors.id(it_back) @test ITensors.hastags(it_back, "Bond") + + primed_orig = Tensor4all.setprime(orig, 3) + primed_it = ITensors.Index(primed_orig) + primed_back = Tensor4all.Index(primed_it) + @test Tensor4all.plev(primed_back) == 3 end @testset "Convert function" begin diff --git a/test/test_index.jl b/test/test_index.jl index 64b7acc..b2f4bc8 100644 --- a/test/test_index.jl +++ b/test/test_index.jl @@ -59,4 +59,49 @@ @test_throws ArgumentError T4AIndex(0) @test_throws ArgumentError T4AIndex(-1) end + + @testset "prime level" begin + i = T4AIndex(5; tags="Site") + + # Default plev is 0 + @test Tensor4all.plev(i) == 0 + + # prime + ip = Tensor4all.prime(i) + @test Tensor4all.plev(ip) == 1 + @test Tensor4all.id(ip) == Tensor4all.id(i) + + # double prime + ipp = Tensor4all.prime(ip) + @test Tensor4all.plev(ipp) == 2 + + # noprime + i0 = Tensor4all.noprime(ipp) + @test Tensor4all.plev(i0) == 0 + + # setprime + i3 = Tensor4all.setprime(i, 3) + @test Tensor4all.plev(i3) == 3 + + # equality includes plev + @test i != ip + @test i == Tensor4all.noprime(ip) + + # hash includes plev + @test hash(i) != hash(ip) + @test hash(i) == hash(Tensor4all.noprime(ip)) + + # index matching includes plev + @test !Tensor4all.hascommoninds([i], [ip]) + @test isempty(Tensor4all.commoninds([i], [ip])) + + # sim preserves plev + ip_sim = Tensor4all.sim(ip) + @test Tensor4all.plev(ip_sim) == 1 + @test Tensor4all.id(ip_sim) != Tensor4all.id(ip) + + # display shows prime + s = sprint(show, ip) + @test occursin("'", s) + end end From d05449dffc01638909c446dba1778485c70394ba Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 7 Apr 2026 11:43:36 +0900 Subject: [PATCH 2/5] feat: add TensorTrain alias and is_chain/assert_chain utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TensorTrain = TreeTensorNetwork{Int} (same type as MPS/MPO) - is_chain(ttn): runtime check for chain topology - Int variant: verifies vertices are 1:n with sequential connectivity - Generic variant: checks path graph properties (degree ≤ 2, connected) - _assert_chain(ttn): throws ArgumentError for non-chain topology - Tests: 121/121 TreeTN tests pass Co-Authored-By: Claude Opus 4.6 (1M context) --- src/Tensor4all.jl | 2 ++ src/TreeTN.jl | 84 +++++++++++++++++++++++++++++++++++++++++++-- test/test_treetn.jl | 62 +++++++++++++++++++++++++++++++-- 3 files changed, 144 insertions(+), 4 deletions(-) diff --git a/src/Tensor4all.jl b/src/Tensor4all.jl index dc1094a..f0fe3b7 100644 --- a/src/Tensor4all.jl +++ b/src/Tensor4all.jl @@ -905,6 +905,8 @@ include("SimpleTT.jl") # Tree tensor network functionality is in a separate submodule. # Use: using Tensor4all.TreeTN include("TreeTN.jl") +using .TreeTN: MPS, MPO, TensorTrain, random_mps, random_tt, is_chain +export MPS, MPO, TensorTrain, random_mps, random_tt, is_chain # ============================================================================ # QuanticsGrids Submodule diff --git a/src/TreeTN.jl b/src/TreeTN.jl index 0b04edf..01203ed 100644 --- a/src/TreeTN.jl +++ b/src/TreeTN.jl @@ -111,7 +111,7 @@ function _from_c_vertex(ttn::TreeTensorNetwork{V}, idx::Integer) where V end # ============================================================================ -# MPS/MPO type aliases and constructors +# MPS/MPO/TensorTrain type aliases and constructors # ============================================================================ """ @@ -130,7 +130,15 @@ Vertices are 1-indexed (Julia convention). """ const MPO = TreeTensorNetwork{Int} -export MPS, MPO +""" + TensorTrain + +Type alias for `TreeTensorNetwork{Int}`. +Same type as MPS and MPO; the distinction is semantic, not type-level. +""" +const TensorTrain = TreeTensorNetwork{Int} + +export MPS, MPO, TensorTrain """ MPS(tensors::Vector{Tensor}) @@ -295,6 +303,78 @@ end export neighbors +""" + is_chain(ttn::TreeTensorNetwork) -> Bool + +Check if a TreeTensorNetwork has chain (linear) topology. + +For `TreeTensorNetwork{Int}`, also verifies that vertices are named 1, 2, ..., n +in sequential order and that vertex `i` is connected to vertex `i+1`. + +For other vertex types, only checks that the topology is a path graph: +the graph is connected, has exactly two endpoints for `n > 1`, and every +vertex has degree at most 2. +""" +function is_chain(ttn::TreeTensorNetwork) + n = nv(ttn) + n <= 1 && return true + + verts = vertices(ttn) + degrees = Dict(v => length(neighbors(ttn, v)) for v in verts) + + any(>(2), values(degrees)) && return false + count(==(1), values(degrees)) == 2 || return false + count(==(2), values(degrees)) == n - 2 || return false + + visited = Set{eltype(verts)}() + stack = [first(verts)] + + while !isempty(stack) + v = pop!(stack) + v in visited && continue + push!(visited, v) + append!(stack, filter(w -> !(w in visited), neighbors(ttn, v))) + end + + return length(visited) == n +end + +function is_chain(ttn::TreeTensorNetwork{Int}) + n = nv(ttn) + n <= 1 && return true + + sort(vertices(ttn)) == collect(1:n) || return false + + for i in 1:n + expected = if i == 1 + [2] + elseif i == n + [n - 1] + else + [i - 1, i + 1] + end + sort(neighbors(ttn, i)) == expected || return false + end + + return true +end + +export is_chain + +""" + _assert_chain(ttn::TreeTensorNetwork) + +Assert that a TreeTensorNetwork has chain topology. +Used internally by chain-specific operations. +""" +function _assert_chain(ttn::TreeTensorNetwork) + is_chain(ttn) || throw(ArgumentError( + "Operation requires a chain (linear) topology, but the TreeTensorNetwork " * + "has $(nv(ttn)) vertices with non-chain connectivity." + )) + return nothing +end + """ getindex(ttn::TreeTensorNetwork{V}, v::V) -> Tensor diff --git a/test/test_treetn.jl b/test/test_treetn.jl index 5026072..f6ff0c7 100644 --- a/test/test_treetn.jl +++ b/test/test_treetn.jl @@ -1,11 +1,11 @@ using Test using Tensor4all: Index as T4AIndex, Tensor as T4ATensor using Tensor4all: dim, rank -using Tensor4all.TreeTN: MPS, MPO, TreeTensorNetwork +using Tensor4all.TreeTN: MPS, MPO, TensorTrain, TreeTensorNetwork using Tensor4all.TreeTN: nv, ne, linkdims, maxbonddim, linkind, linkinds, linkdim using Tensor4all.TreeTN: canonical_form, Unitary, LU, CI using Tensor4all.TreeTN: orthogonalize!, truncate!, inner -using Tensor4all.TreeTN: contract, to_dense, truncate +using Tensor4all.TreeTN: contract, to_dense, truncate, is_chain using Tensor4all.TreeTN: random_mps, random_tt using Tensor4all.TreeTN: findsite, findsites, siteinds, siteind using LinearAlgebra @@ -401,6 +401,64 @@ using LinearAlgebra @test length(mps_alias) == 5 end + @testset "TensorTrain alias and is_chain" begin + @test TensorTrain === MPS + @test TensorTrain === MPO + @test TensorTrain === TreeTensorNetwork{Int} + @test Tensor4all.TensorTrain === TensorTrain + @test Tensor4all.MPS === MPS + + sites = [T4AIndex(2) for _ in 1:5] + mps = random_mps(sites; linkdims=2) + @test is_chain(mps) + @test Tensor4all.is_chain(mps) + + single_site = random_mps([T4AIndex(2)]) + @test is_chain(single_site) + + sym_s1 = T4AIndex(2) + sym_l12 = T4AIndex(3) + sym_s2 = T4AIndex(2) + sym_l23 = T4AIndex(3) + sym_s3 = T4AIndex(2) + + sym_t1 = T4ATensor([sym_s1, sym_l12], rand(2, 3)) + sym_t2 = T4ATensor([sym_l12, sym_s2, sym_l23], rand(3, 2, 3)) + sym_t3 = T4ATensor([sym_l23, sym_s3], rand(3, 2)) + sym_chain = TreeTensorNetwork{Symbol}([:left => sym_t1, :middle => sym_t2, :right => sym_t3]) + @test is_chain(sym_chain) + + int_s1 = T4AIndex(2) + int_l12 = T4AIndex(3) + int_s2 = T4AIndex(2) + int_l23 = T4AIndex(3) + int_s3 = T4AIndex(2) + + int_t1 = T4ATensor([int_s1, int_l12], rand(2, 3)) + int_t2 = T4ATensor([int_l12, int_s2, int_l23], rand(3, 2, 3)) + int_t3 = T4ATensor([int_l23, int_s3], rand(3, 2)) + misnamed_chain = TreeTensorNetwork{Int}([2 => int_t1, 3 => int_t2, 4 => int_t3]) + @test !is_chain(misnamed_chain) + + br_s1 = T4AIndex(2) + br_l12 = T4AIndex(3) + br_s2 = T4AIndex(2) + br_l23 = T4AIndex(3) + br_l24 = T4AIndex(3) + br_s3 = T4AIndex(2) + br_s4 = T4AIndex(2) + + br_t1 = T4ATensor([br_s1, br_l12], rand(2, 3)) + br_t2 = T4ATensor([br_l12, br_s2, br_l23, br_l24], rand(3, 2, 3, 3)) + br_t3 = T4ATensor([br_l23, br_s3], rand(3, 2)) + br_t4 = T4ATensor([br_l24, br_s4], rand(3, 2)) + branched_ttn = TreeTensorNetwork{Int}([1 => br_t1, 2 => br_t2, 3 => br_t3, 4 => br_t4]) + @test !is_chain(branched_ttn) + + @test Tensor4all.TreeTN._assert_chain(mps) === nothing + @test_throws ArgumentError Tensor4all.TreeTN._assert_chain(branched_ttn) + end + @testset "MPS setindex!" begin using Tensor4all: indices From 2deee1b3f411dccef1891c347cceae50cbb52414 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 7 Apr 2026 12:03:22 +0900 Subject: [PATCH 3/5] feat: add tensor contraction, is_mps/mpo_like, diag_embed/diag_trace MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - contract(t1::Tensor, t2::Tensor): tensor-tensor contraction via C API - t1 * t2 operator for tensors - is_mps_like(tt): check each vertex has 1 site index - is_mpo_like(tt): check each vertex has 2 site indices - diag_embed(t, idx): duplicate idx as primed copy (MPS→MPO-like) - diag_trace(t, idx, idx'): extract diagonal (MPO→MPS-like) Co-Authored-By: Claude Opus 4.6 (1M context) --- src/C_API.jl | 16 +++++++ src/Tensor4all.jl | 107 ++++++++++++++++++++++++++++++++++++++++++-- src/TreeTN.jl | 32 ++++++++++++- test/test_tensor.jl | 37 +++++++++++++++ test/test_treetn.jl | 7 +++ 5 files changed, 195 insertions(+), 4 deletions(-) diff --git a/src/C_API.jl b/src/C_API.jl index e8f5b69..eef68fd 100644 --- a/src/C_API.jl +++ b/src/C_API.jl @@ -585,6 +585,22 @@ function t4a_tensor_onehot(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, vals:: ) end +""" + t4a_tensor_contract(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out::Ref{Ptr{Cvoid}}) + +Contract two tensors and write the resulting tensor handle to `out`. +""" +function t4a_tensor_contract(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out::Ref{Ptr{Cvoid}}) + return ccall( + _sym(:t4a_tensor_contract), + Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Ptr{Cvoid}}), + a, + b, + out + ) +end + # ============================================================================ # TreeTN lifecycle functions # ============================================================================ diff --git a/src/Tensor4all.jl b/src/Tensor4all.jl index 919cf6e..4666bad 100644 --- a/src/Tensor4all.jl +++ b/src/Tensor4all.jl @@ -58,7 +58,7 @@ include("Algorithm.jl") # Re-export public API # Core types (tensor4all-core-common, tensor4all-core-tensor) export Index, dim, tags, id, hastag, plev -export Tensor, rank, dims, indices, storage_kind, data +export Tensor, rank, dims, indices, storage_kind, data, contract, diag_embed, diag_trace export StorageKind, DenseF64, DenseC64, DiagF64, DiagC64 # Re-export Algorithm submodule and utilities @@ -911,6 +911,21 @@ function Base.deepcopy(t::Tensor) return Tensor(ptr) end +""" + contract(t1::Tensor, t2::Tensor) -> Tensor + +Contract two tensors by matching common indices (same ID + tags + plev). +Returns a new tensor with the non-contracted indices from both inputs. +""" +function contract(t1::Tensor, t2::Tensor) + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_tensor_contract(t1.ptr, t2.ptr, out) + C_API.check_status(status) + return Tensor(out[]) +end + +Base.:*(t1::Tensor, t2::Tensor) = contract(t1, t2) + # ============================================================================ # HDF5 Save/Load for Tensor (ITensors.jl compatible) # ============================================================================ @@ -950,6 +965,92 @@ function load_itensor(filepath::AbstractString, name::AbstractString) return Tensor(out[]) end +""" + diag_embed(t::Tensor, idx::Index) -> Tensor + +Create a new tensor where `idx` is duplicated as `prime(idx)`. +The result is diagonal in `idx` and `prime(idx)`: +only elements where `idx == prime(idx)` are nonzero. + +This is used to convert MPS-like site tensors to MPO-like site tensors. +""" +function diag_embed(t::Tensor, idx::Index) + t_inds = indices(t) + + pos = findfirst(i -> i == idx, t_inds) + pos === nothing && error("Index not found in tensor") + + d = dim(idx) + idx_prime = prime(idx) + + arr = data(t) + t_dims = dims(t) + + new_dims = collect(t_dims) + insert!(new_dims, pos + 1, d) + new_arr = zeros(eltype(arr), new_dims...) + + n_axes = length(t_dims) + for idx_val in 1:d + src_slices = [i == pos ? idx_val : Colon() for i in 1:n_axes] + dst_slices = Any[i == pos ? idx_val : Colon() for i in 1:n_axes] + insert!(dst_slices, pos + 1, idx_val) + new_arr[dst_slices...] = arr[src_slices...] + end + + new_inds = copy(t_inds) + insert!(new_inds, pos + 1, idx_prime) + + return Tensor(new_inds, new_arr) +end + +""" + diag_trace(t::Tensor, idx::Index, idx_prime::Index) -> Tensor + +Extract the diagonal of a tensor in `idx` and `idx_prime`, keeping only `idx`. +This is the inverse of `diag_embed`: it reduces an MPO-like site tensor +back to MPS-like by tracing out the primed index. + +Requires `dim(idx) == dim(idx_prime)`. +""" +function diag_trace(t::Tensor, idx::Index, idx_prime::Index) + dim(idx) == dim(idx_prime) || error("Dimensions must match: $(dim(idx)) vs $(dim(idx_prime))") + + t_inds = indices(t) + d = dim(idx) + + pos1 = findfirst(i -> i == idx, t_inds) + pos2 = findfirst(i -> i == idx_prime, t_inds) + pos1 === nothing && error("idx not found in tensor") + pos2 === nothing && error("idx_prime not found in tensor") + + if pos1 > pos2 + pos1, pos2 = pos2, pos1 + idx, idx_prime = idx_prime, idx + end + + arr = data(t) + t_dims = dims(t) + n_axes = length(t_dims) + + new_dims = [t_dims[i] for i in 1:n_axes if i != pos2] + new_arr = zeros(eltype(arr), new_dims...) + + for idx_val in 1:d + src_slices = [i == pos1 ? idx_val : (i == pos2 ? idx_val : Colon()) for i in 1:n_axes] + dst_slices = Any[] + for i in 1:n_axes + i == pos2 && continue + push!(dst_slices, i == pos1 ? idx_val : Colon()) + end + new_arr[dst_slices...] = arr[src_slices...] + end + + new_inds = [t_inds[i] for i in 1:n_axes if i != pos2] + + return Tensor(new_inds, new_arr) +end + export save_itensor, load_itensor # ============================================================================ @@ -966,8 +1067,8 @@ include("SimpleTT.jl") # Tree tensor network functionality is in a separate submodule. # Use: using Tensor4all.TreeTN include("TreeTN.jl") -using .TreeTN: MPS, MPO, TensorTrain, random_mps, random_tt, is_chain -export MPS, MPO, TensorTrain, random_mps, random_tt, is_chain +using .TreeTN: MPS, MPO, TensorTrain, random_mps, random_tt, is_chain, is_mps_like, is_mpo_like +export MPS, MPO, TensorTrain, random_mps, random_tt, is_chain, is_mps_like, is_mpo_like # ============================================================================ # QuanticsGrids Submodule diff --git a/src/TreeTN.jl b/src/TreeTN.jl index 01203ed..8286f9e 100644 --- a/src/TreeTN.jl +++ b/src/TreeTN.jl @@ -25,7 +25,7 @@ module TreeTN using LinearAlgebra # Import from parent module -import ..Tensor4all: Index, Tensor, dim, id, tags, indices, rank, dims, data +import ..Tensor4all: Index, Tensor, dim, id, tags, indices, rank, dims, data, contract import ..Tensor4all: hascommoninds, commoninds, uniqueinds, HasCommonIndsPredicate import ..Tensor4all: C_API import ..SimpleTT: SimpleTensorTrain, site_tensor @@ -375,6 +375,36 @@ function _assert_chain(ttn::TreeTensorNetwork) return nothing end +""" + is_mps_like(tt::TreeTensorNetwork{Int}) -> Bool + +Check if a chain TensorTrain is MPS-like: each vertex has exactly 1 site index. +""" +function is_mps_like(tt::TreeTensorNetwork{Int}) + _assert_chain(tt) + for v in vertices(tt) + length(siteinds(tt, v)) != 1 && return false + end + return true +end + +export is_mps_like + +""" + is_mpo_like(tt::TreeTensorNetwork{Int}) -> Bool + +Check if a chain TensorTrain is MPO-like: each vertex has exactly 2 site indices. +""" +function is_mpo_like(tt::TreeTensorNetwork{Int}) + _assert_chain(tt) + for v in vertices(tt) + length(siteinds(tt, v)) != 2 && return false + end + return true +end + +export is_mpo_like + """ getindex(ttn::TreeTensorNetwork{V}, v::V) -> Tensor diff --git a/test/test_tensor.jl b/test/test_tensor.jl index d407808..b66f2e8 100644 --- a/test/test_tensor.jl +++ b/test/test_tensor.jl @@ -220,4 +220,41 @@ d = Tensor4all.data(t) @test d[] ≈ 1.0 end + + @testset "tensor contraction" begin + i = T4AIndex(2; tags="i") + j = T4AIndex(3; tags="j") + k = T4AIndex(4; tags="k") + + A = Tensor4all.Tensor([i, j], ones(2, 3)) + B = Tensor4all.Tensor([j, k], ones(3, 4)) + + C = Tensor4all.contract(A, B) + @test Tensor4all.rank(C) == 2 + d = Tensor4all.data(C) + @test all(x -> abs(x - 3.0) < 1e-12, d) + + C2 = A * B + @test Tensor4all.data(C2) ≈ Tensor4all.data(C) + end + + @testset "diag_embed and diag_trace" begin + i = T4AIndex(3; tags="i") + j = T4AIndex(2; tags="j") + + arr = reshape(collect(1.0:6.0), 3, 2) + t = Tensor4all.Tensor([i, j], arr) + + t_diag = Tensor4all.diag_embed(t, i) + @test Tensor4all.rank(t_diag) == 3 + inds_diag = Tensor4all.indices(t_diag) + @test Tensor4all.dim(inds_diag[1]) == 3 + @test Tensor4all.dim(inds_diag[2]) == 3 + @test Tensor4all.plev(inds_diag[2]) == 1 + + i_prime = Tensor4all.prime(i) + t_back = Tensor4all.diag_trace(t_diag, i, i_prime) + @test Tensor4all.rank(t_back) == 2 + @test Tensor4all.data(t_back) ≈ arr + end end diff --git a/test/test_treetn.jl b/test/test_treetn.jl index f6ff0c7..9ab1ebb 100644 --- a/test/test_treetn.jl +++ b/test/test_treetn.jl @@ -459,6 +459,13 @@ using LinearAlgebra @test_throws ArgumentError Tensor4all.TreeTN._assert_chain(branched_ttn) end + @testset "is_mps_like and is_mpo_like" begin + sites = [Tensor4all.Index(2) for _ in 1:4] + mps = Tensor4all.random_mps(sites; linkdims=2) + @test Tensor4all.is_mps_like(mps) + @test !Tensor4all.is_mpo_like(mps) + end + @testset "MPS setindex!" begin using Tensor4all: indices From 4da7820ddf9c7ed053d0473111c916af1accf077 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 7 Apr 2026 13:59:03 +0900 Subject: [PATCH 4/5] chore: update tensor4all-rs pin to include plev + tensor contraction Update TENSOR4ALL_RS_FALLBACK_COMMIT to latest main which includes plev support and t4a_tensor_contract C API. Co-Authored-By: Claude Opus 4.6 (1M context) --- deps/build.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/build.jl b/deps/build.jl index b4314c0..aa91fc2 100644 --- a/deps/build.jl +++ b/deps/build.jl @@ -12,7 +12,7 @@ using Libdl using RustToolChain: cargo # Configuration -const TENSOR4ALL_RS_FALLBACK_COMMIT = "4ee57fee0a71d385576c11d42850304548c6949d" +const TENSOR4ALL_RS_FALLBACK_COMMIT = "3f05ea81177c64b5f351b99fdfd23325e732fc62" const TENSOR4ALL_RS_REPO = "https://github.com/tensor4all/tensor4all-rs.git" # Paths From f5b3a7d4a37a413f69e483e67c7c5f99b4044421 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 7 Apr 2026 14:08:43 +0900 Subject: [PATCH 5/5] fix: update build script test to match new pinned commit Co-Authored-By: Claude Opus 4.6 (1M context) --- test/test_build_script.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_build_script.jl b/test/test_build_script.jl index 1e364ca..6a18b67 100644 --- a/test/test_build_script.jl +++ b/test/test_build_script.jl @@ -3,7 +3,7 @@ using Test @testset "build.jl" begin script = read(joinpath(dirname(@__DIR__), "deps", "build.jl"), String) - @test occursin("const TENSOR4ALL_RS_FALLBACK_COMMIT = \"4ee57fee0a71d385576c11d42850304548c6949d\"", script) + @test occursin("const TENSOR4ALL_RS_FALLBACK_COMMIT = \"3f05ea81177c64b5f351b99fdfd23325e732fc62\"", script) @test occursin("checkout --detach", script) @test !occursin("--branch", script) end