Skip to content

AD rules that apply to KroneckerProducts #92

@elisno

Description

@elisno

(Related to #11)

I'm trying to wrap my head around getting gradients with kron/kronecker.

  1. Is it sufficient to define custom AD rules for the vec-trick with ChainRulesCore.jl
function rrule(::typeof(*), K::KroneckerProduct, x::AbstractVector)
    function times_vec_pullback(ΔΩ)
        ...
    end
    return K*x, times_vec_pullback
end

function rrule(::typeof(*), K::KroneckerProduct, X::AbstractMatrix)
    function times_mat_pullback(ΔΩ)
        ...
    end
    return K*X, times_mat_pullback
end
  1. Do we also need to define rules for the constructor as well to get gradients?
function rrule(::typeof(kronecker), A::AbstractMatrix, B::AbstractMatrix)
    function kronecker_pullback(ΔΩ)
        ...
    end
    return kronecker(A, B), kronecker_pullback
end
  1. Should the pullbacks also be lazy? I found this to be a decent overview on finding vectorized derivatives. Would the pullbacks then just be reshape rules for these vectorized derivatives?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions