-
Notifications
You must be signed in to change notification settings - Fork 0
add joint optimize W and alpha #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); | ||
| img_tol_int = sum(abs2, X) | ||
| @test size(W_gsvd, 2) == 10 | ||
| @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted this test because the standard NMF also generate perfect results on my machine. On RIS, it generate a bad results. One question: in the document of this repo, should we keep the standard NMF result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might need to have you walk me through the issues.
| _, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) | ||
| _, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) | ||
| @test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12 | ||
| @test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is also interesting. One RIS, when I include the test file, it cannot pass and the sum(abs2, W_gsvd_1-W_gsvd_2)=1e-7. However, when I copy paste these lines in REPL, it always passes. On my machine and github, these tests always pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably depends on the specific random numbers. I think the seed gets saved in the testset, if that helps debug.
| else | ||
| # @show alg | ||
| W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f) | ||
| W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This version is ready for review. I have not written the document yet and after we merge this code change (or we finish the code change), I will update the documentation.
i added a keyword argument to add a new method for our jointly optimizing Wadd and alpha approach. The default approach is separately computing Wadd and alpha.
| end | ||
|
|
||
| function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...) | ||
| m, n = size(W0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main function change. A split of change is added.
timholy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things to think about.
| # @show alg | ||
| W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW) | ||
| if alg == :multmse | ||
| @show alg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
| W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW) | ||
| if alg == :multmse | ||
| @show alg | ||
| W_recover, H_recover = max.(W_recover, 1e-5), max.(H_recover, 1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the 1e-5 be hard-coded or a kwarg?
| m, r0 = size(W0) | ||
| k = size(Hadd, 1) | ||
| b = zeros(Float64, m*k + r0) | ||
| b[1:m*k] = vec(X * Hadd') | ||
| b[m*k+1:end] = diag(W0' * X * H0') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| m, r0 = size(W0) | |
| k = size(Hadd, 1) | |
| b = zeros(Float64, m*k + r0) | |
| b[1:m*k] = vec(X * Hadd') | |
| b[m*k+1:end] = diag(W0' * X * H0') | |
| b = vcat(vec(X * Hadd'), diag(W0' * X * H0')) |
Though computing all of W0' * X * X0' and then keeping only the diagonal seems wasteful?
| _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); | ||
| img_tol_int = sum(abs2, X) | ||
| @test size(W_gsvd, 2) == 10 | ||
| @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might need to have you walk me through the issues.
| _, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) | ||
| _, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) | ||
| @test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12 | ||
| @test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably depends on the specific random numbers. I think the seed gets saved in the testset, if that helps debug.
| X = W*H | ||
| standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X))) | ||
| _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4, initW=:joint); | ||
| img_tol_int = sum(abs2, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem used, though you recompute the same quantity a couple lines down.
| version = "1.0.0" | ||
|
|
||
| [deps] | ||
| Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to update [compat]
not ready for review