-
Notifications
You must be signed in to change notification settings - Fork 0
add joint optimize W and alpha #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,3 @@ | ||
| {} | ||
| { | ||
| "julia.environmentPath": "/Users/youdongguo/.julia/dev/GsvdInitialization" | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ module GsvdInitialization | |||||||||||||
|
|
||||||||||||||
| using LinearAlgebra, NMF, TSVD | ||||||||||||||
| using NonNegLeastSquares | ||||||||||||||
| using Kronecker, SparseArrays | ||||||||||||||
|
|
||||||||||||||
| export gsvdnmf, | ||||||||||||||
| gsvdrecover | ||||||||||||||
|
|
@@ -33,6 +34,8 @@ Other keyword arguments are passed to `NMF.nnmf`. | |||||||||||||
| function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; | ||||||||||||||
| n2 = size(first(f), 2), | ||||||||||||||
| tol_nmf=1e-4, | ||||||||||||||
| alg = :cd, | ||||||||||||||
| initW = :standard, | ||||||||||||||
| kwargs...) | ||||||||||||||
| n1 = size(W, 2) | ||||||||||||||
| kadd = n2 - n1 | ||||||||||||||
|
|
@@ -42,9 +45,14 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; | |||||||||||||
| if kadd == 0 | ||||||||||||||
| return W, H | ||||||||||||||
| else | ||||||||||||||
| W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f) | ||||||||||||||
| result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover) | ||||||||||||||
| return result_recover.W, result_recover.H | ||||||||||||||
| # @show alg | ||||||||||||||
| W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW) | ||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This version is ready for review. I have not written the document yet and after we merge this code change (or we finish the code change), I will update the documentation. i added a keyword argument to add a new method for our jointly optimizing Wadd and alpha approach. The default approach is separately computing Wadd and alpha. |
||||||||||||||
| if alg == :multmse | ||||||||||||||
| @show alg | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete |
||||||||||||||
| W_recover, H_recover = max.(W_recover, 1e-5), max.(H_recover, 1e-5) | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the 1e-5 be hard-coded or a kwarg? |
||||||||||||||
| end | ||||||||||||||
| result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover)) | ||||||||||||||
| return result_recover, result_recover.W, result_recover.H | ||||||||||||||
| end | ||||||||||||||
| end | ||||||||||||||
| gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...) | ||||||||||||||
|
|
@@ -73,13 +81,13 @@ Keyword arguments: | |||||||||||||
|
|
||||||||||||||
| Other keyword arguments are passed to `NMF.nnmf`. | ||||||||||||||
| """ | ||||||||||||||
| function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, kwargs...) | ||||||||||||||
| function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, initW = :standard, kwargs...) | ||||||||||||||
| n1, n2 = ncomponents | ||||||||||||||
| f = tsvd(X, n2) | ||||||||||||||
| W0, H0 = NMF.nndsvd(X, n1; initdata = (U = f[1], S = f[2], V = f[3])) | ||||||||||||||
| result_initial_nmf = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) | ||||||||||||||
| W_initial_nmf, H_initial_nmf = result_initial_nmf.W, result_initial_nmf.H | ||||||||||||||
| return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final) | ||||||||||||||
| return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final, initW=initW) | ||||||||||||||
| end | ||||||||||||||
| gsvdnmf(X::AbstractMatrix, ncomponents_final::Integer; kwargs...) = gsvdnmf(X, ncomponents_final-1 => ncomponents_final; kwargs...) | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -108,7 +116,7 @@ Arguments: | |||||||||||||
|
|
||||||||||||||
| `f`: SVD (or Truncated SVD) of `X` | ||||||||||||||
| """ | ||||||||||||||
| function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple) | ||||||||||||||
| function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...) | ||||||||||||||
| m, n = size(W0) | ||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main function change. A split of change is added. |
||||||||||||||
| kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components")) | ||||||||||||||
| if kadd == 0 | ||||||||||||||
|
|
@@ -117,15 +125,29 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad | |||||||||||||
| U0, S0, V0 = f | ||||||||||||||
| U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n] | ||||||||||||||
| Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd) | ||||||||||||||
| Wadd, a = init_W(X, W0, H0, Hadd) | ||||||||||||||
| Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) | ||||||||||||||
| W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] | ||||||||||||||
| cs = Wcols_modification(X, W0_1, H0_1) | ||||||||||||||
| W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 | ||||||||||||||
| if initW == :standard | ||||||||||||||
| Wadd, a = init_W(X, W0, H0, Hadd) | ||||||||||||||
| Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) | ||||||||||||||
| W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] | ||||||||||||||
| cs = Wcols_modification(X, W0_1, H0_1) | ||||||||||||||
| W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 | ||||||||||||||
| elseif initW == :joint | ||||||||||||||
| W0_2, H0_2 = gsvdrecover_Wa(X, W0, H0, Hadd; kwargs...) | ||||||||||||||
| else | ||||||||||||||
| throw(ArgumentError("Unknown initW method: $initW")) | ||||||||||||||
| end | ||||||||||||||
| return abs.(W0_2), abs.(H0_2), Λ | ||||||||||||||
| end | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray) | ||||||||||||||
| m = size(W0, 1) | ||||||||||||||
| Hadd_nn = truncatepos(Hadd', X, W0, H0)' | ||||||||||||||
| Wadd, a = init_Wa(X, W0, H0, Hadd_nn) | ||||||||||||||
| W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd], [H0; Hadd_nn] | ||||||||||||||
| return abs.(W0_1), abs.(H0_1) | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int) | ||||||||||||||
| _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0)); | ||||||||||||||
| inv_RQt = inv(R*Q') | ||||||||||||||
|
|
@@ -138,7 +160,45 @@ function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::Abs | |||||||||||||
| H_index = sortperm(Λ, rev = true)[1:kadd] | ||||||||||||||
| Hadd = inv_RQt[:, H_index] | ||||||||||||||
| Hadd_1 = V0*Hadd | ||||||||||||||
| return Hadd_1', Λ[H_index] | ||||||||||||||
| return Hadd_1', Λ | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}) where T | ||||||||||||||
| m = size(X, 1) | ||||||||||||||
| kadd = size(Hadd, 1) | ||||||||||||||
| G = gram_sp_C(W0, H0, Hadd)[1] | ||||||||||||||
| b = gram_b(X, W0, H0, Hadd) | ||||||||||||||
| θ = nonneg_lsq(G, b; alg=:fnnls, gram=true) | ||||||||||||||
| Wadd = reshape(θ[1:m*kadd], m, kadd) | ||||||||||||||
| α = θ[m*kadd+1:end] | ||||||||||||||
| return Wadd, α | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function gram_sp_C(W0, H0, Hadd) | ||||||||||||||
| m, r0 = size(W0) | ||||||||||||||
| k = size(Hadd, 1) | ||||||||||||||
| mk = m*k | ||||||||||||||
| W0W0, H0H0 = W0'*W0, H0*H0' | ||||||||||||||
| P = Hadd*H0' | ||||||||||||||
| HH = Hadd*Hadd' | ||||||||||||||
| G22 = sparse(W0W0.*H0H0) | ||||||||||||||
| G12 = zeros(Float64, mk, r0) | ||||||||||||||
| for j in 1:r0 | ||||||||||||||
| G12[:,j] .= vec(W0[:,j] * P[:,j]') | ||||||||||||||
| end | ||||||||||||||
| G12 = sparse(G12) | ||||||||||||||
| G11 = kronecker(HH, sparse(I, m, m)) | ||||||||||||||
| G = [G11 G12; G12' G22] | ||||||||||||||
| return G, G11, G12, G22 | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function gram_b(X, W0, H0, Hadd) | ||||||||||||||
| m, r0 = size(W0) | ||||||||||||||
| k = size(Hadd, 1) | ||||||||||||||
| b = zeros(Float64, m*k + r0) | ||||||||||||||
| b[1:m*k] = vec(X * Hadd') | ||||||||||||||
| b[m*k+1:end] = diag(W0' * X * H0') | ||||||||||||||
|
Comment on lines
+196
to
+200
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Though computing all of |
||||||||||||||
| return b | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function init_W(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; α = nothing) where T | ||||||||||||||
|
|
@@ -176,4 +236,21 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac | |||||||||||||
| return β[:] | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| function truncatepos(Y, X, W, H) | ||||||||||||||
| ΔX = max.(zero(eltype(X)), X - W*H) | ||||||||||||||
| Yout = similar(Y) | ||||||||||||||
| for j in axes(Y, 2) | ||||||||||||||
| y = view(Y, :, j) | ||||||||||||||
| yp = max.(y, zero(eltype(y))) | ||||||||||||||
| ym = max.(-y, zero(eltype(y))) | ||||||||||||||
| if sum(ΔX * yp) >= sum(ΔX * ym) | ||||||||||||||
| Yout[:, j] = yp | ||||||||||||||
| else | ||||||||||||||
| Yout[:, j] = ym | ||||||||||||||
| end | ||||||||||||||
| end | ||||||||||||||
| return Yout | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| end | ||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,20 +12,19 @@ W_GT, H_GT = generate_ground_truth() | |
| H = H_GT | ||
| X = W*H | ||
| standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X))) | ||
| W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); | ||
| _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); | ||
| img_tol_int = sum(abs2, X) | ||
| @test size(W_gsvd, 2) == 10 | ||
| @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I deleted this test because the standard NMF also generate perfect results on my machine. On RIS, it generate a bad results. One question: in the document of this repo, should we keep the standard NMF result?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might need to have you walk me through the issues. |
||
| @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 | ||
|
|
||
| X = rand(30, 20) | ||
| W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) | ||
| W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) | ||
| _, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) | ||
| _, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) | ||
| @test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12 | ||
| @test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12 | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is also interesting. One RIS, when I
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It probably depends on the specific random numbers. I think the seed gets saved in the testset, if that helps debug. |
||
| end | ||
|
|
||
| @testset "GsvdInitialization.jl" begin | ||
| @testset "GsvdInitialization" begin | ||
| W, H = rand(10, 3), rand(3, 8) | ||
| X = W*H | ||
| U, S, V = svd(X) | ||
|
|
@@ -54,3 +53,33 @@ end | |
| @test β.*β0 ≈ ones(3) | ||
|
|
||
| end | ||
|
|
||
| @testset "joint optimize W and alpha" begin | ||
| W = W_GT | ||
| H = H_GT | ||
| X = W*H | ||
| standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X))) | ||
| _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4, initW=:joint); | ||
| img_tol_int = sum(abs2, X) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem used, though you recompute the same quantity a couple lines down. |
||
| @test size(W_gsvd, 2) == 10 | ||
| @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 | ||
|
|
||
| W, H = rand(10, 3), rand(3, 8) | ||
| X = W*H | ||
| U, S, V = svd(X) | ||
|
|
||
| W0, H0 = copy(W), copy(H) | ||
| Hadd = rand(2, 8) | ||
| Wadd, a = GsvdInitialization.init_Wa(X, W0, H0, Hadd) | ||
| @test a ≈ ones(size(W0, 2)) | ||
| @test norm(Wadd) <= 1e-8 | ||
|
|
||
| G = GsvdInitialization.gram_sp_C(W0, H0, Hadd)[1] | ||
| b = GsvdInitialization.gram_b(X, W0, H0, Hadd) | ||
| Wadd = rand(10, 2) | ||
| α = rand(3) | ||
| θ = vcat(vec(Wadd), α) | ||
| E = θ'*G*θ-2*b'*θ+sum(abs2, X) | ||
| @test abs(E-sum(abs2, X-[repeat(α', size(W0, 1)).*W0 Wadd]*[H0;Hadd])) <= 1e-12 | ||
|
|
||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to update
[compat]