From a362ee86290bc27095efa67703def3e05312709f Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Fri, 20 Mar 2026 07:57:35 +0900 Subject: [PATCH] Re-export chainrules-core from chainrules and add Neg bound to ScalarAd Downstream crates can now depend on `chainrules` alone to get core traits (Differentiable, ReverseRule, ForwardRule, etc.). Adding `Neg` to `ScalarAd` removes the `neg_one()` helper and `from_i32(-1) *` workarounds, making negation idiomatic (`-x` instead of `-1 * x`). Co-Authored-By: Claude Opus 4.6 --- crates/chainrules/src/binary.rs | 6 +++--- crates/chainrules/src/lib.rs | 5 +++++ crates/chainrules/src/scalar_ad.rs | 10 ++++++++-- crates/chainrules/src/unary/mod.rs | 4 ---- crates/chainrules/src/unary/trig.rs | 10 +++++----- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/crates/chainrules/src/binary.rs b/crates/chainrules/src/binary.rs index 3c1756c..9d21d4c 100644 --- a/crates/chainrules/src/binary.rs +++ b/crates/chainrules/src/binary.rs @@ -91,7 +91,7 @@ pub fn sub_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { /// assert_eq!(dy, -2.0_f64); /// ``` pub fn sub_rrule(cotangent: S) -> (S, S) { - (cotangent, S::from_i32(-1) * cotangent) + (cotangent, -cotangent) } /// Primal `mul`. @@ -176,7 +176,7 @@ pub fn div_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { let primal = x / y; let inv_y = S::from_i32(1) / y; let dfdx = inv_y.conj(); - let dfdy = (S::from_i32(-1) * x * inv_y * inv_y).conj(); + let dfdy = (-(x * inv_y * inv_y)).conj(); let tangent = dx * dfdx + dy * dfdy; (primal, tangent) } @@ -200,6 +200,6 @@ pub fn div_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { pub fn div_rrule(x: S, y: S, cotangent: S) -> (S, S) { let inv_y = S::from_i32(1) / y; let dfdx = inv_y.conj(); - let dfdy = (S::from_i32(-1) * x * inv_y * inv_y).conj(); + let dfdy = (-(x * inv_y * inv_y)).conj(); (cotangent * dfdx, cotangent * dfdy) } diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 9cb3863..9fb6fa8 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -34,6 +34,11 @@ //! assert_eq!(dx, 12.0); //! ``` +pub use chainrules_core::{ + AdResult, AutodiffError, Differentiable, ForwardRule, NodeId, PullbackEntry, + PullbackWithTangentsEntry, ReverseRule, SavePolicy, +}; + mod binary; mod power; mod real_ops; diff --git a/crates/chainrules/src/scalar_ad.rs b/crates/chainrules/src/scalar_ad.rs index 1cc5870..7b0a6dd 100644 --- a/crates/chainrules/src/scalar_ad.rs +++ b/crates/chainrules/src/scalar_ad.rs @@ -1,4 +1,4 @@ -use core::ops::{Add, Div, Mul, Sub}; +use core::ops::{Add, Div, Mul, Neg, Sub}; use num_complex::{Complex32, Complex64}; use num_traits::Float; @@ -16,7 +16,13 @@ use num_traits::Float; /// takes_scalar(1.0_f64); /// ``` pub trait ScalarAd: - Copy + PartialEq + Add + Sub + Mul + Div + Copy + + PartialEq + + Neg + + Add + + Sub + + Mul + + Div { /// Real exponent type for `powf`. type Real: Copy + Float; diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index b90eed0..27488f4 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -9,10 +9,6 @@ fn one() -> S { S::from_i32(1) } -fn neg_one() -> S { - S::from_i32(-1) -} - pub use basic::{conj, conj_frule, conj_rrule, sqrt, sqrt_frule, sqrt_rrule}; pub use exp_log::{ exp, exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, log, log1p, log1p_frule, diff --git a/crates/chainrules/src/unary/trig.rs b/crates/chainrules/src/unary/trig.rs index 4c6a259..e5cdf3c 100644 --- a/crates/chainrules/src/unary/trig.rs +++ b/crates/chainrules/src/unary/trig.rs @@ -1,4 +1,4 @@ -use crate::unary::{neg_one, one}; +use crate::unary::one; use crate::ScalarAd; /// Primal `sin`. @@ -25,12 +25,12 @@ pub fn cos(x: S) -> S { /// Forward rule for `cos`. pub fn cos_frule(x: S, dx: S) -> (S, S) { let y = x.cos(); - (y, dx * (neg_one::() * x.sin()).conj()) + (y, dx * (-x.sin()).conj()) } /// Reverse rule for `cos`. pub fn cos_rrule(x: S, cotangent: S) -> S { - cotangent * (neg_one::() * x.sin()).conj() + cotangent * (-x.sin()).conj() } fn inverse_sqrt_one_minus_square(x: S) -> S { @@ -62,13 +62,13 @@ pub fn acos(x: S) -> S { /// Forward rule for `acos`. pub fn acos_frule(x: S, dx: S) -> (S, S) { let y = x.acos(); - let scale = neg_one::() * inverse_sqrt_one_minus_square(x); + let scale = -inverse_sqrt_one_minus_square(x); (y, dx * scale.conj()) } /// Reverse rule for `acos`. pub fn acos_rrule(x: S, cotangent: S) -> S { - cotangent * (neg_one::() * inverse_sqrt_one_minus_square(x)).conj() + cotangent * (-inverse_sqrt_one_minus_square(x)).conj() } /// Primal `atan`.