Skip to content

Set AD rules#93

Draft
elisno wants to merge 3 commits intoMichielStock:masterfrom
elisno:elisno/chainrulescore
Draft

Set AD rules#93
elisno wants to merge 3 commits intoMichielStock:masterfrom
elisno:elisno/chainrulescore

Conversation

@elisno
Copy link
Copy Markdown
Contributor

@elisno elisno commented Apr 3, 2021

Resolves #92.

@codecov-io
Copy link
Copy Markdown

codecov-io commented Apr 3, 2021

Codecov Report

Merging #93 (acc9a90) into master (4967a5f) will decrease coverage by 1.25%.
The diff coverage is 69.23%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #93      +/-   ##
==========================================
- Coverage   90.96%   89.71%   -1.26%     
==========================================
  Files          11       11              
  Lines         620      632      +12     
==========================================
+ Hits          564      567       +3     
- Misses         56       65       +9     
Impacted Files Coverage Δ
src/chainrules.jl 69.23% <69.23%> (ø)
src/vectrick.jl 90.55% <0.00%> (-3.94%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4967a5f...acc9a90. Read the comment docs.

@MichielStock
Copy link
Copy Markdown
Owner

MichielStock commented Apr 6, 2021

Your code does not seem to work for some examples and gives the wrong result for others.

For example:

gradient((A, B)->sum(AB), A, B)
gradient((A, B)->sum(kron(A,B)), A, B)

Most of our Kronecker functions fall back on regular function Zygote etc should be able to handle fine. It works for logdet but not for tr and sum (which work with the native kron, e.g. gradient((A, B)-> sum(kron(A, B)), A, B). Not sure why or how we can make ChainRulesCore fall back on this underlying code.

Do you have a reference for your gradients?

elisno added 2 commits April 6, 2021 14:02
Adds a testing-function for different 'output' dimensions of each factor in the Kronecker product. It defines linear regression models with the sum of squared residuals as a loss function. Currently only works for residuals of scalar outputs. Tests are broken for outputs of higher dimensions.
@elisno
Copy link
Copy Markdown
Contributor Author

elisno commented Apr 6, 2021

Your code does not seem to work for some examples and gives the wrong result for others.

You're right, I started with the following loss function:

function loss(A, B, X)
    Z = kron(A, B)*X - y
    L = 0.5 * tr(Z' * Z)
    return L
end

where y has size (1, num_samples).
I wrote kronecker_product_pullback in the rrule with this in mind, but forgot that each sample in y can have a higher dimension.

In test/testchainrules.jl, I make a comparison of Zygote.gradient with hand-written gradients for this trivial case. I do another comparison with kronecker.

I decided to leave similar tests for higher-dimensions, but leave them with @test_broken for now.

@elisno
Copy link
Copy Markdown
Contributor Author

elisno commented Apr 6, 2021

I've been experimenting with KroneckerSum as well.

I managed to get the correct values for the pullback:

function ChainRulesCore.frule((_, ΔA, ΔB), ::KroneckerSum, A::AbstractMatrix, B::AbstractMatrix)
    Ω = (A  B)
    ∂Ω = (ΔA  ΔB)
    return Ω, ∂Ω
end

function ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)
    function kronecker_sum_pullback(ΔΩ)
        ∂A = nB .* A + Diagonal(fill(tr(B), nA))
        ∂B = nA .* B + Diagonal(fill(tr(A), nB))
        return (NO_FIELDS, ∂A, ∂B)
    end
    return (A  B), kronecker_sum_pullback
end

nA = 3
nB = 2
Ar = rand(nA,nA)
Br = rand(nB,nB)
Y_lazy, back_lazy = Zygote._pullback(, Ar, Br)
Y, back = Zygote._pullback((x,y) -> kron(x, Diagonal(ones(nB))) + kron(Diagonal(ones(nA)), y), Ar, Br)
julia> back(Y)[2:end] .≈ back_lazy(Y_lazy)[2:end]
(true, true)

Of course, this isn't useful for computing the gradient in more complicated expressions, since ΔΩ is not used in computing either ∂A or ∂B in the rrule.

@elisno
Copy link
Copy Markdown
Contributor Author

elisno commented Apr 6, 2021

Note that:

ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)

overwrites

ChainRulesCore.rrule(::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix)

Should I use something else instead of ::typeof(KroneckerProduct)/::typeof(KroneckerSum)?

@MichielStock
Copy link
Copy Markdown
Owner

Still stuck on this, why does computing gradients work for logdet but not tr or sum. It should just fall back to the simple shortcuts, for which adjoints already exist?

@MichielStock
Copy link
Copy Markdown
Owner

Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form

@elisno
Copy link
Copy Markdown
Contributor Author

elisno commented Apr 12, 2021

Still stuck on this, why does computing gradients work for logdet

Can you provide a MWE for logdet?

@elisno
Copy link
Copy Markdown
Contributor Author

elisno commented Apr 12, 2021

Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form

Maybe I misunderstood, but doesn't this only provide the frules?

@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

❌ Patch coverage is 69.23077% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 89.71%. Comparing base (4967a5f) to head (acc9a90).
⚠️ Report is 93 commits behind head on master.

Files with missing lines Patch % Lines
src/chainrules.jl 69.23% 4 Missing ⚠️
❗ Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #93      +/-   ##
==========================================
- Coverage   90.96%   89.71%   -1.26%     
==========================================
  Files          11       11              
  Lines         620      632      +12     
==========================================
+ Hits          564      567       +3     
- Misses         56       65       +9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AD rules that apply to KroneckerProducts

4 participants