Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
{}
{
"julia.environmentPath": "/Users/youdongguo/.julia/dev/GsvdInitialization"
}
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ authors = ["youdongguo <1010705897@qq.com> and contributors"]
version = "1.0.0"

[deps]
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to update [compat]

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386"
NonNegLeastSquares = "b7351bd1-99d9-5c5d-8786-f205a815c4d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e"

[compat]
Expand Down
101 changes: 89 additions & 12 deletions src/GsvdInitialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GsvdInitialization

using LinearAlgebra, NMF, TSVD
using NonNegLeastSquares
using Kronecker, SparseArrays

export gsvdnmf,
gsvdrecover
Expand Down Expand Up @@ -33,6 +34,8 @@ Other keyword arguments are passed to `NMF.nnmf`.
function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f;
n2 = size(first(f), 2),
tol_nmf=1e-4,
alg = :cd,
initW = :standard,
kwargs...)
n1 = size(W, 2)
kadd = n2 - n1
Expand All @@ -42,9 +45,14 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f;
if kadd == 0
return W, H
else
W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f)
result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover)
return result_recover.W, result_recover.H
# @show alg
W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version is ready for review. I have not written the document yet and after we merge this code change (or we finish the code change), I will update the documentation.

i added a keyword argument to add a new method for our jointly optimizing Wadd and alpha approach. The default approach is separately computing Wadd and alpha.

if alg == :multmse
@show alg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

W_recover, H_recover = max.(W_recover, 1e-5), max.(H_recover, 1e-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the 1e-5 be hard-coded or a kwarg?

end
result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover))
return result_recover, result_recover.W, result_recover.H
end
end
gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...)
Expand Down Expand Up @@ -73,13 +81,13 @@ Keyword arguments:

Other keyword arguments are passed to `NMF.nnmf`.
"""
function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, kwargs...)
function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, initW = :standard, kwargs...)
n1, n2 = ncomponents
f = tsvd(X, n2)
W0, H0 = NMF.nndsvd(X, n1; initdata = (U = f[1], S = f[2], V = f[3]))
result_initial_nmf = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0))
W_initial_nmf, H_initial_nmf = result_initial_nmf.W, result_initial_nmf.H
return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final)
return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final, initW=initW)
end
gsvdnmf(X::AbstractMatrix, ncomponents_final::Integer; kwargs...) = gsvdnmf(X, ncomponents_final-1 => ncomponents_final; kwargs...)

Expand Down Expand Up @@ -108,7 +116,7 @@ Arguments:

`f`: SVD (or Truncated SVD) of `X`
"""
function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple)
function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...)
m, n = size(W0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main function change. A split of change is added.

kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components"))
if kadd == 0
Expand All @@ -117,15 +125,29 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad
U0, S0, V0 = f
U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n]
Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd)
Wadd, a = init_W(X, W0, H0, Hadd)
Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
cs = Wcols_modification(X, W0_1, H0_1)
W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1
if initW == :standard
Wadd, a = init_W(X, W0, H0, Hadd)
Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
cs = Wcols_modification(X, W0_1, H0_1)
W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1
elseif initW == :joint
W0_2, H0_2 = gsvdrecover_Wa(X, W0, H0, Hadd; kwargs...)
else
throw(ArgumentError("Unknown initW method: $initW"))
end
return abs.(W0_2), abs.(H0_2), Λ
end
end

function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray)
m = size(W0, 1)
Hadd_nn = truncatepos(Hadd', X, W0, H0)'
Wadd, a = init_Wa(X, W0, H0, Hadd_nn)
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd], [H0; Hadd_nn]
return abs.(W0_1), abs.(H0_1)
end

function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int)
_, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0));
inv_RQt = inv(R*Q')
Expand All @@ -138,7 +160,45 @@ function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::Abs
H_index = sortperm(Λ, rev = true)[1:kadd]
Hadd = inv_RQt[:, H_index]
Hadd_1 = V0*Hadd
return Hadd_1', Λ[H_index]
return Hadd_1', Λ
end

function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}) where T
m = size(X, 1)
kadd = size(Hadd, 1)
G = gram_sp_C(W0, H0, Hadd)[1]
b = gram_b(X, W0, H0, Hadd)
θ = nonneg_lsq(G, b; alg=:fnnls, gram=true)
Wadd = reshape(θ[1:m*kadd], m, kadd)
α = θ[m*kadd+1:end]
return Wadd, α
end

