Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 105 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# tidu-rs

`tidu-rs` is a general-purpose, tape-based automatic-differentiation engine.
`tidu-rs` is a general-purpose automatic-differentiation engine with a
value-centered, linearize-first public API.

It originated in the tensor4all stack, but it is designed to work with any
downstream differentiable value type that implements the core AD traits from
Expand All @@ -12,7 +13,7 @@ The name **tidu** comes from the Chinese word **梯度**, written in pinyin as
## Getting Started

Add `tidu` and its companion crate `chainrules` (which provides scalar
differentiation rules such as `powf_rrule` and `sin_frule`) to your
differentiation rules such as `powf_rrule` and `sin_rrule`) to your
`Cargo.toml`:

```toml
Expand All @@ -21,94 +22,136 @@ tidu = { git = "https://github.com/tensor4all/tidu-rs" }
chainrules = { git = "https://github.com/tensor4all/chainrules-rs" }
```

`tidu` re-exports the core traits (`Differentiable`, `ReverseRule`, `NodeId`,
etc.) from `chainrules-core`, so you only need to import `chainrules`
explicitly when you use its scalar rule helpers (e.g. `powf_rrule`,
`powf_frule`).
`tidu` re-exports the core AD traits needed by the normal public surface,
including `Differentiable`, `AdResult`, and `AutodiffError`. The intended
public extension points are `Value`, `LinearizableOp`, `LinearizedOp`,
`Schema`, `SlotSchema`, `CheckpointMode`, `AdExecutionPolicy`,
`CheckpointHint`, and `with_ad_policy(...)`.

## Quick Example

Compute the gradient of f(x) = x³ at x = 2 using reverse-mode AD:
Compute the gradient of `f(x) = x^3` at `x = 2` and inspect a local
directional derivative from the same linearized object:

```rust
use chainrules::powf_rrule;
use tidu::{AdResult, NodeId, ReverseRule, Tape};
use tidu::{LinearizableOp, LinearizedOp, Schema, SlotSchema, Value};

// 1. Define a reverse rule for f(x) = x^exponent.
struct PowfRule { input: NodeId, x: f64, exponent: f64 }
#[derive(Clone, Copy)]
struct Cube;

impl ReverseRule<f64> for PowfRule {
fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
Ok(vec![(self.input, powf_rrule(self.x, self.exponent, *cotangent))])
struct CubeLinearized {
x: f64,
}

impl LinearizedOp<f64> for CubeLinearized {
fn jvp(&self, input_tangents: &[Option<f64>]) -> tidu::AdResult<Vec<Option<f64>>> {
Ok(vec![input_tangents[0].map(|dx| 3.0 * self.x * self.x * dx)])
}

fn vjp(
&self,
output_cotangents: &[Option<f64>],
input_grad_mask: &[bool],
) -> tidu::AdResult<Vec<Option<f64>>> {
assert_eq!(input_grad_mask, &[true]);
let g = output_cotangents[0].unwrap_or(0.0);
Ok(vec![Some(3.0 * self.x * self.x * g)])
}
}

impl LinearizableOp<f64> for Cube {
type Linearized = CubeLinearized;

fn primal(&self, inputs: &[&f64]) -> tidu::AdResult<Vec<f64>> {
Ok(vec![*inputs[0] * *inputs[0] * *inputs[0]])
}

fn input_schema(&self, _inputs: &[&f64]) -> tidu::AdResult<Schema> {
Ok(Schema {
slots: vec![SlotSchema {
differentiable: true,
auxiliary: false,
}],
})
}

fn output_schema(&self, _inputs: &[&f64], _outputs: &[f64]) -> tidu::AdResult<Schema> {
Ok(Schema {
slots: vec![SlotSchema {
differentiable: true,
auxiliary: false,
}],
})
}

fn linearize(
&self,
inputs: &[&f64],
_outputs: &[f64],
) -> tidu::AdResult<Self::Linearized> {
Ok(CubeLinearized { x: *inputs[0] })
}
fn inputs(&self) -> Vec<NodeId> { vec![self.input] }
}

// 2. Build the computation graph.
let tape = Tape::<f64>::new();
let x = tape.leaf(2.0);
let y = tape.record_op(
8.0, // forward value: 2^3
Box::new(PowfRule { input: x.node_id().unwrap(), x: 2.0, exponent: 3.0 }),
None, // no tangent (only for HVP)
);

// 3. Run reverse-mode pullback.
let grads = tape.pullback(&y).unwrap();
assert_eq!(*grads.get(x.node_id().unwrap()).unwrap(), 12.0); // dy/dx = 3·2² = 12
let x = Value::new(2.0).with_requires_grad(true);
let y = Cube.apply_one(&[&x]).unwrap();
y.backward().unwrap();
assert_eq!(x.grad().unwrap().unwrap(), 12.0);

let lin = Cube.linearize(&[x.primal()], &[*y.primal()]).unwrap();
assert_eq!(lin.jvp(&[Some(1.0)]).unwrap(), vec![Some(12.0)]);
```

