From 752c1dcb6b3c0259e07e94ffb09cd7384626f7a7 Mon Sep 17 00:00:00 2001 From: NepomukRitz Date: Mon, 6 Apr 2026 17:32:09 +0900 Subject: [PATCH] Add exact affine pullback operator, fix SimpleTT buffer bug, and add tests - Expose new affine pullback operator for quantics tensor trains - Fix buffer length bug in SimpleTT complex site-tensor extraction - Add comprehensive tests for affine transforms - Aligns with new backend and C API features Reviewer: Hiroshi Shinaoka --- src/C_API.jl | 28 ++++++- src/QuanticsTransform.jl | 63 +++++++++++++- test/runtests.jl | 1 + test/test_quanticstransform.jl | 147 +++++++++++++++++++++++++++++++++ 4 files changed, 237 insertions(+), 2 deletions(-) create mode 100644 test/test_quanticstransform.jl diff --git a/src/C_API.jl b/src/C_API.jl index 6f985ef..0219dd1 100644 --- a/src/C_API.jl +++ b/src/C_API.jl @@ -1455,7 +1455,7 @@ function t4a_simplett_c64_site_tensor( ptr, Csize_t(site), out_data, - Csize_t(length(out_data) ÷ 2), + Csize_t(length(out_data)), out_left_dim, out_site_dim, out_right_dim @@ -2237,6 +2237,32 @@ function t4a_qtransform_fourier(r::Csize_t, forward::Cint, maxbonddim::Csize_t, ) end +""" + t4a_qtransform_affine_pullback(r, m, n, a_num, a_den, b_num, b_den, bc, out) -> Cint + +Create an affine pullback operator: f(y) = g(A*y + b). +The affine matrix is MxN in column-major order, with rational entries encoded by +numerator and denominator arrays. +""" +function t4a_qtransform_affine_pullback( + r::Csize_t, + m::Csize_t, + n::Csize_t, + a_num::Vector{Int64}, + a_den::Vector{Int64}, + b_num::Vector{Int64}, + b_den::Vector{Int64}, + bc::Vector{Cint}, + out, +) + return ccall( + _sym(:t4a_qtransform_affine_pullback), + Cint, + (Csize_t, Csize_t, Csize_t, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}, Ptr{Cint}, Ptr{Ptr{Cvoid}}), + r, m, n, a_num, a_den, b_num, b_den, bc, out, + ) +end + """ t4a_linop_apply(op, state, method, rtol, maxdim, out) -> Cint diff --git a/src/QuanticsTransform.jl b/src/QuanticsTransform.jl index 5b4c931..2acfcd7 100644 --- a/src/QuanticsTransform.jl +++ b/src/QuanticsTransform.jl @@ -21,7 +21,8 @@ using ..C_API using ..TreeTN: TreeTensorNetwork export LinearOperator -export shift_operator, flip_operator, phase_rotation_operator, cumsum_operator, fourier_operator +export AffineParams +export shift_operator, flip_operator, phase_rotation_operator, cumsum_operator, fourier_operator, affine_pullback_operator export apply export BoundaryCondition, Periodic, Open @@ -67,6 +68,29 @@ mutable struct LinearOperator end end +""" + AffineParams(a, b) + +Affine pullback parameters representing `f(y) = g(A*y + b)`. + +- `a`: source-dimension by output-dimension affine matrix +- `b`: source-dimension shift vector +""" +struct AffineParams + a::Matrix{Rational{Int64}} + b::Vector{Rational{Int64}} + + function AffineParams(a::AbstractMatrix, b::AbstractVector) + size(a, 1) == length(b) || error("Affine matrix row count must match shift length") + a_rat = Rational{Int64}[x isa Rational ? Rational{Int64}(Int64(numerator(x)), Int64(denominator(x))) : Rational{Int64}(Int64(x), 1) for x in a] + b_rat = Rational{Int64}[x isa Rational ? Rational{Int64}(Int64(numerator(x)), Int64(denominator(x))) : Rational{Int64}(Int64(x), 1) for x in b] + return new(reshape(a_rat, size(a)), b_rat) + end +end + +source_ndims(params::AffineParams) = size(params.a, 1) +output_ndims(params::AffineParams) = size(params.a, 2) + # ============================================================================ # Operator construction functions # ============================================================================ @@ -156,6 +180,43 @@ function fourier_operator(r::Integer; forward::Bool=true, maxbonddim::Integer=0, return LinearOperator(out[]) end +""" + affine_pullback_operator(r::Integer, params::AffineParams; bc=fill(Periodic, source_ndims(params))) -> LinearOperator + +Create an affine pullback operator implementing `f(y) = g(A*y + b)`. + +The input state has `source_ndims(params)` variables and the output state has +`output_ndims(params)` variables. Boundary conditions apply to the transformed +source coordinates `A*y + b`. +""" +function affine_pullback_operator( + r::Integer, + params::AffineParams; + bc::AbstractVector{<:BoundaryCondition}=fill(Periodic, source_ndims(params)), +) + length(bc) == source_ndims(params) || error("Boundary condition length must match source dimension") + a_num = Int64[numerator(value) for value in vec(params.a)] + a_den = Int64[denominator(value) for value in vec(params.a)] + b_num = Int64[numerator(value) for value in params.b] + b_den = Int64[denominator(value) for value in params.b] + bc_int = Cint[Int(value) for value in bc] + + out = Ref{Ptr{Cvoid}}(C_NULL) + status = C_API.t4a_qtransform_affine_pullback( + Csize_t(r), + Csize_t(source_ndims(params)), + Csize_t(output_ndims(params)), + a_num, + a_den, + b_num, + b_den, + bc_int, + out, + ) + C_API.check_status(status) + return LinearOperator(out[]) +end + # ============================================================================ # Operator application # ============================================================================ diff --git a/test/runtests.jl b/test/runtests.jl index 97bfcd9..4a5d38a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ skip_hdf5 = get(ENV, "T4A_SKIP_HDF5_TESTS", "") == "1" include("test_treetn.jl") include("test_treetci.jl") include("test_simplett.jl") + include("test_quanticstransform.jl") if !skip_hdf5 include("test_hdf5.jl") end diff --git a/test/test_quanticstransform.jl b/test/test_quanticstransform.jl new file mode 100644 index 0000000..913e81c --- /dev/null +++ b/test/test_quanticstransform.jl @@ -0,0 +1,147 @@ +using Test +using Tensor4all +using Tensor4all.QuanticsTransform +using Tensor4all.TreeTN: MPS, to_dense + +function _product_mps(vectors::Vector{<:AbstractVector{ComplexF64}}) + nsites = length(vectors) + site_inds = [Tensor4all.Index(length(vector)) for vector in vectors] + arrays = Array{ComplexF64}[] + + for (site, vector) in enumerate(vectors) + if nsites == 1 + push!(arrays, reshape(collect(vector), length(vector))) + elseif site == 1 + push!(arrays, reshape(collect(vector), length(vector), 1)) + elseif site == nsites + push!(arrays, reshape(collect(vector), 1, length(vector))) + else + push!(arrays, reshape(collect(vector), 1, length(vector), 1)) + end + end + + return MPS(arrays, site_inds) +end + +_dense_state(mps) = ComplexF64.(Tensor4all.data(to_dense(mps))) + +function _decode_coords(level_digits::Vector{Int}, nvars::Int) + r = length(level_digits) + coords = zeros(Int, nvars) + for var in 1:nvars + value = 0 + for digit in level_digits + value = (value << 1) | ((digit >> (var - 1)) & 1) + end + coords[var] = value + end + return coords +end + +function _encode_coords(coords::Vector{Int}, r::Int) + digits = zeros(Int, r) + for level in 1:r + bitpos = r - level + digit = 0 + for (var, coord) in enumerate(coords) + digit |= ((coord >> bitpos) & 1) << (var - 1) + end + digits[level] = digit + end + return digits +end + +function _expected_pullback(source_dense, a::AbstractMatrix{<:Integer}, b::AbstractVector{<:Integer}, bc::Vector{BoundaryCondition}) + source_ndims, output_ndims = size(a) + r = ndims(source_dense) + output_dim = 1 << output_ndims + source_size = 1 << r + expected = zeros(ComplexF64, ntuple(_ -> output_dim, r)) + + for output_index in CartesianIndices(expected) + output_digits = Int[index - 1 for index in Tuple(output_index)] + output_coords = _decode_coords(output_digits, output_ndims) + source_coords = vec(a * output_coords .+ b) + + valid = true + for i in eachindex(source_coords) + if bc[i] == Periodic + source_coords[i] = mod(source_coords[i], source_size) + elseif source_coords[i] < 0 || source_coords[i] >= source_size + valid = false + break + end + end + + if valid + source_digits = _encode_coords(source_coords, r) + expected[output_index] = source_dense[CartesianIndex((source_digits .+ 1)...)] + end + end + + return expected +end + +@testset "QuanticsTransform" begin + @testset "affine pullback identity" begin + params = AffineParams([1 0; 0 1], [0, 0]) + op = affine_pullback_operator(2, params; bc=[Periodic, Periodic]) + state = _product_mps([ + ComplexF64[1, 2, 3, 4], + ComplexF64[5, 6, 7, 8], + ]) + result = apply(op, state) + @test _dense_state(result) ≈ _dense_state(state) + end + + @testset "affine pullback 2d shear" begin + a = [1 0; 1 1] + b = [0, 0] + bc = [Periodic, Periodic] + params = AffineParams(a, b) + op = affine_pullback_operator(2, params; bc=bc) + state = _product_mps([ + ComplexF64[1, 3, 5, 7], + ComplexF64[2, 4, 6, 8], + ]) + result = apply(op, state) + + source_dense = _dense_state(state) + expected = _expected_pullback(source_dense, a, b, bc) + @test _dense_state(result) ≈ expected + end + + @testset "affine pullback embedding" begin + a = reshape([1, 0], 1, 2) + b = [0] + bc = [Open] + params = AffineParams(a, b) + op = affine_pullback_operator(2, params; bc=bc) + state = _product_mps([ + ComplexF64[1, 2], + ComplexF64[3, 4], + ]) + result = apply(op, state) + + source_dense = _dense_state(state) + expected = _expected_pullback(source_dense, a, b, bc) + @test _dense_state(result) ≈ expected + end + + @testset "affine pullback open shift" begin + a = reshape([1], 1, 1) + b = [1] + bc = [Open] + params = AffineParams(a, b) + op = affine_pullback_operator(2, params; bc=bc) + state = _product_mps([ + ComplexF64[1, 2], + ComplexF64[3, 4], + ]) + result = apply(op, state) + + source_dense = _dense_state(state) + expected = _expected_pullback(source_dense, a, b, bc) + @test _dense_state(result) ≈ expected + end +end \ No newline at end of file