function gram_sp_C(W0, H0, Hadd)
m, r0 = size(W0)
k = size(Hadd, 1)
mk = m*k
W0W0, H0H0 = W0'*W0, H0*H0'
P = Hadd*H0'
HH = Hadd*Hadd'
G22 = sparse(W0W0.*H0H0)
G12 = zeros(Float64, mk, r0)
for j in 1:r0
G12[:,j] .= vec(W0[:,j] * P[:,j]')
end
G12 = sparse(G12)
G11 = kronecker(HH, sparse(I, m, m))
G = [G11 G12; G12' G22]
return G, G11, G12, G22
end

function gram_b(X, W0, H0, Hadd)
m, r0 = size(W0)
k = size(Hadd, 1)
b = zeros(Float64, m*k + r0)
b[1:m*k] = vec(X * Hadd')
b[m*k+1:end] = diag(W0' * X * H0')
Comment on lines +196 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
m, r0 = size(W0)
k = size(Hadd, 1)
b = zeros(Float64, m*k + r0)
b[1:m*k] = vec(X * Hadd')
b[m*k+1:end] = diag(W0' * X * H0')
b = vcat(vec(X * Hadd'), diag(W0' * X * H0'))

Though computing all of W0' * X * X0' and then keeping only the diagonal seems wasteful?

return b
end

function init_W(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; α = nothing) where T
Expand Down Expand Up @@ -176,4 +236,21 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac
return β[:]
end

function truncatepos(Y, X, W, H)
ΔX = max.(zero(eltype(X)), X - W*H)
Yout = similar(Y)
for j in axes(Y, 2)
y = view(Y, :, j)
yp = max.(y, zero(eltype(y)))
ym = max.(-y, zero(eltype(y)))
if sum(ΔX * yp) >= sum(ΔX * ym)
Yout[:, j] = yp
else
Yout[:, j] = ym
end
end
return Yout
end


end
39 changes: 34 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@ W_GT, H_GT = generate_ground_truth()
H = H_GT
X = W*H
standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X)))
W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4);
_, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4);
img_tol_int = sum(abs2, X)
@test size(W_gsvd, 2) == 10
@test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted this test because the standard NMF also generate perfect results on my machine. On RIS, it generate a bad results. One question: in the document of this repo, should we keep the standard NMF result?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might need to have you walk me through the issues.

@test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10

X = rand(30, 20)
W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd)
W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd)
_, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd)
_, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd)
@test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12
@test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is also interesting. One RIS, when I include the test file, it cannot pass and the sum(abs2, W_gsvd_1-W_gsvd_2)=1e-7. However, when I copy paste these lines in REPL, it always passes. On my machine and github, these tests always pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably depends on the specific random numbers. I think the seed gets saved in the testset, if that helps debug.

end

@testset "GsvdInitialization.jl" begin
@testset "GsvdInitialization" begin
W, H = rand(10, 3), rand(3, 8)
X = W*H
U, S, V = svd(X)
Expand Down Expand Up @@ -54,3 +53,33 @@ end
@test β.*β0 ≈ ones(3)

end

@testset "joint optimize W and alpha" begin
W = W_GT
H = H_GT
X = W*H
standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X)))
_, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4, initW=:joint);
img_tol_int = sum(abs2, X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem used, though you recompute the same quantity a couple lines down.

@test size(W_gsvd, 2) == 10
@test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10

W, H = rand(10, 3), rand(3, 8)
X = W*H
U, S, V = svd(X)

W0, H0 = copy(W), copy(H)
Hadd = rand(2, 8)
Wadd, a = GsvdInitialization.init_Wa(X, W0, H0, Hadd)
@test a ≈ ones(size(W0, 2))
@test norm(Wadd) <= 1e-8

G = GsvdInitialization.gram_sp_C(W0, H0, Hadd)[1]
b = GsvdInitialization.gram_b(X, W0, H0, Hadd)
Wadd = rand(10, 2)
α = rand(3)
θ = vcat(vec(Wadd), α)
E = θ'*G*θ-2*b'*θ+sum(abs2, X)
@test abs(E-sum(abs2, X-[repeat(α', size(W0, 1)).*W0 Wadd]*[H0;Hadd])) <= 1e-12

end