From 1550f85489d58fed54a4e41ce20f726ca478e1c7 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Mon, 30 Mar 2026 22:32:27 +0900 Subject: [PATCH] refactor: migrate TreeTCI evaluate to IndexId-based C API Replace t4a_treetn_evaluate_batch with the new IndexId-based t4a_treetn_all_site_index_ids + t4a_treetn_evaluate flow. The new _evaluate_batch first queries index IDs and vertex names, reorders values to match the index ordering, then calls evaluate. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/TreeTCI.jl | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index 4f0e659..781221c 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -532,16 +532,46 @@ function evaluate(ttn::TreeTensorNetwork, batch::AbstractMatrix{<:Integer}) return _evaluate_batch(ttn, flat, n_sites, n_points) end -"""Internal: call C API batch evaluate and return typed results.""" +"""Internal: call C API evaluate using IndexId-based interface and return typed results.""" function _evaluate_batch(ttn::TreeTensorNetwork, flat::Vector{Csize_t}, n_sites::Int, n_points::Int) + # Step 1: Query the number of site indices + n_indices_ref = Ref{Csize_t}(0) + C_API.check_status(ccall( + C_API._sym(:t4a_treetn_all_site_index_ids), Cint, + (Ptr{Cvoid}, Ptr{UInt64}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + ttn.handle, C_NULL, C_NULL, 0, n_indices_ref, + )) + n_indices = Int(n_indices_ref[]) + + # Step 2: Fetch index IDs and their vertex names + index_ids = Vector{UInt64}(undef, n_indices) + vertex_names = Vector{Csize_t}(undef, n_indices) + C_API.check_status(ccall( + C_API._sym(:t4a_treetn_all_site_index_ids), Cint, + (Ptr{Cvoid}, Ptr{UInt64}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}), + ttn.handle, index_ids, vertex_names, n_indices, n_indices_ref, + )) + + # Step 3: Reorder flat values to match index_ids order. + # `flat` is laid out by site order (0, 1, …, n_sites-1) for each point. + # `vertex_names[i]` tells us which site (0-based) index i belongs to. + reordered = Vector{Csize_t}(undef, n_indices * n_points) + for p in 0:(n_points - 1) + for i in 1:n_indices + site = Int(vertex_names[i]) # 0-based site number + reordered[i + n_indices * p] = flat[site + 1 + n_sites * p] # +1 for Julia 1-based + end + end + + # Step 4: Call evaluate out_re = Vector{Cdouble}(undef, n_points) out_im = Vector{Cdouble}(undef, n_points) C_API.check_status(ccall( - C_API._sym(:t4a_treetn_evaluate_batch), - Cint, - (Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}), - ttn.handle, flat, Csize_t(n_sites), Csize_t(n_points), out_re, out_im, + C_API._sym(:t4a_treetn_evaluate), Cint, + (Ptr{Cvoid}, Ptr{UInt64}, Csize_t, Ptr{Csize_t}, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}), + ttn.handle, index_ids, n_indices, reordered, n_points, out_re, out_im, )) + # Detect if complex by checking if any imaginary part is nonzero if all(iszero, out_im) return out_re