Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ julia> derivative([f],x,y,x) #take derivative wrt x, then y, then x

[:, :, 2] =
(((((x * y) + 1) * exp((x * y))) * y) + (y * exp((x * y))))

julia> higher_order_derivatives(f, [x, y], 3; symmetric=false); #all mixed partials, order 3
```

Compute derivative of a function and make executable
Expand Down Expand Up @@ -208,4 +210,4 @@ julia> jtv_exe([1.0,2.0,3.0,4.0])
2-element Vector{Float64}:
-1.4116362015446517
-0.04368042858415033
```
```
201 changes: 201 additions & 0 deletions src/Jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,207 @@ end
export sparse_hessian


struct DerivativeContext
variables::Vector{Node}
cache::IdDict{Node,Dict{Tuple{Vararg{Int}},Node}}
symmetric::Bool
end

DerivativeContext(variables::AbstractVector{<:Node}; symmetric::Bool=true) =
DerivativeContext(collect(variables), IdDict{Node,Dict{Tuple{Vararg{Int}},Node}}(), symmetric)

@inline function _canonical_multiindex(ctx::DerivativeContext, idx::Tuple{Vararg{Int}})
if !ctx.symmetric || length(idx) <= 1
return idx
end
if length(idx) == 2
a, b = idx
return a <= b ? idx : (b, a)
end
tmp = collect(idx)
sort!(tmp)
return Tuple(tmp)
end

@inline function _cache_dict!(ctx::DerivativeContext, node::Node)
return get!(ctx.cache, node) do
Dict{Tuple{Vararg{Int}},Node}()
end
end

@inline function _cache_get(ctx::DerivativeContext, node::Node, idx::Tuple{Vararg{Int}})
dict = get(ctx.cache, node, nothing)
dict === nothing && return nothing
return get(dict, _canonical_multiindex(ctx, idx), nothing)
end

@inline function _cache_set!(ctx::DerivativeContext, node::Node, idx::Tuple{Vararg{Int}}, value::Node)
dict = _cache_dict!(ctx, node)
dict[_canonical_multiindex(ctx, idx)] = value
return value
end

@inline function _maybe_add_order_node!(nodes::Vector{Node}, index_map::IdDict{Node,Int}, node::Node)
if is_constant(node)
return
end
if !haskey(index_map, node)
push!(nodes, node)
index_map[node] = length(nodes)
end
end

function _finalize_indices!(indices::Vector{Tuple{Vararg{Int}}})
sort!(indices)
unique!(indices)
return indices
end

function _batched_jacobian(roots::Vector{Node}, variables::Vector{Node})
nroots = length(roots)
nvars = length(variables)
if nroots == 0 || nvars == 0
return Matrix{Node}(undef, nroots, nvars)
end
graph = DerivativeGraph(roots)
return _symbolic_jacobian!(graph, variables)
end

"""
higher_order_derivatives(
terms::AbstractArray{<:Node},
variables::AbstractVector{<:Node},
order::Integer;
symmetric::Bool=true
)

Compute all order-`order` partial derivatives of `terms` with respect to `variables`.
Returns an array with shape `(size(terms)..., length(variables), ..., length(variables))`,
with `order` trailing variable dimensions. When `terms` is a scalar `Node`, a convenience
method returns an array with only the variable dimensions (e.g., a Hessian for `order=2`).

If `symmetric=true`, mixed partials are canonicalized so `∂^2 f / ∂x∂y` and
`∂^2 f / ∂y∂x` map to the same cached node.
"""
function higher_order_derivatives(
terms::AbstractArray{<:Node},
variables::AbstractVector{<:Node},
order::Integer;
symmetric::Bool=true
)
if order < 1
throw(ErrorException("higher_order_derivatives requires order >= 1."))
end

roots = vec(terms)
root_shape = size(terms)
vars = collect(variables)
nroots = length(roots)
nvars = length(vars)

if nvars == 0
out_shape = (root_shape..., ntuple(_ -> 0, order)...)
return Array{Node}(undef, out_shape)
end

ctx = DerivativeContext(vars; symmetric=symmetric)

current_indices = [Tuple{Vararg{Int}}[] for _ in 1:nroots]
order_nodes = Node[]
order_index = IdDict{Node,Int}()

jac = _batched_jacobian(roots, vars)
for r in 1:nroots
for v in 1:nvars
dnode = jac[r, v]
idx = _canonical_multiindex(ctx, (v,))
_cache_set!(ctx, roots[r], idx, dnode)
push!(current_indices[r], idx)
_maybe_add_order_node!(order_nodes, order_index, dnode)
end
end
for r in 1:nroots
_finalize_indices!(current_indices[r])
end

for ord in 2:order
if !isempty(order_nodes)
order_jac = _batched_jacobian(order_nodes, vars)
for node_idx in eachindex(order_nodes)
node = order_nodes[node_idx]
for v in 1:nvars
dnode = order_jac[node_idx, v]
_cache_set!(ctx, node, (v,), dnode)
end
end
end

next_indices = [Tuple{Vararg{Int}}[] for _ in 1:nroots]
next_nodes = Node[]
next_index = IdDict{Node,Int}()

for r in 1:nroots
root = roots[r]
for idx in current_indices[r]
base_node = _cache_get(ctx, root, idx)
if base_node === nothing
continue
end
for v in 1:nvars
new_idx = _canonical_multiindex(ctx, (idx..., v))
dnode = _cache_get(ctx, base_node, (v,))
if dnode === nothing
dnode = zero(Node)
_cache_set!(ctx, base_node, (v,), dnode)
end
_cache_set!(ctx, root, new_idx, dnode)
push!(next_indices[r], new_idx)
_maybe_add_order_node!(next_nodes, next_index, dnode)
end
end
end

for r in 1:nroots
_finalize_indices!(next_indices[r])
end
current_indices = next_indices
order_nodes = next_nodes
order_index = next_index
end

out_shape = (root_shape..., ntuple(_ -> nvars, order)...)
result = Array{Node}(undef, out_shape)
root_positions = CartesianIndices(root_shape)
root_linear = LinearIndices(root_shape)
var_ranges = ntuple(_ -> 1:nvars, order)

for rpos in root_positions
root = roots[root_linear[rpos]]
for idx in Iterators.product(var_ranges...)
idx_tuple = Tuple(idx)
val = _cache_get(ctx, root, idx_tuple)
if val === nothing
val = zero(Node)
end
result[rpos, idx_tuple...] = val
end
end

return result
end

function higher_order_derivatives(
term::Node,
variables::AbstractVector{<:Node},
order::Integer;
symmetric::Bool=true
)
result = higher_order_derivatives([term], variables, order; symmetric=symmetric)
return dropdims(result; dims=1)
end
export higher_order_derivatives




"""
Expand Down
123 changes: 123 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1816,6 +1816,129 @@ end
)
end

@testitem "higher_order_derivatives order1 matches jacobian" begin
import FastDifferentiation as FD
FD.@variables x y

f = [x^2 + y, x * y]
hod = FD.higher_order_derivatives(f, [x, y], 1)
jac = FD.jacobian(f, [x, y])

hod_exe = FD.make_function(hod, [x, y])
jac_exe = FD.make_function(jac, [x, y])

@test isapprox(hod_exe([1.5, -2.0]), jac_exe([1.5, -2.0]))
end

@testitem "higher_order_derivatives order2 matches hessian" begin
import FastDifferentiation as FD
FD.@variables x y z

f = x^2 * y^2 * z^2
hod = FD.higher_order_derivatives(f, [x, y, z], 2)
h = FD.hessian(f, [x, y, z])

hod_exe = FD.make_function(hod, [x, y, z])
h_exe = FD.make_function(h, [x, y, z])

@test isapprox(hod_exe([1, 2, 3]), h_exe([1, 2, 3]))
end

@testitem "higher_order_derivatives symmetry and order3 entry" begin
import FastDifferentiation as FD
FD.@variables x y

f = x^3 * y^2
hod2 = FD.higher_order_derivatives(f, [x, y], 2; symmetric=true)
@test hod2[1, 2] === hod2[2, 1]

hod3 = FD.higher_order_derivatives(f, [x, y], 3; symmetric=false)
dxyx = FD.derivative([f], x, y, x)[1]

hod3_exe = FD.make_function(hod3, [x, y])
dxyx_exe = FD.make_function([dxyx], [x, y])

@test isapprox(hod3_exe([2.0, -1.5])[1, 2, 1], dxyx_exe([2.0, -1.5])[1])
end

@testitem "higher_order_derivatives performance (opt-in)" begin
import FastDifferentiation as FD

if get(ENV, "FD_PERF", "") != "1"
@info "Skipping performance test. Set FD_PERF=1 to run."
return
end

function old_all_partials_scalar(f, vars, order)
n = length(vars)
result = Array{FD.Node}(undef, ntuple(_ -> n, order)...)
for idx in Iterators.product(ntuple(_ -> 1:n, order)...)
var_tuple = ntuple(j -> vars[idx[j]], order)
result[idx...] = FD.derivative([f], var_tuple...)[1]
end
return result
end

function old_all_partials_array(terms::AbstractArray{<:FD.Node}, vars, order)
n = length(vars)
result = Array{FD.Node}(undef, size(terms)..., ntuple(_ -> n, order)...)
term_axes = axes(terms)
for idx in Iterators.product(ntuple(_ -> 1:n, order)...)
var_tuple = ntuple(j -> vars[idx[j]], order)
deriv = FD.derivative(terms, var_tuple...)
view(result, term_axes..., idx...) .= deriv
end
return result
end

function timed_min(fn; repeats=3)
times = Float64[]
for _ in 1:repeats
GC.gc()
push!(times, @elapsed fn())
end
return minimum(times)
end

function perf_case(label, old_fn, new_fn; repeats=3)
old_fn()
new_fn()
old_t = timed_min(old_fn; repeats=repeats)
new_t = timed_min(new_fn; repeats=repeats)
ratio = old_t / new_t
@info "higher_order_derivatives perf" label=label old_seconds=old_t new_seconds=new_t speedup=ratio
return ratio
end

FD.clear_cache()

FD.@variables t
A = [t t^2; 3t^2 5]
ratio_column = perf_case(
"Examples/columnderivative.jl",
() -> old_all_partials_array(A, [t], 2),
() -> FD.higher_order_derivatives(A, [t], 2; symmetric=true)
)

FD.@variables x y z
vars_xyz = [x, y, z]
f_quad = x^2 + y^2 + z^2
ratio_hess_quad = perf_case(
"Examples/hessian.jl: x^2+y^2+z^2",
() -> old_all_partials_scalar(f_quad, vars_xyz, 2),
() -> FD.higher_order_derivatives(f_quad, vars_xyz, 2; symmetric=true)
)

f_prod = x * y * z
ratio_hess_prod = perf_case(
"Examples/hessian.jl: x*y*z",
() -> old_all_partials_scalar(f_prod, vars_xyz, 2),
() -> FD.higher_order_derivatives(f_prod, vars_xyz, 2; symmetric=true)
)

@test all(isfinite.([ratio_column, ratio_hess_quad, ratio_hess_prod]))
end

@testitem "sparse hessian" begin
using SparseArrays
import FastDifferentiation as FD
Expand Down
Loading