From 8a1d82adabb521558f6bd3d7b638bdd0e52b8337 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Wed, 25 Mar 2026 11:31:47 +0900 Subject: [PATCH 1/2] docs: improve API documentation for library consumers - Enable compile-testing on Differentiable doc examples (was `ignore`) - Document Send+Sync requirement on ReverseRule and ForwardRule traits - Add Wirtinger/CR-calculus convention section for complex scalars - Document rrule first-argument convention (result vs input x) - Clarify result vs x parameter on individual rrule functions - Add Getting Started section to README with dependency and usage examples Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 46 +++++++++++++++++++++++ crates/chainrules-core/src/lib.rs | 12 ++++-- crates/chainrules/README.md | 26 +++++++++++++ crates/chainrules/src/unary/basic.rs | 2 + crates/chainrules/src/unary/exp_log.rs | 8 ++++ crates/chainrules/src/unary/hyperbolic.rs | 6 +++ crates/chainrules/src/unary/trig.rs | 4 ++ 7 files changed, 100 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 78a4c1c..572e44d 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,52 @@ It intentionally does **not** ship a tape, traced value type, or any other AD engine runtime. Those live in separate engine crates such as [`tidu-rs`](https://github.com/tensor4all/tidu-rs). +## Getting Started + +Add `chainrules` (or `chainrules-core` alone if you only need the traits) as a +git dependency: + +```toml +[dependencies] +chainrules = { git = "https://github.com/tensor4all/chainrules-rs" } + +# For complex scalar support, also add: +num-complex = "0.4" +``` + +### Using scalar rules + +```rust +use chainrules::{exp_frule, sin_rrule, powf_frule}; + +// Forward rule: returns (primal, tangent) +let (y, dy) = exp_frule(1.0_f64, 1.0); +assert!((y - 1.0_f64.exp()).abs() < 1e-12); + +// Reverse rule: returns input cotangent +let dx = sin_rrule(0.5_f64, 1.0); +assert!((dx - 0.5_f64.cos()).abs() < 1e-12); +``` + +### Implementing custom AD types + +```rust +use chainrules::{Differentiable, ReverseRule, AdResult, NodeId}; + +#[derive(Clone)] +struct MyVec(Vec); + +impl Differentiable for MyVec { + type Tangent = MyVec; + fn zero_tangent(&self) -> MyVec { MyVec(vec![0.0; self.0.len()]) } + fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec { + MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect()) + } + fn num_elements(&self) -> usize { self.0.len() } + fn seed_cotangent(&self) -> MyVec { MyVec(vec![1.0; self.0.len()]) } +} +``` + ## Design Goals - Keep differentiation rules reusable across projects and AD engines diff --git a/crates/chainrules-core/src/lib.rs b/crates/chainrules-core/src/lib.rs index 883f108..81f2440 100644 --- a/crates/chainrules-core/src/lib.rs +++ b/crates/chainrules-core/src/lib.rs @@ -21,7 +21,7 @@ //! //! Implementing `Differentiable` for a custom type: //! -//! ```ignore +//! ``` //! use chainrules_core::Differentiable; //! //! #[derive(Clone)] @@ -61,11 +61,9 @@ /// /// # Examples /// -/// ```ignore +/// ``` /// use chainrules_core::Differentiable; /// -/// // Tensor implements Differentiable with Tangent = Tensor -/// // (defined in tenferro-tensor crate) /// fn example(x: &V) { /// let zero = x.zero_tangent(); /// let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent()); @@ -257,6 +255,9 @@ pub enum SavePolicy { /// /// The type parameter `V` is the differentiable value type (e.g., `Tensor`). /// +/// Implementors must be `Send + Sync` because rule objects may be stored on an +/// AD tape that is shared across threads. +/// /// # Examples /// /// Custom reverse rule for scalar multiplication `output = a * b`: @@ -334,6 +335,9 @@ pub trait ReverseRule: Send + Sync { /// /// The type parameter `V` is the differentiable value type (e.g., `Tensor`). /// +/// Implementors must be `Send + Sync` because rule objects may be stored on an +/// AD tape that is shared across threads. +/// /// # Examples /// /// Custom forward rule for scalar multiplication `output = a * b`: diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index d3106d2..b1df25a 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -61,6 +61,32 @@ repository-local tests: in repository-local formula tests such as `tests/smooth_basis_tests.rs` - complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses +## rrule first-argument convention + +Some rrule helpers accept the **forward result** as their first argument +(when the derivative can be expressed in terms of the output), while others +accept the **input `x`** (when the derivative depends on the original +input). The parameter name in each function signature tells you which: + +| First parameter | Functions | +|-----------------|-----------| +| `result` | `exp_rrule`, `expm1_rrule`, `exp2_rrule`, `exp10_rrule`, `sqrt_rrule`, `cbrt_rrule`, `inv_rrule`, `tanh_rrule`, `tan_rrule` | +| `x` | `log_rrule`, `log1p_rrule`, `log2_rrule`, `log10_rrule`, `sin_rrule`, `cos_rrule`, `sinh_rrule`, `cosh_rrule`, all inverse-trig/hyperbolic rrules, `powf_rrule`, `powi_rrule`, `pow_rrule`, Julia-compat trig/hyperbolic rrules | +| both inputs | `mul_rrule(x, y, …)`, `div_rrule(x, y, …)`, `atan2_rrule(y, x, …)`, `hypot_rrule(x, y, …)`, `min_rrule(x, y, …)`, `max_rrule(x, y, …)` | +| cotangent only | `add_rrule`, `sub_rrule`, `conj_rrule`, `real_rrule`, `imag_rrule` | + +## Complex scalar convention + +For complex scalars (`Complex64`, `Complex32`): + +- **Forward-mode** (frule): uses the standard JVP convention on **C ≅ R²** +- **Reverse-mode** (rrule): uses the **conjugate-Wirtinger** convention for + real-valued losses — gradients include `conj(df/dz)` + +For real scalars `conj` is the identity, so the convention is invisible. + +See the validation section below for how each convention is tested. + ## Examples ```rust diff --git a/crates/chainrules/src/unary/basic.rs b/crates/chainrules/src/unary/basic.rs index 7b76987..b4f0c93 100644 --- a/crates/chainrules/src/unary/basic.rs +++ b/crates/chainrules/src/unary/basic.rs @@ -28,6 +28,8 @@ pub fn sqrt_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `sqrt`. +/// +/// Takes the forward **result** `sqrt(x)`, not the input `x`. pub fn sqrt_rrule(result: S, cotangent: S) -> S { cotangent / (S::from_i32(2) * result.conj()) } diff --git a/crates/chainrules/src/unary/exp_log.rs b/crates/chainrules/src/unary/exp_log.rs index 7ebabbe..2861939 100644 --- a/crates/chainrules/src/unary/exp_log.rs +++ b/crates/chainrules/src/unary/exp_log.rs @@ -17,6 +17,8 @@ pub fn exp_frule(x: S, dx: S) -> (S, S) { (y, dx * y) } /// Reverse rule for `exp`. +/// +/// Takes the forward **result** `exp(x)`, not the input `x`. pub fn exp_rrule(result: S, cotangent: S) -> S { cotangent * result.conj() } @@ -31,6 +33,8 @@ pub fn expm1_frule(x: S, dx: S) -> (S, S) { (y, dx * scale) } /// Reverse rule for `exp(x) - 1`. +/// +/// Takes the forward **result** `expm1(x)`, not the input `x`. pub fn expm1_rrule(result: S, cotangent: S) -> S { cotangent * (result + one::()).conj() } @@ -71,6 +75,8 @@ pub fn log_frule(x: S, dx: S) -> (S, S) { (y, dy) } /// Reverse rule for `log`. +/// +/// Takes the original **input** `x`, not the result. pub fn log_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / x).conj() } @@ -85,6 +91,8 @@ pub fn log1p_frule(x: S, dx: S) -> (S, S) { (y, dy) } /// Reverse rule for `log(1 + x)`. +/// +/// Takes the original **input** `x`, not the result. pub fn log1p_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / (one::() + x)).conj() } diff --git a/crates/chainrules/src/unary/hyperbolic.rs b/crates/chainrules/src/unary/hyperbolic.rs index 07cfca5..e2291d6 100644 --- a/crates/chainrules/src/unary/hyperbolic.rs +++ b/crates/chainrules/src/unary/hyperbolic.rs @@ -14,6 +14,8 @@ pub fn tanh_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `tanh`. +/// +/// Takes the forward **result** `tanh(x)`, not the input `x`. pub fn tanh_rrule(result: S, cotangent: S) -> S { cotangent * (one::() - result * result).conj() } @@ -30,6 +32,8 @@ pub fn sinh_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `sinh`. +/// +/// Takes the original **input** `x`, not the result. pub fn sinh_rrule(x: S, cotangent: S) -> S { cotangent * x.cosh().conj() } @@ -46,6 +50,8 @@ pub fn cosh_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `cosh`. +/// +/// Takes the original **input** `x`, not the result. pub fn cosh_rrule(x: S, cotangent: S) -> S { cotangent * x.sinh().conj() } diff --git a/crates/chainrules/src/unary/trig.rs b/crates/chainrules/src/unary/trig.rs index e740b1a..4beb137 100644 --- a/crates/chainrules/src/unary/trig.rs +++ b/crates/chainrules/src/unary/trig.rs @@ -13,6 +13,8 @@ pub fn sin_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `sin`. +/// +/// Takes the original **input** `x`, not the result. pub fn sin_rrule(x: S, cotangent: S) -> S { cotangent * x.cos().conj() } @@ -29,6 +31,8 @@ pub fn cos_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `cos`. +/// +/// Takes the original **input** `x`, not the result. pub fn cos_rrule(x: S, cotangent: S) -> S { cotangent * (-x.sin()).conj() } From 7e0323dd473c4fedcdb5e1e0825e168c36c63754 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Wed, 25 Mar 2026 11:36:30 +0900 Subject: [PATCH 2/2] chore: trigger CI