diff --git a/docs/src/examples.md b/docs/src/examples.md index 32782c4..96b50d0 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -66,6 +66,8 @@ julia> derivative([f],x,y,x) #take derivative wrt x, then y, then x [:, :, 2] = (((((x * y) + 1) * exp((x * y))) * y) + (y * exp((x * y)))) + +julia> higher_order_derivatives(f, [x, y], 3; symmetric=false); #all mixed partials, order 3 ``` Compute derivative of a function and make executable @@ -208,4 +210,4 @@ julia> jtv_exe([1.0,2.0,3.0,4.0]) 2-element Vector{Float64}: -1.4116362015446517 -0.04368042858415033 -``` \ No newline at end of file +``` diff --git a/src/Jacobian.jl b/src/Jacobian.jl index e1580ca..3b703af 100644 --- a/src/Jacobian.jl +++ b/src/Jacobian.jl @@ -344,6 +344,207 @@ end export sparse_hessian +struct DerivativeContext + variables::Vector{Node} + cache::IdDict{Node,Dict{Tuple{Vararg{Int}},Node}} + symmetric::Bool +end + +DerivativeContext(variables::AbstractVector{<:Node}; symmetric::Bool=true) = + DerivativeContext(collect(variables), IdDict{Node,Dict{Tuple{Vararg{Int}},Node}}(), symmetric) + +@inline function _canonical_multiindex(ctx::DerivativeContext, idx::Tuple{Vararg{Int}}) + if !ctx.symmetric || length(idx) <= 1 + return idx + end + if length(idx) == 2 + a, b = idx + return a <= b ? idx : (b, a) + end + tmp = collect(idx) + sort!(tmp) + return Tuple(tmp) +end + +@inline function _cache_dict!(ctx::DerivativeContext, node::Node) + return get!(ctx.cache, node) do + Dict{Tuple{Vararg{Int}},Node}() + end +end + +@inline function _cache_get(ctx::DerivativeContext, node::Node, idx::Tuple{Vararg{Int}}) + dict = get(ctx.cache, node, nothing) + dict === nothing && return nothing + return get(dict, _canonical_multiindex(ctx, idx), nothing) +end + +@inline function _cache_set!(ctx::DerivativeContext, node::Node, idx::Tuple{Vararg{Int}}, value::Node) + dict = _cache_dict!(ctx, node) + dict[_canonical_multiindex(ctx, idx)] = value + return value +end + +@inline function _maybe_add_order_node!(nodes::Vector{Node}, index_map::IdDict{Node,Int}, node::Node) + if is_constant(node) + return + end + if !haskey(index_map, node) + push!(nodes, node) + index_map[node] = length(nodes) + end +end + +function _finalize_indices!(indices::Vector{Tuple{Vararg{Int}}}) + sort!(indices) + unique!(indices) + return indices +end + +function _batched_jacobian(roots::Vector{Node}, variables::Vector{Node}) + nroots = length(roots) + nvars = length(variables) + if nroots == 0 || nvars == 0 + return Matrix{Node}(undef, nroots, nvars) + end + graph = DerivativeGraph(roots) + return _symbolic_jacobian!(graph, variables) +end + +""" + higher_order_derivatives( + terms::AbstractArray{<:Node}, + variables::AbstractVector{<:Node}, + order::Integer; + symmetric::Bool=true + ) + +Compute all order-`order` partial derivatives of `terms` with respect to `variables`. +Returns an array with shape `(size(terms)..., length(variables), ..., length(variables))`, +with `order` trailing variable dimensions. When `terms` is a scalar `Node`, a convenience +method returns an array with only the variable dimensions (e.g., a Hessian for `order=2`). + +If `symmetric=true`, mixed partials are canonicalized so `∂^2 f / ∂x∂y` and +`∂^2 f / ∂y∂x` map to the same cached node. +""" +function higher_order_derivatives( + terms::AbstractArray{<:Node}, + variables::AbstractVector{<:Node}, + order::Integer; + symmetric::Bool=true +) + if order < 1 + throw(ErrorException("higher_order_derivatives requires order >= 1.")) + end + + roots = vec(terms) + root_shape = size(terms) + vars = collect(variables) + nroots = length(roots) + nvars = length(vars) + + if nvars == 0 + out_shape = (root_shape..., ntuple(_ -> 0, order)...) + return Array{Node}(undef, out_shape) + end + + ctx = DerivativeContext(vars; symmetric=symmetric) + + current_indices = [Tuple{Vararg{Int}}[] for _ in 1:nroots] + order_nodes = Node[] + order_index = IdDict{Node,Int}() + + jac = _batched_jacobian(roots, vars) + for r in 1:nroots + for v in 1:nvars + dnode = jac[r, v] + idx = _canonical_multiindex(ctx, (v,)) + _cache_set!(ctx, roots[r], idx, dnode) + push!(current_indices[r], idx) + _maybe_add_order_node!(order_nodes, order_index, dnode) + end + end + for r in 1:nroots + _finalize_indices!(current_indices[r]) + end + + for ord in 2:order + if !isempty(order_nodes) + order_jac = _batched_jacobian(order_nodes, vars) + for node_idx in eachindex(order_nodes) + node = order_nodes[node_idx] + for v in 1:nvars + dnode = order_jac[node_idx, v] + _cache_set!(ctx, node, (v,), dnode) + end + end + end + + next_indices = [Tuple{Vararg{Int}}[] for _ in 1:nroots] + next_nodes = Node[] + next_index = IdDict{Node,Int}() + + for r in 1:nroots + root = roots[r] + for idx in current_indices[r] + base_node = _cache_get(ctx, root, idx) + if base_node === nothing + continue + end + for v in 1:nvars + new_idx = _canonical_multiindex(ctx, (idx..., v)) + dnode = _cache_get(ctx, base_node, (v,)) + if dnode === nothing + dnode = zero(Node) + _cache_set!(ctx, base_node, (v,), dnode) + end + _cache_set!(ctx, root, new_idx, dnode) + push!(next_indices[r], new_idx) + _maybe_add_order_node!(next_nodes, next_index, dnode) + end + end + end + + for r in 1:nroots + _finalize_indices!(next_indices[r]) + end + current_indices = next_indices + order_nodes = next_nodes + order_index = next_index + end + + out_shape = (root_shape..., ntuple(_ -> nvars, order)...) + result = Array{Node}(undef, out_shape) + root_positions = CartesianIndices(root_shape) + root_linear = LinearIndices(root_shape) + var_ranges = ntuple(_ -> 1:nvars, order) + + for rpos in root_positions + root = roots[root_linear[rpos]] + for idx in Iterators.product(var_ranges...) + idx_tuple = Tuple(idx) + val = _cache_get(ctx, root, idx_tuple) + if val === nothing + val = zero(Node) + end + result[rpos, idx_tuple...] = val + end + end + + return result +end + +function higher_order_derivatives( + term::Node, + variables::AbstractVector{<:Node}, + order::Integer; + symmetric::Bool=true +) + result = higher_order_derivatives([term], variables, order; symmetric=symmetric) + return dropdims(result; dims=1) +end +export higher_order_derivatives + + """ diff --git a/test/runtests.jl b/test/runtests.jl index 2a8f333..89319ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1816,6 +1816,129 @@ end ) end +@testitem "higher_order_derivatives order1 matches jacobian" begin + import FastDifferentiation as FD + FD.@variables x y + + f = [x^2 + y, x * y] + hod = FD.higher_order_derivatives(f, [x, y], 1) + jac = FD.jacobian(f, [x, y]) + + hod_exe = FD.make_function(hod, [x, y]) + jac_exe = FD.make_function(jac, [x, y]) + + @test isapprox(hod_exe([1.5, -2.0]), jac_exe([1.5, -2.0])) +end + +@testitem "higher_order_derivatives order2 matches hessian" begin + import FastDifferentiation as FD + FD.@variables x y z + + f = x^2 * y^2 * z^2 + hod = FD.higher_order_derivatives(f, [x, y, z], 2) + h = FD.hessian(f, [x, y, z]) + + hod_exe = FD.make_function(hod, [x, y, z]) + h_exe = FD.make_function(h, [x, y, z]) + + @test isapprox(hod_exe([1, 2, 3]), h_exe([1, 2, 3])) +end + +@testitem "higher_order_derivatives symmetry and order3 entry" begin + import FastDifferentiation as FD + FD.@variables x y + + f = x^3 * y^2 + hod2 = FD.higher_order_derivatives(f, [x, y], 2; symmetric=true) + @test hod2[1, 2] === hod2[2, 1] + + hod3 = FD.higher_order_derivatives(f, [x, y], 3; symmetric=false) + dxyx = FD.derivative([f], x, y, x)[1] + + hod3_exe = FD.make_function(hod3, [x, y]) + dxyx_exe = FD.make_function([dxyx], [x, y]) + + @test isapprox(hod3_exe([2.0, -1.5])[1, 2, 1], dxyx_exe([2.0, -1.5])[1]) +end + +@testitem "higher_order_derivatives performance (opt-in)" begin + import FastDifferentiation as FD + + if get(ENV, "FD_PERF", "") != "1" + @info "Skipping performance test. Set FD_PERF=1 to run." + return + end + + function old_all_partials_scalar(f, vars, order) + n = length(vars) + result = Array{FD.Node}(undef, ntuple(_ -> n, order)...) + for idx in Iterators.product(ntuple(_ -> 1:n, order)...) + var_tuple = ntuple(j -> vars[idx[j]], order) + result[idx...] = FD.derivative([f], var_tuple...)[1] + end + return result + end + + function old_all_partials_array(terms::AbstractArray{<:FD.Node}, vars, order) + n = length(vars) + result = Array{FD.Node}(undef, size(terms)..., ntuple(_ -> n, order)...) + term_axes = axes(terms) + for idx in Iterators.product(ntuple(_ -> 1:n, order)...) + var_tuple = ntuple(j -> vars[idx[j]], order) + deriv = FD.derivative(terms, var_tuple...) + view(result, term_axes..., idx...) .= deriv + end + return result + end + + function timed_min(fn; repeats=3) + times = Float64[] + for _ in 1:repeats + GC.gc() + push!(times, @elapsed fn()) + end + return minimum(times) + end + + function perf_case(label, old_fn, new_fn; repeats=3) + old_fn() + new_fn() + old_t = timed_min(old_fn; repeats=repeats) + new_t = timed_min(new_fn; repeats=repeats) + ratio = old_t / new_t + @info "higher_order_derivatives perf" label=label old_seconds=old_t new_seconds=new_t speedup=ratio + return ratio + end + + FD.clear_cache() + + FD.@variables t + A = [t t^2; 3t^2 5] + ratio_column = perf_case( + "Examples/columnderivative.jl", + () -> old_all_partials_array(A, [t], 2), + () -> FD.higher_order_derivatives(A, [t], 2; symmetric=true) + ) + + FD.@variables x y z + vars_xyz = [x, y, z] + f_quad = x^2 + y^2 + z^2 + ratio_hess_quad = perf_case( + "Examples/hessian.jl: x^2+y^2+z^2", + () -> old_all_partials_scalar(f_quad, vars_xyz, 2), + () -> FD.higher_order_derivatives(f_quad, vars_xyz, 2; symmetric=true) + ) + + f_prod = x * y * z + ratio_hess_prod = perf_case( + "Examples/hessian.jl: x*y*z", + () -> old_all_partials_scalar(f_prod, vars_xyz, 2), + () -> FD.higher_order_derivatives(f_prod, vars_xyz, 2; symmetric=true) + ) + + @test all(isfinite.([ratio_column, ratio_hess_quad, ratio_hess_prod])) +end + @testitem "sparse hessian" begin using SparseArrays import FastDifferentiation as FD