diff --git a/src/C_API.jl b/src/C_API.jl index 6f985ef..5028aaa 100644 --- a/src/C_API.jl +++ b/src/C_API.jl @@ -1455,7 +1455,7 @@ function t4a_simplett_c64_site_tensor( ptr, Csize_t(site), out_data, - Csize_t(length(out_data) ÷ 2), + Csize_t(length(out_data)), out_left_dim, out_site_dim, out_right_dim diff --git a/test/test_conversions.jl b/test/test_conversions.jl new file mode 100644 index 0000000..fff7205 --- /dev/null +++ b/test/test_conversions.jl @@ -0,0 +1,105 @@ +using Test +using Tensor4all +using Tensor4all.SimpleTT +using Tensor4all.TreeTN + +import Tensor4all.SimpleTT: sitedims, linkdims, evaluate, sitetensor, fulltensor +import Tensor4all.TreeTN: MPS, nv, inner, linkdims as ttn_linkdims +using LinearAlgebra: norm + +@testset "SimpleTT <-> TreeTN Conversions" begin + @testset "SimpleTT -> MPS -> SimpleTT round-trip" begin + @testset "rank-1 constant" begin + tt = SimpleTensorTrain([2, 3, 4], 1.5) + mps = MPS(tt) + @test nv(mps) == 3 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 3 + @test sitedims(tt2) == [2, 3, 4] + + # Check values match + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "higher rank" begin + t1 = randn(1, 2, 3) + t2 = randn(3, 4, 2) + t3 = randn(2, 3, 1) + tt = SimpleTensorTrain([t1, t2, t3]) + mps = MPS(tt) + @test nv(mps) == 3 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 3 + @test sitedims(tt2) == [2, 4, 3] + @test linkdims(tt2) == [3, 2] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "single site" begin + tt = SimpleTensorTrain([3], 2.0) + mps = MPS(tt) + @test nv(mps) == 1 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 1 + @test sitedims(tt2) == [3] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "two sites" begin + tt = SimpleTensorTrain([2, 5], 3.0) + mps = MPS(tt) + @test nv(mps) == 2 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 2 + @test sitedims(tt2) == [2, 5] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + end + + @testset "ComplexF64 round-trip" begin + tt = SimpleTensorTrain([2, 3, 4], 1.0 + 2.0im) + mps = MPS(tt) + @test nv(mps) == 3 + + tt2 = SimpleTensorTrain(mps) + @test length(tt2) == 3 + @test sitedims(tt2) == [2, 3, 4] + + arr1 = fulltensor(tt) + arr2 = fulltensor(tt2) + @test arr1 ≈ arr2 + end + + @testset "MPS -> SimpleTT -> MPS" begin + sites = [Tensor4all.Index(2) for _ in 1:4] + mps = TreeTN.random_mps(sites; linkdims=3) + + tt = SimpleTensorTrain(mps) + @test length(tt) == 4 + @test sitedims(tt) == [2, 2, 2, 2] + + mps2 = MPS(tt) + @test nv(mps2) == 4 + + # Check that the dense tensor representation matches + dense1 = TreeTN.to_dense(mps) + arr1 = data(dense1) + arr2 = fulltensor(tt) + @test arr1 ≈ arr2 + end +end