diff --git a/crates/chainrules-core/src/lib.rs b/crates/chainrules-core/src/lib.rs index 81f2440..bf4d2f2 100644 --- a/crates/chainrules-core/src/lib.rs +++ b/crates/chainrules-core/src/lib.rs @@ -301,11 +301,34 @@ pub trait ReverseRule: Send + Sync { /// Returns input node IDs this rule depends on. fn inputs(&self) -> Vec; + /// Computes the forward tangent of this operation's output. + /// + /// Given a closure that returns the tangent for each input node + /// (or `None` if the input has no tangent), returns the output tangent. + /// + /// The default implementation returns [`AutodiffError::HvpNotSupported`]. + /// Operations that support deferred HVP override this method. + fn forward_tangents<'t>( + &self, + input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>, + ) -> AdResult> + where + V::Tangent: 't, + { + let _ = input_tangents; + Err(AutodiffError::HvpNotSupported) + } + /// Computes pullback with tangent propagation for HVP. /// - /// Given an output cotangent and its tangent, returns + /// Given an output cotangent, its tangent, and a closure providing input + /// tangents by node ID, returns /// `(node_id, input_cotangent, input_cotangent_tangent)` triples. /// + /// The `input_tangents` closure provides access to forward-propagated + /// tangents for each input node, enabling deferred tangent injection + /// without storing tangents in the rule struct. + /// /// The default implementation returns [`AutodiffError::HvpNotSupported`]. /// Operations that support forward-over-reverse HVP override this method. /// @@ -313,18 +336,24 @@ pub trait ReverseRule: Send + Sync { /// /// ```ignore /// // Called internally by hvp(); users rarely call this directly. - /// let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?; + /// let results = rule.pullback_with_tangents( + /// &cotangent, &cotangent_tangent, &|node| tangents_vec[node.index()].as_ref(), + /// )?; /// for (node_id, grad, grad_tangent) in results { /// // grad: standard cotangent for this input /// // grad_tangent: cotangent tangent for HVP /// } /// ``` - fn pullback_with_tangents( + fn pullback_with_tangents<'t>( &self, cotangent: &V::Tangent, cotangent_tangent: &V::Tangent, - ) -> AdResult>> { - let _ = (cotangent, cotangent_tangent); + input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>, + ) -> AdResult>> + where + V::Tangent: 't, + { + let _ = (cotangent, cotangent_tangent, input_tangents); Err(AutodiffError::HvpNotSupported) } } diff --git a/crates/chainrules-core/tests/core_tests.rs b/crates/chainrules-core/tests/core_tests.rs index 9ddb6fa..42aff46 100644 --- a/crates/chainrules-core/tests/core_tests.rs +++ b/crates/chainrules-core/tests/core_tests.rs @@ -287,7 +287,18 @@ fn reverse_rule_inputs() { #[test] fn reverse_rule_default_hvp_returns_error() { let rule = DummyRule; - let result = rule.pullback_with_tangents(&1.0, &1.0); + let result = rule.pullback_with_tangents(&1.0, &1.0, &|_| None); + assert!(result.is_err()); + match result.unwrap_err() { + AutodiffError::HvpNotSupported => {} + other => panic!("expected HvpNotSupported, got {other:?}"), + } +} + +#[test] +fn reverse_rule_default_forward_tangents_returns_error() { + let rule = DummyRule; + let result = rule.forward_tangents(&|_| None); assert!(result.is_err()); match result.unwrap_err() { AutodiffError::HvpNotSupported => {}