Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ authors = ["Hiroshi Shinaoka <h.shinaoka@gmail.com>", "GiggleLiu <cacate0129@gma
publish = false

[workspace.dependencies]
chainrules-core = { git = "https://github.com/tensor4all/chainrules-rs", rev = "c45a23093df0b4d8967888a8b1b7bb40ffbfeb18" }
chainrules = { git = "https://github.com/tensor4all/chainrules-rs", rev = "c45a23093df0b4d8967888a8b1b7bb40ffbfeb18" }
chainrules-core = { git = "https://github.com/tensor4all/chainrules-rs", branch = "deferred-hvp-tangents" }
chainrules = { git = "https://github.com/tensor4all/chainrules-rs", branch = "deferred-hvp-tangents" }
41 changes: 38 additions & 3 deletions crates/tidu/src/engine/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ impl<V: Differentiable> AutogradGraph<V> {
output_node: NodeId,
seed: V::Tangent,
seed_tangent: V::Tangent,
tangents: &[Option<V::Tangent>],
) -> AdResult<(Vec<Option<V::Tangent>>, Vec<Option<V::Tangent>>)>
where
V::Tangent: Clone + Differentiable<Tangent = V::Tangent>,
Expand All @@ -172,7 +173,10 @@ impl<V: Differentiable> AutogradGraph<V> {
continue;
};
let cot_tan = cot_tangents[i].take().unwrap_or_else(|| cot.zero_tangent());
let input_grads = rule.pullback_with_tangents(&cot, &cot_tan)?;
let input_tangents_fn = |node: NodeId| -> Option<&V::Tangent> {
tangents.get(node.index()).and_then(|t| t.as_ref())
};
let input_grads = rule.pullback_with_tangents(&cot, &cot_tan, &input_tangents_fn)?;
for (node_id, grad, grad_tan) in input_grads {
let idx = node_id.index();
match cotangents[idx].take() {
Expand All @@ -193,20 +197,51 @@ impl<V: Differentiable> AutogradGraph<V> {
Ok((cotangents, cot_tangents))
}

/// Two-phase HVP: forward tangent propagation then reverse pass.
///
/// Phase 1 walks nodes 0..=output_node and calls `forward_tangents` on
/// each op, building a `Vec<Option<V::Tangent>>`. Leaves are looked up
/// in `leaf_tangents`.
///
/// Phase 2 runs `compute_cotangents_with_tangents`, passing the tangents
/// vec as a closure to `pullback_with_tangents`.
pub(crate) fn hvp_from(
&self,
output_node: NodeId,
seed: V::Tangent,
seed_tangent: V::Tangent,
leaf_tangents: &std::collections::HashMap<NodeId, V::Tangent>,
) -> AdResult<HvpResult<V>>
where
V::Tangent: Clone + Differentiable<Tangent = V::Tangent>,
{
let n = self.nodes.len();
if output_node.index() >= n {
return Err(AutodiffError::MissingNode);
}

// Phase 1: Forward tangent propagation.
let mut tangents: Vec<Option<V::Tangent>> = vec![None; n];
for i in 0..=output_node.index() {
let node = &self.nodes[i];
if node.is_leaf {
// Look up in leaf_tangents HashMap.
tangents[i] = leaf_tangents.get(&NodeId::new(i)).cloned();
} else if let Some(rule) = node.rule.as_ref() {
let tangents_fn = |node: NodeId| -> Option<&V::Tangent> {
tangents.get(node.index()).and_then(|t| t.as_ref())
};
tangents[i] = rule.forward_tangents(&tangents_fn)?;
}
}

// Phase 2: Reverse pass with tangents.
let (mut cotangents, mut cot_tangents) =
self.compute_cotangents_with_tangents(output_node, seed, seed_tangent)?;
self.compute_cotangents_with_tangents(output_node, seed, seed_tangent, &tangents)?;

let mut gradients = Gradients::new();
let mut hvp = Gradients::new();
for i in 0..self.nodes.len() {
for i in 0..n {
if !self.nodes[i].is_leaf {
continue;
}
Expand Down
35 changes: 26 additions & 9 deletions crates/tidu/src/engine/results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,18 @@ impl<V: Differentiable> PullbackPlan<V> {
/// Result of a forward-over-reverse HVP computation.
///
/// Contains both the standard gradient and the Hessian-vector
/// product H*v, where v is the tangent direction set on leaf values
/// via [`crate::Tape::leaf_with_tangent`].
/// product H*v, where v is the tangent direction passed as a
/// `HashMap<NodeId, V::Tangent>` to [`crate::Tape::hvp`].
///
/// # Examples
///
/// ```rust
/// use std::collections::HashMap;
/// use tidu::{AdResult, HvpResult, NodeId, ReverseRule, Tape};
///
/// struct SquareRuleHvp {
/// input: NodeId,
/// x: f64,
/// dx: f64,
/// }
///
/// impl ReverseRule<f64> for SquareRuleHvp {
Expand All @@ -219,31 +219,48 @@ impl<V: Differentiable> PullbackPlan<V> {
/// vec![self.input]
/// }
///
/// fn pullback_with_tangents(
/// fn forward_tangents<'t>(
/// &self,
/// input_tangents: &dyn Fn(NodeId) -> Option<&'t f64>,
/// ) -> AdResult<Option<f64>>
/// where
/// f64: 't,
/// {
/// let dx = input_tangents(self.input).copied().unwrap_or(0.0);
/// Ok(Some(2.0 * self.x * dx))
/// }
///
/// fn pullback_with_tangents<'t>(
/// &self,
/// cotangent: &f64,
/// cotangent_tangent: &f64,
/// ) -> AdResult<Vec<(NodeId, f64, f64)>> {
/// input_tangents: &dyn Fn(NodeId) -> Option<&'t f64>,
/// ) -> AdResult<Vec<(NodeId, f64, f64)>>
/// where
/// f64: 't,
/// {
/// let dx = input_tangents(self.input).copied().unwrap_or(0.0);
/// Ok(vec![(
/// self.input,
/// 2.0 * self.x * *cotangent,
/// 2.0 * self.dx * *cotangent + 2.0 * self.x * *cotangent_tangent,
/// 2.0 * dx * *cotangent + 2.0 * self.x * *cotangent_tangent,
/// )])
/// }
/// }
///
/// let tape = Tape::<f64>::new();
/// let x = tape.leaf_with_tangent(3.0, 1.0).unwrap();
/// let x = tape.leaf(3.0);
/// let y = tape.record_op(
/// 9.0,
/// Box::new(SquareRuleHvp {
/// input: x.node_id().unwrap(),
/// x: 3.0,
/// dx: 1.0,
/// }),
/// None,
/// );
/// let result: HvpResult<f64> = tape.hvp(&y).unwrap();
/// let mut leaf_tangents = HashMap::new();
/// leaf_tangents.insert(x.node_id().unwrap(), 1.0);
/// let result: HvpResult<f64> = tape.hvp(&y, &leaf_tangents).unwrap();
/// assert_eq!(*result.gradients.get(x.node_id().unwrap()).unwrap(), 6.0);
/// assert_eq!(*result.hvp.get(x.node_id().unwrap()).unwrap(), 2.0);
/// ```
Expand Down
18 changes: 13 additions & 5 deletions crates/tidu/src/engine/tape.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex, MutexGuard};

use crate::engine::{AutogradGraph, Gradients, TrackedValue};
Expand Down Expand Up @@ -223,13 +224,19 @@ impl<V: Differentiable> Tape<V> {
///
/// Requires:
/// - A **scalar** loss (`num_elements() == 1`).
/// - Tangents set on leaves via [`Tape::leaf_with_tangent`] — these
/// define the direction vector **v** in H·v.
/// - Each rule must implement [`ReverseRule::pullback_with_tangents`]
/// (the default returns `Err(HvpNotSupported)`).
/// - `leaf_tangents` maps leaf [`NodeId`] values to tangent directions
/// **v** for the Hessian-vector product H·v. Leaves not present in
/// the map are treated as having zero tangent.
/// - Each rule must implement [`ReverseRule::forward_tangents`] and
/// [`ReverseRule::pullback_with_tangents`]
/// (the defaults return `Err(HvpNotSupported)`).
///
/// Returns an [`HvpResult`] containing both the gradient and the HVP.
pub fn hvp(&self, loss: &TrackedValue<V>) -> AdResult<HvpResult<V>>
pub fn hvp(
&self,
loss: &TrackedValue<V>,
leaf_tangents: &HashMap<NodeId, V::Tangent>,
) -> AdResult<HvpResult<V>>
where
V::Tangent: Differentiable<Tangent = V::Tangent>,
{
Expand All @@ -244,6 +251,7 @@ impl<V: Differentiable> Tape<V> {
loss_node,
loss.value.seed_cotangent(),
loss.value.zero_tangent(),
leaf_tangents,
)
}

Expand Down
58 changes: 39 additions & 19 deletions crates/tidu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,26 @@
//!
//! ## Scalar Hessian-Vector Product
//!
//! A Hessian-vector product (HVP) computes **H·v** the product of the
//! Hessian of a scalar function with a tangent direction **v** without
//! A Hessian-vector product (HVP) computes **H*v** -- the product of the
//! Hessian of a scalar function with a tangent direction **v** -- without
//! materialising the full Hessian matrix. `tidu` achieves this via
//! forward-over-reverse mode.
//!
//! To enable HVP, implement [`ReverseRule::pullback_with_tangents`] on your
//! rule. The default implementation returns `Err(HvpNotSupported)`, so it is
//! only required when you need second-order derivatives.
//! To enable HVP, implement [`ReverseRule::forward_tangents`] and
//! [`ReverseRule::pullback_with_tangents`] on your rule. The default
//! implementations return `Err(HvpNotSupported)`, so they are only
//! required when you need second-order derivatives.
//!
//! Tangent directions are passed as a `HashMap<NodeId, V::Tangent>` to
//! [`Tape::hvp`] rather than being stored on leaves or rules.
//!
//! ```rust
//! use std::collections::HashMap;
//! use tidu::{AdResult, HvpResult, NodeId, ReverseRule, Tape};
//!
//! struct SquareRuleHvp {
//! input: NodeId,
//! x: f64,
//! dx: f64, // tangent of x, must match the tangent passed to leaf_with_tangent
//! }
//!
//! impl ReverseRule<f64> for SquareRuleHvp {
Expand All @@ -123,39 +127,55 @@
//! vec![self.input]
//! }
//!
//! // Forward tangent propagation: d(x^2) = 2*x*dx.
//! fn forward_tangents<'t>(
//! &self,
//! input_tangents: &dyn Fn(NodeId) -> Option<&'t f64>,
//! ) -> AdResult<Option<f64>>
//! where
//! f64: 't,
//! {
//! let dx = input_tangents(self.input).copied().unwrap_or(0.0);
//! Ok(Some(2.0 * self.x * dx))
//! }
//!
//! // Forward-over-reverse: differentiates the pullback itself.
//! // `cotangent` is the standard reverse-mode adjoint.
//! // `cotangent_tangent` is its tangent component from the forward pass.
//! // Returns (node, gradient, gradient_tangent) triples.
//! fn pullback_with_tangents(
//! // `input_tangents` provides the forward tangent for each input node.
//! fn pullback_with_tangents<'t>(
//! &self,
//! cotangent: &f64,
//! cotangent_tangent: &f64,
//! ) -> AdResult<Vec<(NodeId, f64, f64)>> {
//! input_tangents: &dyn Fn(NodeId) -> Option<&'t f64>,
//! ) -> AdResult<Vec<(NodeId, f64, f64)>>
//! where
//! f64: 't,
//! {
//! let dx = input_tangents(self.input).copied().unwrap_or(0.0);
//! Ok(vec![(
//! self.input,
//! 2.0 * self.x * *cotangent,
//! 2.0 * self.dx * *cotangent + 2.0 * self.x * *cotangent_tangent,
//! 2.0 * dx * *cotangent + 2.0 * self.x * *cotangent_tangent,
//! )])
//! }
//! }
//!
//! let tape = Tape::<f64>::new();
//! // Set tangent v = 1.0 on the leaf for the HVP direction.
//! let x = tape.leaf_with_tangent(3.0, 1.0).unwrap();
//! let x = tape.leaf(3.0);
//! let y = tape.record_op(
//! 9.0, // forward value: 3.0^2
//! Box::new(SquareRuleHvp {
//! input: x.node_id().unwrap(),
//! x: 3.0,
//! dx: 1.0,
//! }),
//! None, // no output tangent (only needed for HVP)
//! None,
//! );
//! let result: HvpResult<f64> = tape.hvp(&y).unwrap();
//! // Gradient: d(x^2)/dx at x=3 → 6.0
//! // Pass tangent direction v = 1.0 via HashMap.
//! let mut leaf_tangents = HashMap::new();
//! leaf_tangents.insert(x.node_id().unwrap(), 1.0);
//! let result: HvpResult<f64> = tape.hvp(&y, &leaf_tangents).unwrap();
//! // Gradient: d(x^2)/dx at x=3 = 6.0
//! assert_eq!(*result.gradients.get(x.node_id().unwrap()).unwrap(), 6.0);
//! // HVP: H·v = d²(x²)/dx² · 1.0 = 2.0
//! // HVP: H*v = d^2(x^2)/dx^2 * 1.0 = 2.0
//! assert_eq!(*result.hvp.get(x.node_id().unwrap()).unwrap(), 2.0);
//! ```
//!
Expand Down
Loading
Loading