See the [crate-level rustdoc](https://tensor4all.org/tidu-rs/tidu/) for
forward-mode, HVP, and custom-type examples.
Checkpointing is controlled with a small public policy scope:

## Checkpointed Ops
```rust
use tidu::{AdExecutionPolicy, CheckpointMode, with_ad_policy};

`tidu` supports two reverse-mode recording styles:
let policy = AdExecutionPolicy {
checkpoint_mode: CheckpointMode::Conservative,
};

- `Tape::record_op(...)` retains the materialized reverse rule on the tape.
- `Tape::record_checkpointed_op(...)` stores a lightweight replay recipe and
rebuilds the reverse rule lazily during pullback or HVP.
with_ad_policy(policy, || -> tidu::AdResult<()> {
// Record and differentiate values inside this scope.
Ok(())
})
.unwrap();
```

Retained primals are shared between the tape and attached `TrackedValue`
handles, so `record_op` and `leaf` no longer need `V: Clone` just to preserve
forward values for replay.
`CheckpointHint` is an advanced retain-vs-replay hint for custom ops. Most
downstream code only needs `CheckpointMode`, `AdExecutionPolicy`, and
`with_ad_policy(...)`.

That tradeoff is similar in spirit to activation checkpointing, but the API is
node-oriented rather than a whole-region wrapper.
See the [crate-level rustdoc](https://tensor4all.org/tidu-rs/tidu/) for
`Value`, `LinearizableOp`, `LinearizedOp`, and checkpoint policy examples.

## Architecture

```text
┌─────────────────────────────────────────────────────┐
Tape<V> │
Shared, ref-counted autograd graph.
Records leaves and operations as graph nodes.
Value<V> │
Public value handle for eager reverse-mode AD.
Exposes with_requires_grad, backward, and grad().
├─────────────────────────────────────────────────────┤
TrackedValue<V>
A primal view + NodeId + a ref to the Tape.
Retained primals are shared with the graph.
LinearizableOp<V> │
High-level custom op API: primal + linearize.
The normal extension path for downstream users.
├─────────────────────────────────────────────────────┤
Gradients<V>
Leaf-only gradient map returned by pullback.
Look up by NodeId: grads.get(node_id).
LinearizedOp<V> │
Shared first-order object exposing jvp + vjp.
Retained or replayed internally by the runtime.
├─────────────────────────────────────────────────────┤
│ DualValue<V> │
│ Primal + tangent pair for forward-mode AD. │
│ Independent of the tape — no graph involved. │
│ CheckpointMode / AdExecutionPolicy scope │
│ Small public policy surface over retain/replay. │
└─────────────────────────────────────────────────────┘

Traits (from chainrules-core, re-exported by tidu):
Differentiable — tangent algebra for a value type
ReverseRule<V> — pullback logic for one operation
CheckpointRecipe — lazy replay spec for one checkpointed operation
```

## What Lives Here

- `tidu`: reverse-mode tape execution and dual-number forward mode
- `TrackedValue` and `DualValue`
- checkpoint replay, pullback planning, gradient extraction, and
Hessian-vector-product support
- `tidu`: eager reverse mode with a linearize-first core
- `Value`, `LinearizableOp`, and `LinearizedOp`
- checkpoint policy scope via `CheckpointMode`, `AdExecutionPolicy`, and `with_ad_policy(...)`
- retained or replayed linearizations kept internal to the runtime

## Layering

Expand All @@ -126,7 +169,8 @@ That split is deliberate:
## Design Goals

- Keep the engine generic over downstream differentiable value types
- Preserve strict layering between rules and runtime execution
- Keep the normal public API torch-like and value-centered
- Make future forward-on-reverse straightforward without exposing higher-order execution now
- Prefer root-cause fixes, DRY abstractions, and small focused modules

## Testing
Expand Down
2 changes: 1 addition & 1 deletion crates/tidu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition.workspace = true
license.workspace = true
authors.workspace = true
publish.workspace = true
description = "Tape-based automatic-differentiation engine built on chainrules-core."
description = "Linearize-first automatic-differentiation engine built on chainrules-core."

[dependencies]
chainrules-core.workspace = true
Expand Down
81 changes: 81 additions & 0 deletions crates/tidu/src/checkpoint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use std::cell::RefCell;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckpointMode {
Off,
Conservative,
Aggressive,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdExecutionPolicy {
pub checkpoint_mode: CheckpointMode,
}

impl Default for AdExecutionPolicy {
fn default() -> Self {
Self {
checkpoint_mode: CheckpointMode::Off,
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum StorageDecision {
Retain,
Replay,
}

/// Public hint used by [`crate::LinearizableOp::checkpoint_hint`] to guide
/// retain-vs-replay policy decisions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckpointHint {
CheapReplay,
ExpensiveReplay,
MustRetain,
}

thread_local! {
static POLICY_STACK: RefCell<Vec<AdExecutionPolicy>> =
RefCell::new(vec![AdExecutionPolicy::default()]);
}

struct PolicyScopeGuard;

impl PolicyScopeGuard {
fn push(policy: AdExecutionPolicy) -> Self {
POLICY_STACK.with(|stack| stack.borrow_mut().push(policy));
Self
}
}

impl Drop for PolicyScopeGuard {
fn drop(&mut self) {
POLICY_STACK.with(|stack| {
let popped = stack.borrow_mut().pop();
debug_assert!(popped.is_some());
});
}
}

pub fn with_ad_policy<R>(policy: AdExecutionPolicy, f: impl FnOnce() -> R) -> R {
let _guard = PolicyScopeGuard::push(policy);
f()
}

pub(crate) fn current_ad_policy() -> AdExecutionPolicy {
POLICY_STACK.with(|stack| stack.borrow().last().copied().unwrap_or_default())
}

pub(crate) fn storage_decision(
policy: AdExecutionPolicy,
checkpoint_hint: CheckpointHint,
) -> StorageDecision {
match (policy.checkpoint_mode, checkpoint_hint) {
(_, CheckpointHint::MustRetain) => StorageDecision::Retain,
(CheckpointMode::Off, _) => StorageDecision::Retain,
(CheckpointMode::Conservative, CheckpointHint::CheapReplay) => StorageDecision::Replay,
(CheckpointMode::Conservative, CheckpointHint::ExpensiveReplay) => StorageDecision::Retain,
(CheckpointMode::Aggressive, _) => StorageDecision::Replay,
}
}
Loading
Loading