Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions crates/chainrules-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,30 +301,59 @@ pub trait ReverseRule<V: Differentiable>: Send + Sync {
/// Returns input node IDs this rule depends on.
fn inputs(&self) -> Vec<NodeId>;

/// Computes the forward tangent of this operation's output.
///
/// Given a closure that returns the tangent for each input node
/// (or `None` if the input has no tangent), returns the output tangent.
///
/// The default implementation returns [`AutodiffError::HvpNotSupported`].
/// Operations that support deferred HVP override this method.
fn forward_tangents<'t>(
&self,
input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
) -> AdResult<Option<V::Tangent>>
where
V::Tangent: 't,
{
let _ = input_tangents;
Err(AutodiffError::HvpNotSupported)
}

/// Computes pullback with tangent propagation for HVP.
///
/// Given an output cotangent and its tangent, returns
/// Given an output cotangent, its tangent, and a closure providing input
/// tangents by node ID, returns
/// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
///
/// The `input_tangents` closure provides access to forward-propagated
/// tangents for each input node, enabling deferred tangent injection
/// without storing tangents in the rule struct.
///
/// The default implementation returns [`AutodiffError::HvpNotSupported`].
/// Operations that support forward-over-reverse HVP override this method.
///
/// # Examples
///
/// ```ignore
/// // Called internally by hvp(); users rarely call this directly.
/// let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
/// let results = rule.pullback_with_tangents(
/// &cotangent, &cotangent_tangent, &|node| tangents_vec[node.index()].as_ref(),
/// )?;
/// for (node_id, grad, grad_tangent) in results {
/// // grad: standard cotangent for this input
/// // grad_tangent: cotangent tangent for HVP
/// }
/// ```
fn pullback_with_tangents(
fn pullback_with_tangents<'t>(
&self,
cotangent: &V::Tangent,
cotangent_tangent: &V::Tangent,
) -> AdResult<Vec<PullbackWithTangentsEntry<V>>> {
let _ = (cotangent, cotangent_tangent);
input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
) -> AdResult<Vec<PullbackWithTangentsEntry<V>>>
where
V::Tangent: 't,
{
let _ = (cotangent, cotangent_tangent, input_tangents);
Err(AutodiffError::HvpNotSupported)
}
}
Expand Down
13 changes: 12 additions & 1 deletion crates/chainrules-core/tests/core_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,18 @@ fn reverse_rule_inputs() {
#[test]
fn reverse_rule_default_hvp_returns_error() {
let rule = DummyRule;
let result = rule.pullback_with_tangents(&1.0, &1.0);
let result = rule.pullback_with_tangents(&1.0, &1.0, &|_| None);
assert!(result.is_err());
match result.unwrap_err() {
AutodiffError::HvpNotSupported => {}
other => panic!("expected HvpNotSupported, got {other:?}"),
}
}

#[test]
fn reverse_rule_default_forward_tangents_returns_error() {
let rule = DummyRule;
let result = rule.forward_tangents(&|_| None);
assert!(result.is_err());
match result.unwrap_err() {
AutodiffError::HvpNotSupported => {}
Expand Down
Loading