Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ PartitionedMPSs = "0.5.2, 0.6"
QuanticsTCI = "0.7"
SparseIR = "^0.96, 0.97, 1"
StaticArrays = "1"
TensorCrossInterpolation = "0.9.18"
julia = "1"

[extras]
Expand Down
17 changes: 13 additions & 4 deletions src/affine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,16 @@ function affine_transform_tensors(
# dimensions (links) vary
tensors = Vector{Array{Bool,4}}(undef, R)

# Track sign separately and work with absolute value so that
# right-shifting always terminates (fixes negative b handling).
bsign = sign.(b)
b = abs.(b)

# The initial carry is zero
carry = [zero(SVector{M,Int})]
for r in R:-1:1
# Figure out the current bit to add from the shift term and shift
bcurr = SVector{M,Int}((copysign(b_, abs(b_)) & 1 for b_ in b))
bcurr = (b .& 1) .* bsign

# Get tensor.
new_carry, data = affine_transform_core(A, bcurr, s, carry)
Expand All @@ -138,12 +143,12 @@ function affine_transform_tensors(
b = @. b >> 1
end

if boundary == OpenBoundaryConditions() && maximum(abs, b) > 0
if boundary == OpenBoundaryConditions() && maximum(b) > 0
# Extend the tensors to the left until we have no more nonzero bits in b
# This is equivalent to a larger domain.
tensors_ext = Array{Bool,4}[]
while maximum(abs, b) > 0
bcurr = SVector{M,Int}((copysign(b_, abs(b_)) & 1 for b_ in b))
while maximum(b) > 0
bcurr = (b .& 1) .* bsign
new_carry, data = affine_transform_core(A, bcurr, s, carry; activebit=false)
pushfirst!(tensors_ext, data)

Expand Down Expand Up @@ -215,6 +220,10 @@ function affine_transform_core(
# if s is odd, then there is a unique y which solves satisfies
# above condition (simply the lowest bit)
y = @. Bool(z & 1)
if !activebit && any(y)
# y must be zero when bits are inactive; skip dead-end carry
continue
end
y_index = digits_to_number(y) + 1

# Correct z and compute carry
Expand Down
Loading