Skip to content

Commit 43ea7ff

Browse files
shinaokaclaude
andauthored
docs: improve API documentation for library consumers (#5)
* docs: improve API documentation for library consumers - Enable compile-testing on Differentiable doc examples (was `ignore`) - Document Send+Sync requirement on ReverseRule and ForwardRule traits - Add Wirtinger/CR-calculus convention section for complex scalars - Document rrule first-argument convention (result vs input x) - Clarify result vs x parameter on individual rrule functions - Add Getting Started section to README with dependency and usage examples Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * chore: trigger CI --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c45a230 commit 43ea7ff

7 files changed

Lines changed: 100 additions & 4 deletions

File tree

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,52 @@ It intentionally does **not** ship a tape, traced value type, or any other AD
1616
engine runtime. Those live in separate engine crates such as
1717
[`tidu-rs`](https://github.com/tensor4all/tidu-rs).
1818

19+
## Getting Started
20+
21+
Add `chainrules` (or `chainrules-core` alone if you only need the traits) as a
22+
git dependency:
23+
24+
```toml
25+
[dependencies]
26+
chainrules = { git = "https://github.com/tensor4all/chainrules-rs" }
27+
28+
# For complex scalar support, also add:
29+
num-complex = "0.4"
30+
```
31+
32+
### Using scalar rules
33+
34+
```rust
35+
use chainrules::{exp_frule, sin_rrule, powf_frule};
36+
37+
// Forward rule: returns (primal, tangent)
38+
let (y, dy) = exp_frule(1.0_f64, 1.0);
39+
assert!((y - 1.0_f64.exp()).abs() < 1e-12);
40+
41+
// Reverse rule: returns input cotangent
42+
let dx = sin_rrule(0.5_f64, 1.0);
43+
assert!((dx - 0.5_f64.cos()).abs() < 1e-12);
44+
```
45+
46+
### Implementing custom AD types
47+
48+
```rust
49+
use chainrules::{Differentiable, ReverseRule, AdResult, NodeId};
50+
51+
#[derive(Clone)]
52+
struct MyVec(Vec<f64>);
53+
54+
impl Differentiable for MyVec {
55+
type Tangent = MyVec;
56+
fn zero_tangent(&self) -> MyVec { MyVec(vec![0.0; self.0.len()]) }
57+
fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
58+
MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
59+
}
60+
fn num_elements(&self) -> usize { self.0.len() }
61+
fn seed_cotangent(&self) -> MyVec { MyVec(vec![1.0; self.0.len()]) }
62+
}
63+
```
64+
1965
## Design Goals
2066

2167
- Keep differentiation rules reusable across projects and AD engines

crates/chainrules-core/src/lib.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
//!
2222
//! Implementing `Differentiable` for a custom type:
2323
//!
24-
//! ```ignore
24+
//! ```
2525
//! use chainrules_core::Differentiable;
2626
//!
2727
//! #[derive(Clone)]
@@ -61,11 +61,9 @@
6161
///
6262
/// # Examples
6363
///
64-
/// ```ignore
64+
/// ```
6565
/// use chainrules_core::Differentiable;
6666
///
67-
/// // Tensor<f64> implements Differentiable with Tangent = Tensor<f64>
68-
/// // (defined in tenferro-tensor crate)
6967
/// fn example<V: Differentiable>(x: &V) {
7068
/// let zero = x.zero_tangent();
7169
/// let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
@@ -257,6 +255,9 @@ pub enum SavePolicy {
257255
///
258256
/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
259257
///
258+
/// Implementors must be `Send + Sync` because rule objects may be stored on an
259+
/// AD tape that is shared across threads.
260+
///
260261
/// # Examples
261262
///
262263
/// Custom reverse rule for scalar multiplication `output = a * b`:
@@ -334,6 +335,9 @@ pub trait ReverseRule<V: Differentiable>: Send + Sync {
334335
///
335336
/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
336337
///
338+
/// Implementors must be `Send + Sync` because rule objects may be stored on an
339+
/// AD tape that is shared across threads.
340+
///
337341
/// # Examples
338342
///
339343
/// Custom forward rule for scalar multiplication `output = a * b`:

crates/chainrules/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,32 @@ repository-local tests:
6161
in repository-local formula tests such as `tests/smooth_basis_tests.rs`
6262
- complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses
6363

64+
## rrule first-argument convention
65+
66+
Some rrule helpers accept the **forward result** as their first argument
67+
(when the derivative can be expressed in terms of the output), while others
68+
accept the **input `x`** (when the derivative depends on the original
69+
input). The parameter name in each function signature tells you which:
70+
71+
| First parameter | Functions |
72+
|-----------------|-----------|
73+
| `result` | `exp_rrule`, `expm1_rrule`, `exp2_rrule`, `exp10_rrule`, `sqrt_rrule`, `cbrt_rrule`, `inv_rrule`, `tanh_rrule`, `tan_rrule` |
74+
| `x` | `log_rrule`, `log1p_rrule`, `log2_rrule`, `log10_rrule`, `sin_rrule`, `cos_rrule`, `sinh_rrule`, `cosh_rrule`, all inverse-trig/hyperbolic rrules, `powf_rrule`, `powi_rrule`, `pow_rrule`, Julia-compat trig/hyperbolic rrules |
75+
| both inputs | `mul_rrule(x, y, …)`, `div_rrule(x, y, …)`, `atan2_rrule(y, x, …)`, `hypot_rrule(x, y, …)`, `min_rrule(x, y, …)`, `max_rrule(x, y, …)` |
76+
| cotangent only | `add_rrule`, `sub_rrule`, `conj_rrule`, `real_rrule`, `imag_rrule` |
77+
78+
## Complex scalar convention
79+
80+
For complex scalars (`Complex64`, `Complex32`):
81+
82+
- **Forward-mode** (frule): uses the standard JVP convention on **C ≅ R²**
83+
- **Reverse-mode** (rrule): uses the **conjugate-Wirtinger** convention for
84+
real-valued losses — gradients include `conj(df/dz)`
85+
86+
For real scalars `conj` is the identity, so the convention is invisible.
87+
88+
See the validation section below for how each convention is tested.
89+
6490
## Examples
6591

6692
```rust

crates/chainrules/src/unary/basic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ pub fn sqrt_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
2828
}
2929

3030
/// Reverse rule for `sqrt`.
31+
///
32+
/// Takes the forward **result** `sqrt(x)`, not the input `x`.
3133
pub fn sqrt_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
3234
cotangent / (S::from_i32(2) * result.conj())
3335
}

crates/chainrules/src/unary/exp_log.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ pub fn exp_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
1717
(y, dx * y)
1818
}
1919
/// Reverse rule for `exp`.
20+
///
21+
/// Takes the forward **result** `exp(x)`, not the input `x`.
2022
pub fn exp_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
2123
cotangent * result.conj()
2224
}
@@ -31,6 +33,8 @@ pub fn expm1_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
3133
(y, dx * scale)
3234
}
3335
/// Reverse rule for `exp(x) - 1`.
36+
///
37+
/// Takes the forward **result** `expm1(x)`, not the input `x`.
3438
pub fn expm1_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
3539
cotangent * (result + one::<S>()).conj()
3640
}
@@ -71,6 +75,8 @@ pub fn log_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
7175
(y, dy)
7276
}
7377
/// Reverse rule for `log`.
78+
///
79+
/// Takes the original **input** `x`, not the result.
7480
pub fn log_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
7581
cotangent * (one::<S>() / x).conj()
7682
}
@@ -85,6 +91,8 @@ pub fn log1p_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
8591
(y, dy)
8692
}
8793
/// Reverse rule for `log(1 + x)`.
94+
///
95+
/// Takes the original **input** `x`, not the result.
8896
pub fn log1p_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
8997
cotangent * (one::<S>() / (one::<S>() + x)).conj()
9098
}

crates/chainrules/src/unary/hyperbolic.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ pub fn tanh_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
1414
}
1515

1616
/// Reverse rule for `tanh`.
17+
///
18+
/// Takes the forward **result** `tanh(x)`, not the input `x`.
1719
pub fn tanh_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
1820
cotangent * (one::<S>() - result * result).conj()
1921
}
@@ -30,6 +32,8 @@ pub fn sinh_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
3032
}
3133

3234
/// Reverse rule for `sinh`.
35+
///
36+
/// Takes the original **input** `x`, not the result.
3337
pub fn sinh_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
3438
cotangent * x.cosh().conj()
3539
}
@@ -46,6 +50,8 @@ pub fn cosh_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
4650
}
4751

4852
/// Reverse rule for `cosh`.
53+
///
54+
/// Takes the original **input** `x`, not the result.
4955
pub fn cosh_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
5056
cotangent * x.sinh().conj()
5157
}

crates/chainrules/src/unary/trig.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ pub fn sin_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
1313
}
1414

1515
/// Reverse rule for `sin`.
16+
///
17+
/// Takes the original **input** `x`, not the result.
1618
pub fn sin_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
1719
cotangent * x.cos().conj()
1820
}
@@ -29,6 +31,8 @@ pub fn cos_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
2931
}
3032

3133
/// Reverse rule for `cos`.
34+
///
35+
/// Takes the original **input** `x`, not the result.
3236
pub fn cos_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
3337
cotangent * (-x.sin()).conj()
3438
}

0 commit comments

Comments
 (0)