Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/C_API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,7 @@ function t4a_simplett_c64_site_tensor(
ptr,
Csize_t(site),
out_data,
Csize_t(length(out_data) ÷ 2),
Csize_t(length(out_data)),
out_left_dim,
out_site_dim,
out_right_dim
Expand Down
105 changes: 105 additions & 0 deletions test/test_conversions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using Test
using Tensor4all
using Tensor4all.SimpleTT
using Tensor4all.TreeTN

import Tensor4all.SimpleTT: sitedims, linkdims, evaluate, sitetensor, fulltensor
import Tensor4all.TreeTN: MPS, nv, inner, linkdims as ttn_linkdims
using LinearAlgebra: norm

@testset "SimpleTT <-> TreeTN Conversions" begin
@testset "SimpleTT -> MPS -> SimpleTT round-trip" begin
@testset "rank-1 constant" begin
tt = SimpleTensorTrain([2, 3, 4], 1.5)
mps = MPS(tt)
@test nv(mps) == 3

tt2 = SimpleTensorTrain(mps)
@test length(tt2) == 3
@test sitedims(tt2) == [2, 3, 4]

# Check values match
arr1 = fulltensor(tt)
arr2 = fulltensor(tt2)
@test arr1 ≈ arr2
end

@testset "higher rank" begin
t1 = randn(1, 2, 3)
t2 = randn(3, 4, 2)
t3 = randn(2, 3, 1)
tt = SimpleTensorTrain([t1, t2, t3])
mps = MPS(tt)
@test nv(mps) == 3

tt2 = SimpleTensorTrain(mps)
@test length(tt2) == 3
@test sitedims(tt2) == [2, 4, 3]
@test linkdims(tt2) == [3, 2]

arr1 = fulltensor(tt)
arr2 = fulltensor(tt2)
@test arr1 ≈ arr2
end

@testset "single site" begin
tt = SimpleTensorTrain([3], 2.0)
mps = MPS(tt)
@test nv(mps) == 1

tt2 = SimpleTensorTrain(mps)
@test length(tt2) == 1
@test sitedims(tt2) == [3]

arr1 = fulltensor(tt)
arr2 = fulltensor(tt2)
@test arr1 ≈ arr2
end

@testset "two sites" begin
tt = SimpleTensorTrain([2, 5], 3.0)
mps = MPS(tt)
@test nv(mps) == 2

tt2 = SimpleTensorTrain(mps)
@test length(tt2) == 2
@test sitedims(tt2) == [2, 5]

arr1 = fulltensor(tt)
arr2 = fulltensor(tt2)
@test arr1 ≈ arr2
end
end

@testset "ComplexF64 round-trip" begin
tt = SimpleTensorTrain([2, 3, 4], 1.0 + 2.0im)
mps = MPS(tt)
@test nv(mps) == 3

tt2 = SimpleTensorTrain(mps)
@test length(tt2) == 3
@test sitedims(tt2) == [2, 3, 4]

arr1 = fulltensor(tt)
arr2 = fulltensor(tt2)
@test arr1 ≈ arr2
end

@testset "MPS -> SimpleTT -> MPS" begin
sites = [Tensor4all.Index(2) for _ in 1:4]
mps = TreeTN.random_mps(sites; linkdims=3)

tt = SimpleTensorTrain(mps)
@test length(tt) == 4
@test sitedims(tt) == [2, 2, 2, 2]

mps2 = MPS(tt)
@test nv(mps2) == 4

# Check that the dense tensor representation matches
dense1 = TreeTN.to_dense(mps)
arr1 = data(dense1)
arr2 = fulltensor(tt)
@test arr1 ≈ arr2
end
end
Loading