Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Libdl
using RustToolChain: cargo

# Configuration
const TENSOR4ALL_RS_FALLBACK_COMMIT = "4ee57fee0a71d385576c11d42850304548c6949d"
const TENSOR4ALL_RS_FALLBACK_COMMIT = "3f05ea81177c64b5f351b99fdfd23325e732fc62"
const TENSOR4ALL_RS_REPO = "https://github.com/tensor4all/tensor4all-rs.git"

# Paths
Expand Down
15 changes: 11 additions & 4 deletions ext/Tensor4allITensorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ Both Rust and ITensors use UInt64 IDs natively, so conversion is direct.

## Memory Order

Rust uses row-major, Julia/ITensors uses column-major.
Conversion is handled automatically.
Rust and Julia/ITensors use column-major order.
Conversion preserves the logical layout.
"""
module Tensor4allITensorsExt

Expand All @@ -40,12 +40,13 @@ IDs are natively UInt64 in both systems. Tags are preserved.
function ITensors.Index(idx::Tensor4all.Index)
d = Tensor4all.dim(idx)
t = Tensor4all.tags(idx)
p = Tensor4all.plev(idx)
id64 = Tensor4all.id(idx)

# Create ITensors.Index with explicit ID using full constructor
# Index(id, space, dir, tags, plev)
tagset = isempty(t) ? ITensors.TagSet("") : ITensors.TagSet(t)
return ITensors.Index(id64, d, ITensors.Neither, tagset, 0)
return ITensors.Index(id64, d, ITensors.Neither, tagset, p)
end

# ============================================================================
Expand All @@ -65,12 +66,18 @@ cause an error.
function Tensor4all.Index(idx::ITensors.Index)
d = ITensors.dim(idx)
id64 = ITensors.id(idx)
p = ITensors.plev(idx)

# Get tags as comma-separated string
tag_set = ITensors.tags(idx)
tags_str = _tags_to_string(tag_set)

return Tensor4all.Index(d, id64; tags=tags_str)
t4a_idx = Tensor4all.Index(d, id64; tags=tags_str)
if p != 0
status = Tensor4all.C_API.t4a_index_set_plev(t4a_idx.ptr, Int64(p))
Tensor4all.C_API.check_status(status)
end
return t4a_idx
end

"""
Expand Down
60 changes: 60 additions & 0 deletions src/C_API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,21 @@ function t4a_index_get_tags(ptr::Ptr{Cvoid}, buf, buf_len::Integer, out_len::Ref
)
end

"""
t4a_index_get_plev(ptr::Ptr{Cvoid}, out_plev::Ref{Int64}) -> Cint

Get the prime level of an index.
"""
function t4a_index_get_plev(ptr::Ptr{Cvoid}, out_plev::Ref{Int64})
return ccall(
_sym(:t4a_index_get_plev),
Cint,
(Ptr{Cvoid}, Ptr{Int64}),
ptr,
out_plev
)
end

# ============================================================================
# Index modifiers
# ============================================================================
Expand Down Expand Up @@ -322,6 +337,35 @@ function t4a_index_has_tag(ptr::Ptr{Cvoid}, tag::AbstractString)
)
end

"""
t4a_index_set_plev(ptr::Ptr{Cvoid}, plev::Integer) -> Cint

Set the prime level of an index.
"""
function t4a_index_set_plev(ptr::Ptr{Cvoid}, plev::Integer)
return ccall(
_sym(:t4a_index_set_plev),
Cint,
(Ptr{Cvoid}, Int64),
ptr,
Int64(plev)
)
end

"""
t4a_index_prime(ptr::Ptr{Cvoid}) -> Cint

Increment the prime level of an index by 1.
"""
function t4a_index_prime(ptr::Ptr{Cvoid})
return ccall(
_sym(:t4a_index_prime),
Cint,
(Ptr{Cvoid},),
ptr
)
end

# ============================================================================
# Storage kind enum
# ============================================================================
Expand Down Expand Up @@ -541,6 +585,22 @@ function t4a_tensor_onehot(rank::Integer, index_ptrs::Vector{Ptr{Cvoid}}, vals::
)
end

"""
t4a_tensor_contract(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out::Ref{Ptr{Cvoid}})

Contract two tensors and write the resulting tensor handle to `out`.
"""
function t4a_tensor_contract(a::Ptr{Cvoid}, b::Ptr{Cvoid}, out::Ref{Ptr{Cvoid}})
return ccall(
_sym(:t4a_tensor_contract),
Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Ptr{Cvoid}}),
a,
b,
out
)
end

# ============================================================================
# TreeTN lifecycle functions
# ============================================================================
Expand Down
Loading
Loading