From 902535020da9c2ac484130f78a6667ada5bf288f Mon Sep 17 00:00:00 2001 From: youdongguo Date: Thu, 5 Feb 2026 00:00:37 -0600 Subject: [PATCH 1/2] add joint optimize W and alpha --- Project.toml | 2 + src/GsvdInitialization.jl | 102 ++++++++++++++++++++++++++++++++++++-- test/runtests.jl | 6 +-- 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 8dee87a..6c8ede5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,9 +4,11 @@ authors = ["youdongguo <1010705897@qq.com> and contributors"] version = "1.0.0" [deps] +Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" NonNegLeastSquares = "b7351bd1-99d9-5c5d-8786-f205a815c4d7" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" [compat] diff --git a/src/GsvdInitialization.jl b/src/GsvdInitialization.jl index 7c6e6b3..e74c92d 100644 --- a/src/GsvdInitialization.jl +++ b/src/GsvdInitialization.jl @@ -2,6 +2,7 @@ module GsvdInitialization using LinearAlgebra, NMF, TSVD using NonNegLeastSquares +using Kronecker, SparseArrays export gsvdnmf, gsvdrecover @@ -33,6 +34,7 @@ Other keyword arguments are passed to `NMF.nnmf`. function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; n2 = size(first(f), 2), tol_nmf=1e-4, + alg = :cd, kwargs...) n1 = size(W, 2) kadd = n2 - n1 @@ -42,9 +44,14 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; if kadd == 0 return W, H else - W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f) - result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover) - return result_recover.W, result_recover.H + # @show alg + W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f) + if alg == :multmse + @show alg + W_recover, H_recover = max.(W_recover, 1e-5), max.(H_recover, 1e-5) + end + result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover)) + return result_recover, result_recover.W, result_recover.H end end gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...) @@ -126,6 +133,38 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad end end +function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...) + m, n = size(W0) + kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components")) + if kadd == 0 + return W0, H0, 0 + else + U0, S0, V0 = f + U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n] + Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd) + if initW == :standard + Wadd, a = init_W(X, W0, H0, Hadd) + Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) + W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] + cs = Wcols_modification(X, W0_1, H0_1) + W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 + elseif initW == :joint + W0_2, H0_2 = gsvdrecover_Wa(X, W0, H0, Hadd; kwargs...) + else + throw(ArgumentError("Unknown initW method: $initW")) + end + return abs.(W0_2), abs.(H0_2), Λ + end +end + +function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray; kwargs...) + m = size(W0, 1) + Hadd_nn = truncatepos(Hadd', X, W0, H0)' + Wadd, a = init_Wa(X, W0, H0, Hadd_nn; kwargs...) + W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd], [H0; Hadd_nn] + return abs.(W0_1), abs.(H0_1), Λ +end + function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int) _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0)); inv_RQt = inv(R*Q') @@ -138,7 +177,45 @@ function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::Abs H_index = sortperm(Λ, rev = true)[1:kadd] Hadd = inv_RQt[:, H_index] Hadd_1 = V0*Hadd - return Hadd_1', Λ[H_index] + return Hadd_1', Λ +end + +function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; kwargs...) where T + m = size(X, 1) + kadd = size(Hadd, 1) + G = gram_sp_C(W0, H0, Hadd)[1] + b = gram_b(X, W0, H0, Hadd) + θ = nonneg_lsq(G, b; alg=:fnnls, gram=true, kwargs...) + Wadd = reshape(θ[1:m*kadd], m, kadd) + α = θ[m*kadd+1:end] + return Wadd, α +end + +function gram_sp_C(W0, H0, Hadd) + m, r0 = size(W0) + k = size(Hadd, 1) + mk = m*k + W0W0, H0H0 = W0'*W0, H0*H0' + P = Hadd*H0' + HH = Hadd*Hadd' + G22 = sparse(W0W0.*H0H0) + G12 = zeros(Float64, mk, r0) + for j in 1:r0 + G12[:,j] .= vec(W0[:,j] * P[:,j]') + end + G12 = sparse(G12) + G11 = kronecker(HH, sparse(I, m, m)) + G = [G11 G12; G12' G22] + return G, G11, G12, G22 +end + +function gram_b(X, W0, H0, Hadd) + m, r0 = size(W0) + k = size(Hadd, 1) + b = zeros(Float64, m*k + r0) + b[1:m*k] = vec(X * Hadd') + b[m*k+1:end] = diag(W0' * X * H0') + return b end function init_W(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; α = nothing) where T @@ -176,4 +253,21 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac return β[:] end +function truncatepos(Y, X, W, H) + ΔX = max.(zero(eltype(X)), X - W*H) + Yout = similar(Y) + for j in axes(Y, 2) + y = view(Y, :, j) + yp = max.(y, zero(eltype(y))) + ym = max.(-y, zero(eltype(y))) + if sum(ΔX * yp) >= sum(ΔX * ym) + Yout[:, j] = yp + else + Yout[:, j] = ym + end + end + return Yout +end + + end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e6d0f7e..2a43fa4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,15 +12,15 @@ W_GT, H_GT = generate_ground_truth() H = H_GT X = W*H standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X))) - W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); + _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); img_tol_int = sum(abs2, X) @test size(W_gsvd, 2) == 10 @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 X = rand(30, 20) - W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) - W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) + _, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) + _, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) @test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12 @test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12 end From 8a2396dd8ed6add6a06ffea17086eaa00099cf54 Mon Sep 17 00:00:00 2001 From: youdongguo Date: Thu, 5 Feb 2026 12:57:59 -0600 Subject: [PATCH 2/2] add test for joint optimize W and alpha --- .vscode/settings.json | 4 +++- src/GsvdInitialization.jl | 35 +++++++++-------------------------- test/runtests.jl | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 9e26dfe..a24a35b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1 +1,3 @@ -{} \ No newline at end of file +{ + "julia.environmentPath": "/Users/youdongguo/.julia/dev/GsvdInitialization" +} \ No newline at end of file diff --git a/src/GsvdInitialization.jl b/src/GsvdInitialization.jl index e74c92d..390dac1 100644 --- a/src/GsvdInitialization.jl +++ b/src/GsvdInitialization.jl @@ -35,6 +35,7 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; n2 = size(first(f), 2), tol_nmf=1e-4, alg = :cd, + initW = :standard, kwargs...) n1 = size(W, 2) kadd = n2 - n1 @@ -45,7 +46,7 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; return W, H else # @show alg - W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f) + W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW) if alg == :multmse @show alg W_recover, H_recover = max.(W_recover, 1e-5), max.(H_recover, 1e-5) @@ -80,13 +81,13 @@ Keyword arguments: Other keyword arguments are passed to `NMF.nnmf`. """ -function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, kwargs...) +function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, initW = :standard, kwargs...) n1, n2 = ncomponents f = tsvd(X, n2) W0, H0 = NMF.nndsvd(X, n1; initdata = (U = f[1], S = f[2], V = f[3])) result_initial_nmf = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) W_initial_nmf, H_initial_nmf = result_initial_nmf.W, result_initial_nmf.H - return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final) + return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final, initW=initW) end gsvdnmf(X::AbstractMatrix, ncomponents_final::Integer; kwargs...) = gsvdnmf(X, ncomponents_final-1 => ncomponents_final; kwargs...) @@ -115,24 +116,6 @@ Arguments: `f`: SVD (or Truncated SVD) of `X` """ -function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple) - m, n = size(W0) - kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components")) - if kadd == 0 - return W0, H0, 0 - else - U0, S0, V0 = f - U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n] - Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd) - Wadd, a = init_W(X, W0, H0, Hadd) - Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) - W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] - cs = Wcols_modification(X, W0_1, H0_1) - W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 - return abs.(W0_2), abs.(H0_2), Λ - end -end - function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...) m, n = size(W0) kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components")) @@ -157,12 +140,12 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad end end -function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray; kwargs...) +function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray) m = size(W0, 1) Hadd_nn = truncatepos(Hadd', X, W0, H0)' - Wadd, a = init_Wa(X, W0, H0, Hadd_nn; kwargs...) + Wadd, a = init_Wa(X, W0, H0, Hadd_nn) W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd], [H0; Hadd_nn] - return abs.(W0_1), abs.(H0_1), Λ + return abs.(W0_1), abs.(H0_1) end function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int) @@ -180,12 +163,12 @@ function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::Abs return Hadd_1', Λ end -function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; kwargs...) where T +function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}) where T m = size(X, 1) kadd = size(Hadd, 1) G = gram_sp_C(W0, H0, Hadd)[1] b = gram_b(X, W0, H0, Hadd) - θ = nonneg_lsq(G, b; alg=:fnnls, gram=true, kwargs...) + θ = nonneg_lsq(G, b; alg=:fnnls, gram=true) Wadd = reshape(θ[1:m*kadd], m, kadd) α = θ[m*kadd+1:end] return Wadd, α diff --git a/test/runtests.jl b/test/runtests.jl index 2a43fa4..bf063d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,6 @@ W_GT, H_GT = generate_ground_truth() _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); img_tol_int = sum(abs2, X) @test size(W_gsvd, 2) == 10 - @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 X = rand(30, 20) @@ -25,7 +24,7 @@ W_GT, H_GT = generate_ground_truth() @test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12 end -@testset "GsvdInitialization.jl" begin +@testset "GsvdInitialization" begin W, H = rand(10, 3), rand(3, 8) X = W*H U, S, V = svd(X) @@ -54,3 +53,33 @@ end @test β.*β0 ≈ ones(3) end + +@testset "joint optimize W and alpha" begin + W = W_GT + H = H_GT + X = W*H + standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X))) + _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4, initW=:joint); + img_tol_int = sum(abs2, X) + @test size(W_gsvd, 2) == 10 + @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 + + W, H = rand(10, 3), rand(3, 8) + X = W*H + U, S, V = svd(X) + + W0, H0 = copy(W), copy(H) + Hadd = rand(2, 8) + Wadd, a = GsvdInitialization.init_Wa(X, W0, H0, Hadd) + @test a ≈ ones(size(W0, 2)) + @test norm(Wadd) <= 1e-8 + + G = GsvdInitialization.gram_sp_C(W0, H0, Hadd)[1] + b = GsvdInitialization.gram_b(X, W0, H0, Hadd) + Wadd = rand(10, 2) + α = rand(3) + θ = vcat(vec(Wadd), α) + E = θ'*G*θ-2*b'*θ+sum(abs2, X) + @test abs(E-sum(abs2, X-[repeat(α', size(W0, 1)).*W0 Wadd]*[H0;Hadd])) <= 1e-12 + +end