Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,52 @@ It intentionally does **not** ship a tape, traced value type, or any other AD
engine runtime. Those live in separate engine crates such as
[`tidu-rs`](https://github.com/tensor4all/tidu-rs).

## Getting Started

Add `chainrules` (or `chainrules-core` alone if you only need the traits) as a
git dependency:

```toml
[dependencies]
chainrules = { git = "https://github.com/tensor4all/chainrules-rs" }

# For complex scalar support, also add:
num-complex = "0.4"
```

### Using scalar rules

```rust
use chainrules::{exp_frule, sin_rrule, powf_frule};

// Forward rule: returns (primal, tangent)
let (y, dy) = exp_frule(1.0_f64, 1.0);
assert!((y - 1.0_f64.exp()).abs() < 1e-12);

// Reverse rule: returns input cotangent
let dx = sin_rrule(0.5_f64, 1.0);
assert!((dx - 0.5_f64.cos()).abs() < 1e-12);
```

### Implementing custom AD types

```rust
use chainrules::{Differentiable, ReverseRule, AdResult, NodeId};

#[derive(Clone)]
struct MyVec(Vec<f64>);

impl Differentiable for MyVec {
type Tangent = MyVec;
fn zero_tangent(&self) -> MyVec { MyVec(vec![0.0; self.0.len()]) }
fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
}
fn num_elements(&self) -> usize { self.0.len() }
fn seed_cotangent(&self) -> MyVec { MyVec(vec![1.0; self.0.len()]) }
}
```

## Design Goals

- Keep differentiation rules reusable across projects and AD engines
Expand Down
12 changes: 8 additions & 4 deletions crates/chainrules-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//!
//! Implementing `Differentiable` for a custom type:
//!
//! ```ignore
//! ```
//! use chainrules_core::Differentiable;
//!
//! #[derive(Clone)]
Expand Down Expand Up @@ -61,11 +61,9 @@
///
/// # Examples
///
/// ```ignore
/// ```
/// use chainrules_core::Differentiable;
///
/// // Tensor<f64> implements Differentiable with Tangent = Tensor<f64>
/// // (defined in tenferro-tensor crate)
/// fn example<V: Differentiable>(x: &V) {
/// let zero = x.zero_tangent();
/// let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
Expand Down Expand Up @@ -257,6 +255,9 @@ pub enum SavePolicy {
///
/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
///
/// Implementors must be `Send + Sync` because rule objects may be stored on an
/// AD tape that is shared across threads.
///
/// # Examples
///
/// Custom reverse rule for scalar multiplication `output = a * b`:
Expand Down Expand Up @@ -334,6 +335,9 @@ pub trait ReverseRule<V: Differentiable>: Send + Sync {
///
/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
///
/// Implementors must be `Send + Sync` because rule objects may be stored on an
/// AD tape that is shared across threads.
///
/// # Examples
///
/// Custom forward rule for scalar multiplication `output = a * b`:
Expand Down
26 changes: 26 additions & 0 deletions crates/chainrules/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,32 @@ repository-local tests:
in repository-local formula tests such as `tests/smooth_basis_tests.rs`
- complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses

## rrule first-argument convention

Some rrule helpers accept the **forward result** as their first argument
(when the derivative can be expressed in terms of the output), while others
accept the **input `x`** (when the derivative depends on the original
input). The parameter name in each function signature tells you which:

| First parameter | Functions |
|-----------------|-----------|
| `result` | `exp_rrule`, `expm1_rrule`, `exp2_rrule`, `exp10_rrule`, `sqrt_rrule`, `cbrt_rrule`, `inv_rrule`, `tanh_rrule`, `tan_rrule` |
| `x` | `log_rrule`, `log1p_rrule`, `log2_rrule`, `log10_rrule`, `sin_rrule`, `cos_rrule`, `sinh_rrule`, `cosh_rrule`, all inverse-trig/hyperbolic rrules, `powf_rrule`, `powi_rrule`, `pow_rrule`, Julia-compat trig/hyperbolic rrules |
| both inputs | `mul_rrule(x, y, …)`, `div_rrule(x, y, …)`, `atan2_rrule(y, x, …)`, `hypot_rrule(x, y, …)`, `min_rrule(x, y, …)`, `max_rrule(x, y, …)` |
| cotangent only | `add_rrule`, `sub_rrule`, `conj_rrule`, `real_rrule`, `imag_rrule` |

## Complex scalar convention

For complex scalars (`Complex64`, `Complex32`):

- **Forward-mode** (frule): uses the standard JVP convention on **C ≅ R²**
- **Reverse-mode** (rrule): uses the **conjugate-Wirtinger** convention for
real-valued losses — gradients include `conj(df/dz)`

For real scalars `conj` is the identity, so the convention is invisible.

See the validation section below for how each convention is tested.

## Examples

```rust
Expand Down
2 changes: 2 additions & 0 deletions crates/chainrules/src/unary/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub fn sqrt_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
}

/// Reverse rule for `sqrt`.
///
/// Takes the forward **result** `sqrt(x)`, not the input `x`.
pub fn sqrt_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
cotangent / (S::from_i32(2) * result.conj())
}
8 changes: 8 additions & 0 deletions crates/chainrules/src/unary/exp_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub fn exp_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
(y, dx * y)
}
/// Reverse rule for `exp`.
///
/// Takes the forward **result** `exp(x)`, not the input `x`.
pub fn exp_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
cotangent * result.conj()
}
Expand All @@ -31,6 +33,8 @@ pub fn expm1_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
(y, dx * scale)
}
/// Reverse rule for `exp(x) - 1`.
///
/// Takes the forward **result** `expm1(x)`, not the input `x`.
pub fn expm1_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
cotangent * (result + one::<S>()).conj()
}
Expand Down Expand Up @@ -71,6 +75,8 @@ pub fn log_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
(y, dy)
}
/// Reverse rule for `log`.
///
/// Takes the original **input** `x`, not the result.
pub fn log_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
cotangent * (one::<S>() / x).conj()
}
Expand All @@ -85,6 +91,8 @@ pub fn log1p_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
(y, dy)
}
/// Reverse rule for `log(1 + x)`.
///
/// Takes the original **input** `x`, not the result.
pub fn log1p_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
cotangent * (one::<S>() / (one::<S>() + x)).conj()
}
Expand Down
6 changes: 6 additions & 0 deletions crates/chainrules/src/unary/hyperbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub fn tanh_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
}

/// Reverse rule for `tanh`.
///
/// Takes the forward **result** `tanh(x)`, not the input `x`.
pub fn tanh_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
cotangent * (one::<S>() - result * result).conj()
}
Expand All @@ -30,6 +32,8 @@ pub fn sinh_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
}

/// Reverse rule for `sinh`.
///
/// Takes the original **input** `x`, not the result.
pub fn sinh_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
cotangent * x.cosh().conj()
}
Expand All @@ -46,6 +50,8 @@ pub fn cosh_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
}

/// Reverse rule for `cosh`.
///
/// Takes the original **input** `x`, not the result.
pub fn cosh_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
cotangent * x.sinh().conj()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/chainrules/src/unary/trig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub fn sin_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
}

/// Reverse rule for `sin`.
///
/// Takes the original **input** `x`, not the result.
pub fn sin_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
cotangent * x.cos().conj()
}
Expand All @@ -29,6 +31,8 @@ pub fn cos_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
}

/// Reverse rule for `cos`.
///
/// Takes the original **input** `x`, not the result.
pub fn cos_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
cotangent * (-x.sin()).conj()
}
Expand Down
Loading