Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/TreeTCI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,46 @@ function evaluate(ttn::TreeTensorNetwork, batch::AbstractMatrix{<:Integer})
return _evaluate_batch(ttn, flat, n_sites, n_points)
end

"""Internal: call C API batch evaluate and return typed results."""
"""Internal: call C API evaluate using IndexId-based interface and return typed results."""
function _evaluate_batch(ttn::TreeTensorNetwork, flat::Vector{Csize_t}, n_sites::Int, n_points::Int)
# Step 1: Query the number of site indices
n_indices_ref = Ref{Csize_t}(0)
C_API.check_status(ccall(
C_API._sym(:t4a_treetn_all_site_index_ids), Cint,
(Ptr{Cvoid}, Ptr{UInt64}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}),
ttn.handle, C_NULL, C_NULL, 0, n_indices_ref,
))
n_indices = Int(n_indices_ref[])

# Step 2: Fetch index IDs and their vertex names
index_ids = Vector{UInt64}(undef, n_indices)
vertex_names = Vector{Csize_t}(undef, n_indices)
C_API.check_status(ccall(
C_API._sym(:t4a_treetn_all_site_index_ids), Cint,
(Ptr{Cvoid}, Ptr{UInt64}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}),
ttn.handle, index_ids, vertex_names, n_indices, n_indices_ref,
))

# Step 3: Reorder flat values to match index_ids order.
# `flat` is laid out by site order (0, 1, …, n_sites-1) for each point.
# `vertex_names[i]` tells us which site (0-based) index i belongs to.
reordered = Vector{Csize_t}(undef, n_indices * n_points)
for p in 0:(n_points - 1)
for i in 1:n_indices
site = Int(vertex_names[i]) # 0-based site number
reordered[i + n_indices * p] = flat[site + 1 + n_sites * p] # +1 for Julia 1-based
end
end

# Step 4: Call evaluate
out_re = Vector{Cdouble}(undef, n_points)
out_im = Vector{Cdouble}(undef, n_points)
C_API.check_status(ccall(
C_API._sym(:t4a_treetn_evaluate_batch),
Cint,
(Ptr{Cvoid}, Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}),
ttn.handle, flat, Csize_t(n_sites), Csize_t(n_points), out_re, out_im,
C_API._sym(:t4a_treetn_evaluate), Cint,
(Ptr{Cvoid}, Ptr{UInt64}, Csize_t, Ptr{Csize_t}, Csize_t, Ptr{Cdouble}, Ptr{Cdouble}),
ttn.handle, index_ids, n_indices, reordered, n_points, out_re, out_im,
))

# Detect if complex by checking if any imaginary part is nonzero
if all(iszero, out_im)
return out_re
Expand Down
Loading