Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/ScaleInvariantAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ScaleInvariantAnalysis
using LinearAlgebra
using SparseArrays

export condscale, divmag, dotabs, symscale
export condscale, divmag, dotabs, matrixscale, symscale

include("utils.jl")

Expand Down Expand Up @@ -51,9 +51,49 @@ function symscale(A::AbstractMatrix; exact::Bool=false)
offset = sum(sumlogA) / (2 * sum(nz))
return exp.(sumlogA ./ nz .- offset)
end
return exp.(cholesky(Diagonal(nz) + (!iszero).(A)) \ sumlogA)
return exp.(cholesky(Diagonal(nz) + isnz(A)) \ sumlogA)
end

"""
a, b = matrixscale(A; exact=false)

Given a matrix `A`, return vectors `a` and `b` representing the "scale of each
axis," so that `|A[i,j]| ~ a[i] * b[j]` for all `i, j`. `a[i]` and `b[j]` are
nonnegative, and are zero only if `A[i, j] = 0` for all `j` or all `i`,
respectively.

With `exact=true`, `a` and `b` solve the optimization problem

min ∑_{i,j : A[i,j] ≠ 0} (log(|A[i,j]| / (a[i] * b[j])))²
s.t. ∑_i nA[i] * log(a[i]) = ∑_j mA[j] * log(b[j])

where `nA` and `mA` are the number of nonzeros in each row and column,
respectively. Up to multiplication by a scalar, these vectors are covariant
under changes of scale but not general linear transformations.

With `exact=false`, the pattern of nonzeros in `A` is approximated as `u * v'`,
where `sum(u) * v[j] = mA[j]` and `sum(v) * u[i] = nA[i]`. This results in an
`O(m*n)` rather than `O((m+n)^3)` algorithm.
"""
function matrixscale(A::AbstractMatrix; exact::Bool=false)
Base.require_one_based_indexing(A)
ax1, ax2 = axes(A, 1), axes(A, 2)
(s, ns), (t, mt) = _matrixscale(A, ax1, ax2)
m, n = length(ax1), length(ax2)
if !exact || (all(==(n), ns) && all(==(m), mt))
z = sum(ns)
@assert sum(mt) == z "Inconsistent nonzero counts in rows and columns"
a = exp.(s ./ ns .- sum(s) / (2z))
b = exp.(t ./ mt .- sum(t) / (2z))
return a, b
end
p = vcat(ns, -mt)
W = isnz(A)
a12 = exp.(cholesky(Diagonal(vcat(ns, mt)) + odblocks(W) + p * p') \ vcat(s, t))
return a12[begin:begin+m-1], a12[m+begin:end]
end


ratio_nz(n, d) = iszero(d) ? zero(n) / oneunit(d) : n / d

"""
Expand Down
26 changes: 25 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,28 @@ function _symscale(A, ax)
return sumlogA, nz
end

# TODO: implement _symscale for SparseMatrixCSC
function _matrixscale(A, ax1, ax2)
sumlogA1, nz1 = fill!(similar(A, Float64, ax1), 0), fill!(similar(A, Int, ax1), 0)
sumlogA2, nz2 = fill!(similar(A, Float64, ax2), 0), fill!(similar(A, Int, ax2), 0)
for j in ax2
for i in ax1
Aij = abs(A[i, j])
iszero(Aij) && continue
logAij = log(Aij)
sumlogA1[i] += logAij
nz1[i] += 1
sumlogA2[j] += logAij
nz2[j] += 1
end
end
return (sumlogA1, nz1), (sumlogA2, nz2)
end

isnz(A) = .!iszero.(A)

function odblocks(Anz::AbstractMatrix{T}) where T
m, n = size(Anz)
return [zeros(T, m, m) Anz; Anz' zeros(T, n, n)]
end

# TODO: implementations for SparseMatrixCSC
32 changes: 32 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,43 @@ function test_scaleinv(f, A::AbstractMatrix, p::Int; iter=10, rtol=sqrt(eps(floa
@test npass ≥ iter - 1
end

function test_sumlog(A, a, b; rtol=1e-6)
α, β = log.(a), log.(b)
Aref = sum(abs(log(abs(A[i, j]))) for i in axes(A, 1), j in axes(A, 2) if A[i, j] != 0)
for j in axes(A, 2)
s = 0.0
for i in axes(A, 1)
if A[i, j] != 0
s += log(abs(A[i, j])) - α[i] - β[j]
end
end
@test abs(s) ≤ rtol * Aref
end
for i in axes(A, 1)
s = 0.0
for j in axes(A, 2)
if A[i, j] != 0
s += log(abs(A[i, j])) - α[i] - β[j]
end
end
@test abs(s) ≤ rtol * Aref
end
end

@testset "ScaleInvariantAnalysis.jl" begin
@test symscale([2.0 1.0; 1.0 3.0]) ≈ symscale([2.0 1.0; 1.0 3.0]; exact=true) ≈ exp.([3 1; 1 3] \ [log(2.0); log(3.0)])
@test symscale([1.0 -0.2; -0.2 0]; exact=true) ≈ [1, 0.2]
@test symscale([1.0 0; 0 2]; exact=true) ≈ [1, sqrt(2)]
test_scaleinv(A -> symscale(A; exact=true), [2.0 1.0; 1.0 3.0], 1)
a, b = matrixscale([2.0 1.0; 1.0 3.0]; exact=true)
@test a ≈ b ≈ symscale([2.0 1.0; 1.0 3.0]; exact=true)
a′, b′ = matrixscale([2.0 1.0; 1.0 3.0])
@test a′ ≈ a && b′ ≈ b
A = [0.0 1.0; -2.0 0.0]
a, b = matrixscale(A; exact=true)
test_sumlog(A, a, b)
a′, b′ = matrixscale(A)
@test sum(log, a) ≈ sum(log, b) ≈ sum(log, a′) ≈ sum(log, b′)

@test condscale([1 0; 0 1e-8]) ≈ 1
A = [1.0 -0.2; -0.2 0]
Expand Down