(Related to #11)
I'm trying to wrap my head around getting gradients with kron/kronecker.
- Is it sufficient to define custom AD rules for the vec-trick with ChainRulesCore.jl
function rrule(::typeof(*), K::KroneckerProduct, x::AbstractVector)
function times_vec_pullback(ΔΩ)
...
end
return K*x, times_vec_pullback
end
function rrule(::typeof(*), K::KroneckerProduct, X::AbstractMatrix)
function times_mat_pullback(ΔΩ)
...
end
return K*X, times_mat_pullback
end
- Do we also need to define rules for the constructor as well to get gradients?
function rrule(::typeof(kronecker), A::AbstractMatrix, B::AbstractMatrix)
function kronecker_pullback(ΔΩ)
...
end
return kronecker(A, B), kronecker_pullback
end
- Should the pullbacks also be lazy? I found this to be a decent overview on finding vectorized derivatives. Would the pullbacks then just be reshape rules for these vectorized derivatives?
(Related to #11)
I'm trying to wrap my head around getting gradients with
kron/kronecker.