diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..7c436b4 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,190 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to the Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by the Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding any notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +Copyright 2026 Hiroshi Shinaoka, GiggleLiu, and contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..98eeb52 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Hiroshi Shinaoka, GiggleLiu, and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7b8569b..78a4c1c 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ specific AD engine. It contains: -- `chainrules-core`: core AD traits and error types -- `chainrules`: reusable scalar `frule`/`rrule` helpers and related utilities +- `chainrules-core`: engine-independent AD protocol +- `chainrules`: shared scalar rule basis It intentionally does **not** ship a tape, traced value type, or any other AD engine runtime. Those live in separate engine crates such as @@ -25,9 +25,9 @@ engine runtime. Those live in separate engine crates such as ## Repository Layout -- [`crates/chainrules-core`](crates/chainrules-core): `Differentiable`, - `ReverseRule`, `ForwardRule`, `AutodiffError`, and related core types -- [`crates/chainrules`](crates/chainrules): scalar rule implementations such as +- [`crates/chainrules-core`](crates/chainrules-core): protocol-only crate for + `Differentiable`, `ReverseRule`, `ForwardRule`, and `AutodiffError` +- [`crates/chainrules`](crates/chainrules): shared scalar rules such as `exp`, `log1p`, `sin`, `atanh`, `powf`, and `atan2` - [`third_party/tensor-ad-oracles`](third_party/tensor-ad-oracles): vendored oracle data used to validate scalar rules against published references @@ -48,8 +48,33 @@ engine that executes those rules over a tape. The boundary is deliberate: - `tidu-rs` can evolve independently as an engine - downstream tensor libraries can swap engines without rewriting scalar rules +## Crate Roles + +`chainrules-core` does not provide function rules. +`chainrules` provides stateless scalar `foo`, `foo_frule`, and `foo_rrule` +helpers. + +`chainrules` is a landing zone for scalar rules ported or adapted from Julia's +`ChainRules.jl` where they fit this repository boundary, but `chainrules-rs` is +not a full port of `ChainRules.jl`. + +See the crate READMEs for the supported scalar function inventory and examples. + ## Testing +Scalar rules are checked in complementary ways: + +- formula and behavior tests in `crates/chainrules/tests/scalarops_tests.rs` +- compatibility and edge-case tests such as + `crates/chainrules/tests/julia_compat_trig_tests.rs` and + `crates/chainrules/tests/complex_helper_tests.rs` +- oracle replay tests in `crates/chainrules/tests/oracle_scalar_rules.rs` + against vendored published cases from `third_party/tensor-ad-oracles`, + including direct float64 replay and selected direct Complex64 replay for + `tan`, `exp2`, and `log2`; complex + forward-mode checks use the standard JVP convention on `C ~= R^2`, while + complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses + ```bash cargo test --workspace --release cargo llvm-cov --workspace --json --output-path coverage.json diff --git a/crates/chainrules-core/README.md b/crates/chainrules-core/README.md new file mode 100644 index 0000000..d288b82 --- /dev/null +++ b/crates/chainrules-core/README.md @@ -0,0 +1,31 @@ +# chainrules-core + +`chainrules-core` defines the engine-independent AD protocol used by +`chainrules-rs`. + +It is intentionally small and does not provide function rules. The crate exists +to define the traits and error types that downstream AD engines and rule +libraries build on. + +## What It Provides + +- `Differentiable` +- `ReverseRule` +- `ForwardRule` +- `AutodiffError` +- `NodeId` +- `SavePolicy` + +## Example + +```rust +use chainrules_core::NodeId; + +let id = NodeId::new(7); +assert_eq!(id.index(), 7); +``` + +## Notes + +This crate is protocol-only. Shared scalar `frule` and `rrule` helpers live in +[`chainrules`](../chainrules). diff --git a/crates/chainrules-core/src/lib.rs b/crates/chainrules-core/src/lib.rs index ddcde00..883f108 100644 --- a/crates/chainrules-core/src/lib.rs +++ b/crates/chainrules-core/src/lib.rs @@ -1,3 +1,5 @@ +#![doc = include_str!("../README.md")] + //! Core AD trait definitions (like Julia's ChainRulesCore.jl). //! //! This crate defines the interface for automatic differentiation without diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md new file mode 100644 index 0000000..d3106d2 --- /dev/null +++ b/crates/chainrules/README.md @@ -0,0 +1,86 @@ +# chainrules + +`chainrules` provides a shared scalar rule basis for Rust automatic +differentiation crates. + +It is designed for reusable scalar calculus, not for tapes, traced values, or +tensor-specific execution engines. The crate focuses on stateless helpers that +can be called from downstream AD runtimes and tensor libraries. + +## What It Provides + +- stateless scalar primal helpers +- stateless scalar `foo_frule` helpers +- stateless scalar `foo_rrule` helpers +- real/complex projection helpers for common scalar formulas + +Supported scalar domains: + +- `f32` +- `f64` +- `Complex32` +- `Complex64` + +## Supported Functions + +Current shipped scalar families: + +- arithmetic: `add`, `sub`, `mul`, `div` +- powers and roots: `powf`, `powi`, `sqrt` +- exponentials and logs: `exp`, `expm1`, `log`, `log1p` +- trigonometric: `sin`, `cos`, `asin`, `acos`, `atan` +- hyperbolic: `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` +- Julia-compatible trigonometric helpers: `sec`, `csc`, `cot`, `sinpi`, `cospi`, `sincospi`, `sind`, `cosd`, `tand` +- Julia-compatible hyperbolic helpers: `sech`, `csch`, `coth` +- non-smooth real helpers: `round`, `floor`, `ceil`, `sign`, `min`, `max` +- smooth helpers: `cbrt`, `inv`, `exp2`, `exp10`, `log2`, `log10`, `hypot`, `pow`, `sincos`, `tan` +- complex and projection helpers: `conj`, `abs`, `abs2`, `angle`, `real`, `imag`, `complex` +- real-valued binary helpers: `atan2` + +This crate is a landing zone for scalar rules ported or adapted from Julia's +`ChainRules.jl` where they fit this crate boundary, but it is not a full port +of `ChainRules.jl`. + +## Validation + +Rules in this crate are not accepted on provenance alone. They are checked with +repository-local tests: + +- `tests/scalarops_tests.rs` covers direct formulas, edge cases, and smooth + real/complex behavior +- `tests/julia_compat_trig_tests.rs` covers Julia migration helpers, including + landmark real inputs and representative Complex64 behavior +- `tests/nonsmooth_scalar_tests.rs` covers the documented zero-gradient and + tie-routing policies for non-smooth helpers +- `tests/complex_helper_tests.rs` covers the projection helpers and complex + constructor surface +- `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from + `../../third_party/tensor-ad-oracles`, with direct float64 replay and + selected direct Complex64 replay for `tan`, `exp2`, and `log2` +- complex forward-mode checks use the standard JVP convention on `C ~= R^2` + in repository-local formula tests such as `tests/smooth_basis_tests.rs` +- complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses + +## Examples + +```rust +use chainrules::{powf, powf_frule, powf_rrule}; + +let y = powf(2.0_f64, 3.0_f64); +assert_eq!(y, 8.0_f64); + +let (y, dy) = powf_frule(2.0_f64, 3.0_f64, 1.0_f64); +assert_eq!(y, 8.0_f64); +assert_eq!(dy, 12.0_f64); + +let dx = powf_rrule(2.0_f64, 3.0_f64, 1.0_f64); +assert_eq!(dx, 12.0_f64); +``` + +## Notes + +This crate is the landing zone for shared scalar rule logic, including +Julia-style convenience functions when they help migration. + +It does not define tensor, array, broadcast, reduction, or engine-specific +rules. diff --git a/crates/chainrules/src/binary.rs b/crates/chainrules/src/binary.rs index 9d21d4c..9ce701a 100644 --- a/crates/chainrules/src/binary.rs +++ b/crates/chainrules/src/binary.rs @@ -122,7 +122,7 @@ pub fn mul(x: S, y: S) -> S { /// ``` pub fn mul_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { let primal = x * y; - let tangent = dx * y.conj() + dy * x.conj(); + let tangent = dx * y + dy * x; (primal, tangent) } @@ -175,8 +175,8 @@ pub fn div(x: S, y: S) -> S { pub fn div_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { let primal = x / y; let inv_y = S::from_i32(1) / y; - let dfdx = inv_y.conj(); - let dfdy = (-(x * inv_y * inv_y)).conj(); + let dfdx = inv_y; + let dfdy = -(x * inv_y * inv_y); let tangent = dx * dfdx + dy * dfdy; (primal, tangent) } diff --git a/crates/chainrules/src/binary_special.rs b/crates/chainrules/src/binary_special.rs new file mode 100644 index 0000000..98fa04d --- /dev/null +++ b/crates/chainrules/src/binary_special.rs @@ -0,0 +1,180 @@ +use num_traits::Float; + +fn select_first_for_min(x: R, y: R) -> bool { + !x.is_nan() && (y.is_nan() || x < y) +} + +fn select_first_for_max(x: R, y: R) -> bool { + !x.is_nan() && (y.is_nan() || x > y) +} + +/// Primal `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot; +/// +/// assert_eq!(hypot(3.0_f64, 4.0_f64), 5.0); +/// ``` +pub fn hypot(x: R, y: R) -> R { + x.hypot(y) +} + +/// Forward rule for `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot_frule; +/// +/// let (r, dr) = hypot_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64); +/// assert_eq!(r, 5.0); +/// assert!((dr - 0.5).abs() < 1e-12); +/// ``` +pub fn hypot_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let r = x.hypot(y); + let inv_r = R::one() / r; + (r, dx * (x * inv_r) + dy * (y * inv_r)) +} + +/// Reverse rule for `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot_rrule; +/// +/// let (dx, dy) = hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64); +/// assert!((dx - 0.6).abs() < 1e-12); +/// assert!((dy - 0.8).abs() < 1e-12); +/// ``` +pub fn hypot_rrule(x: R, y: R, cotangent: R) -> (R, R) { + let r = x.hypot(y); + let inv_r = R::one() / r; + (cotangent * (x * inv_r), cotangent * (y * inv_r)) +} + +/// Primal `min`. +/// +/// The primal follows `Float::min`. For differentiation, ties route the +/// tangent/cotangent to the second argument. If exactly one input is `NaN`, +/// the non-`NaN` input receives the gradient. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min; +/// +/// assert_eq!(min(1.5_f64, 2.5_f64), 1.5); +/// assert_eq!(min(2.0_f64, 2.0_f64), 2.0); +/// ``` +pub fn min(x: R, y: R) -> R { + x.min(y) +} + +/// Forward rule for `min`. +/// +/// When `x == y`, the tangent comes from `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min_frule; +/// +/// let (z, dz) = min_frule(1.0_f64, 2.0_f64, 0.25, 0.5); +/// assert_eq!(z, 1.0); +/// assert_eq!(dz, 0.25); +/// ``` +pub fn min_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let z = x.min(y); + if select_first_for_min(x, y) { + (z, dx) + } else { + (z, dy) + } +} + +/// Reverse rule for `min`. +/// +/// When `x == y`, the cotangent goes to `y`. If exactly one input is `NaN`, +/// the non-`NaN` input receives the cotangent. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min_rrule; +/// +/// let (dx, dy) = min_rrule(1.0_f64, 2.0_f64, 0.5); +/// assert_eq!(dx, 0.5); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn min_rrule(x: R, y: R, cotangent: R) -> (R, R) { + if select_first_for_min(x, y) { + (cotangent, R::zero()) + } else { + (R::zero(), cotangent) + } +} + +/// Primal `max`. +/// +/// The primal follows `Float::max`. For differentiation, ties route the +/// tangent/cotangent to the second argument. If exactly one input is `NaN`, +/// the non-`NaN` input receives the gradient. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max; +/// +/// assert_eq!(max(1.5_f64, 2.5_f64), 2.5); +/// assert_eq!(max(2.0_f64, 2.0_f64), 2.0); +/// ``` +pub fn max(x: R, y: R) -> R { + x.max(y) +} + +/// Forward rule for `max`. +/// +/// When `x == y`, the tangent comes from `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max_frule; +/// +/// let (z, dz) = max_frule(1.0_f64, 2.0_f64, 0.25, 0.5); +/// assert_eq!(z, 2.0); +/// assert_eq!(dz, 0.5); +/// ``` +pub fn max_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let z = x.max(y); + if select_first_for_max(x, y) { + (z, dx) + } else { + (z, dy) + } +} + +/// Reverse rule for `max`. +/// +/// When `x == y`, the cotangent goes to `y`. If exactly one input is `NaN`, +/// the non-`NaN` input receives the cotangent. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max_rrule; +/// +/// let (dx, dy) = max_rrule(1.0_f64, 2.0_f64, 0.5); +/// assert_eq!(dx, 0.0); +/// assert_eq!(dy, 0.5); +/// ``` +pub fn max_rrule(x: R, y: R, cotangent: R) -> (R, R) { + if select_first_for_max(x, y) { + (cotangent, R::zero()) + } else { + (R::zero(), cotangent) + } +} diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 9fb6fa8..c326045 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -1,38 +1,4 @@ -//! Engine-independent scalar AD helper rules for elementary operations. -//! -//! This crate provides stateless primal/frule/rrule helpers for scalar -//! operations. It is designed to be shared across AD engines and higher-level -//! tensor libraries. -//! -//! Public helper families: -//! -//! - `add`, `sub`, `mul`, `div` -//! - `conj` -//! - `sqrt` -//! - `exp`, `expm1`, `log`, `log1p` -//! - `sin`, `cos` -//! - `sinh`, `cosh` -//! - `asin`, `acos`, `atan`, `asinh`, `acosh`, `atanh` -//! - `powf` (fixed real exponent) -//! - `powi` (fixed integer exponent) -//! - `atan2` (real scalars) -//! -//! The [`ScalarAd`] trait also provides the scalar method basis used by the -//! higher-level wrappers, including `expm1`, `log1p`, `sin`, `cos`, and -//! `tanh`. -//! -//! # Examples -//! -//! ```rust -//! use chainrules::{powf_frule, powf_rrule}; -//! -//! let (y, dy) = powf_frule(2.0_f64, 3.0, 1.0); -//! assert_eq!(y, 8.0); -//! assert_eq!(dy, 12.0); -//! -//! let dx = powf_rrule(2.0_f64, 3.0, 1.0); -//! assert_eq!(dx, 12.0); -//! ``` +#![doc = include_str!("../README.md")] pub use chainrules_core::{ AdResult, AutodiffError, Differentiable, ForwardRule, NodeId, PullbackEntry, @@ -40,6 +6,7 @@ pub use chainrules_core::{ }; mod binary; +mod binary_special; mod power; mod real_ops; mod scalar_ad; @@ -51,19 +18,31 @@ pub use binary::{ sub_frule, sub_rrule, }; #[doc(inline)] +pub use binary_special::{max, max_frule, max_rrule, min, min_frule, min_rrule}; +#[doc(inline)] pub use power::{powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule}; #[doc(inline)] pub use real_ops::{atan2, atan2_frule, atan2_rrule}; #[doc(inline)] -pub use scalar_ad::{handle_r_to_c_f32, handle_r_to_c_f64, ScalarAd}; +pub use scalar_ad::ScalarAd; #[doc(inline)] pub use unary::{ - acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, - asinh, asinh_frule, asinh_rrule, atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, - conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, cosh_frule, cosh_rrule, exp, - exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, log, log1p, log1p_frule, log1p_rrule, - log_frule, log_rrule, sin, sin_frule, sin_rrule, sinh, sinh_frule, sinh_rrule, sqrt, - sqrt_frule, sqrt_rrule, tanh, tanh_frule, tanh_rrule, + abs, abs2, abs2_frule, abs2_rrule, acos, acos_frule, acos_rrule, acosh, acosh_frule, + acosh_rrule, angle, angle_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, + atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, cbrt, cbrt_frule, cbrt_rrule, + ceil, ceil_frule, ceil_rrule, complex, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, + cosd, cosd_frule, cosd_rrule, cosh, cosh_frule, cosh_rrule, cospi, cospi_frule, cospi_rrule, + cot, cot_frule, cot_rrule, coth, coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, + csch_frule, csch_rrule, exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, + exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, floor, floor_frule, floor_rrule, hypot, + hypot_frule, hypot_rrule, imag, imag_rrule, inv, inv_frule, inv_rrule, log, log10, log10_frule, + log10_rrule, log1p, log1p_frule, log1p_rrule, log2, log2_frule, log2_rrule, log_frule, + log_rrule, pow, pow_frule, pow_rrule, real, real_rrule, round, round_frule, round_rrule, sec, + sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sign, sign_frule, sign_rrule, sin, + sin_frule, sin_rrule, sincos, sincos_frule, sincos_rrule, sincospi, sincospi_frule, + sincospi_rrule, sind, sind_frule, sind_rrule, sinh, sinh_frule, sinh_rrule, sinpi, sinpi_frule, + sinpi_rrule, sqrt, sqrt_frule, sqrt_rrule, tan, tan_frule, tan_rrule, tand, tand_frule, + tand_rrule, tanh, tanh_frule, tanh_rrule, }; #[cfg(test)] diff --git a/crates/chainrules/src/power.rs b/crates/chainrules/src/power.rs index d9aa42c..9a190f6 100644 --- a/crates/chainrules/src/power.rs +++ b/crates/chainrules/src/power.rs @@ -1,29 +1,19 @@ -use num_traits::{One, Zero}; +use num_traits::{Float, One, Zero}; use crate::ScalarAd; /// Primal `powf`. -/// -/// # Examples -/// /// ```rust /// use chainrules::powf; -/// /// assert_eq!(powf(2.0_f64, 3.0), 8.0); /// ``` pub fn powf(x: S, exponent: S::Real) -> S { x.powf(exponent) } -/// Forward rule for `powf` with fixed exponent. -/// -/// Returns `(primal, tangent)`. -/// -/// # Examples -/// +/// Forward rule for `powf` with fixed exponent. Returns `(primal, tangent)`. /// ```rust /// use chainrules::powf_frule; -/// /// let (y, dy) = powf_frule(2.0_f64, 3.0, 1.0); /// assert_eq!(y, 8.0); /// assert_eq!(dy, 12.0); @@ -33,7 +23,7 @@ pub fn powf_frule(x: S, exponent: S::Real, dx: S) -> (S, S) { let dy = if exponent == S::Real::zero() { S::from_real(S::Real::zero()) } else { - dx * (S::from_real(exponent) * x.powf(exponent - S::Real::one())).conj() + dx * (S::from_real(exponent) * x.powf(exponent - S::Real::one())) }; (y, dy) } @@ -56,27 +46,17 @@ pub fn powf_rrule(x: S, exponent: S::Real, cotangent: S) -> S { } /// Primal `powi`. -/// -/// # Examples -/// /// ```rust /// use chainrules::powi; -/// /// assert_eq!(powi(2.0_f64, 4), 16.0); /// ``` pub fn powi(x: S, exponent: i32) -> S { x.powi(exponent) } -/// Forward rule for `powi` with fixed integer exponent. -/// -/// Returns `(primal, tangent)`. -/// -/// # Examples -/// +/// Forward rule for `powi` with fixed integer exponent. Returns `(primal, tangent)`. /// ```rust /// use chainrules::powi_frule; -/// /// let (y, dy) = powi_frule(2.0_f64, 4, 1.0); /// assert_eq!(y, 16.0); /// assert_eq!(dy, 32.0); @@ -86,7 +66,7 @@ pub fn powi_frule(x: S, exponent: i32, dx: S) -> (S, S) { let dy = if exponent == 0 { S::from_i32(0) } else { - dx * (S::from_i32(exponent) * x.powi(exponent - 1)).conj() + dx * (S::from_i32(exponent) * x.powi(exponent - 1)) }; (y, dy) } @@ -107,3 +87,74 @@ pub fn powi_rrule(x: S, exponent: i32, cotangent: S) -> S { } cotangent * (S::from_i32(exponent) * x.powi(exponent - 1)).conj() } + +#[doc = "Primal `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow;\n\nassert_eq!(pow(2.0_f64, 3.0_f64), 8.0);\n```"] +pub fn pow(x: S, exponent: S) -> S { + x.pow(exponent) +} +fn zero() -> S { + S::from_i32(0) +} +fn nan() -> S { + S::from_real(S::Real::nan()) +} +fn pow_x_scale(x: S, exponent: S) -> S { + if exponent == zero::() { + zero::() + } else { + (exponent * x.pow(exponent - S::from_i32(1))).conj() + } +} +fn pow_exp_scale(x: S, exponent: S) -> S { + if x == zero::() && exponent.imag() == S::Real::zero() { + if exponent.real() > S::Real::zero() { + zero::() + } else { + nan::() + } + } else { + (x.pow(exponent) * x.ln()).conj() + } +} +#[doc = "Forward rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-tangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_frule;\n\nlet (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0);\nassert_eq!(y, 8.0);\nassert!((dy - 12.0).abs() < 1e-12);\n```"] +pub fn pow_frule(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) { + let y = x.pow(exponent); + let dfdx = if dx == zero::() { + zero::() + } else { + dx * if exponent == zero::() { + zero::() + } else { + exponent * x.pow(exponent - S::from_i32(1)) + } + }; + let dfde = if dexponent == zero::() { + zero::() + } else { + dexponent + * if x == zero::() && exponent.imag() == S::Real::zero() { + if exponent.real() > S::Real::zero() { + zero::() + } else { + nan::() + } + } else { + x.pow(exponent) * x.ln() + } + }; + (y, dfdx + dfde) +} +#[doc = "Reverse rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-cotangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_rrule;\n\nlet (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0);\nassert_eq!(dx, 12.0);\nassert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] +pub fn pow_rrule(x: S, exponent: S, cotangent: S) -> (S, S) { + let dfdx = if cotangent == zero::() { + zero::() + } else { + cotangent * pow_x_scale(x, exponent) + }; + let dfde = if cotangent == zero::() { + zero::() + } else { + cotangent * pow_exp_scale(x, exponent) + }; + (dfdx, dfde) +} diff --git a/crates/chainrules/src/scalar_ad.rs b/crates/chainrules/src/scalar_ad.rs deleted file mode 100644 index 7b0a6dd..0000000 --- a/crates/chainrules/src/scalar_ad.rs +++ /dev/null @@ -1,314 +0,0 @@ -use core::ops::{Add, Div, Mul, Neg, Sub}; - -use num_complex::{Complex32, Complex64}; -use num_traits::Float; - -/// Scalar trait used by elementary AD rule helpers. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::ScalarAd; -/// -/// fn takes_scalar(_x: S) {} -/// -/// takes_scalar(1.0_f32); -/// takes_scalar(1.0_f64); -/// ``` -pub trait ScalarAd: - Copy - + PartialEq - + Neg - + Add - + Sub - + Mul - + Div -{ - /// Real exponent type for `powf`. - type Real: Copy + Float; - - /// Complex conjugate (identity for real scalars). - fn conj(self) -> Self; - - /// Square root. - fn sqrt(self) -> Self; - - /// Exponential. - fn exp(self) -> Self; - - /// `exp(self) - 1`. - fn expm1(self) -> Self; - - /// Natural logarithm. - fn ln(self) -> Self; - - /// `ln(1 + self)`. - fn log1p(self) -> Self; - - /// Sine. - fn sin(self) -> Self; - - /// Cosine. - fn cos(self) -> Self; - - /// Hyperbolic tangent. - fn tanh(self) -> Self; - - /// Arc sine. - fn asin(self) -> Self; - - /// Arc cosine. - fn acos(self) -> Self; - - /// Arc tangent. - fn atan(self) -> Self; - - /// Hyperbolic sine. - fn sinh(self) -> Self; - - /// Hyperbolic cosine. - fn cosh(self) -> Self; - - /// Area hyperbolic sine. - fn asinh(self) -> Self; - - /// Area hyperbolic cosine. - fn acosh(self) -> Self; - - /// Area hyperbolic tangent. - fn atanh(self) -> Self; - - /// Power by real exponent. - fn powf(self, exponent: Self::Real) -> Self; - - /// Power by integer exponent. - fn powi(self, exponent: i32) -> Self; - - /// Convert real scalar to this scalar type. - fn from_real(value: Self::Real) -> Self; - - /// Convert signed integer to this scalar type. - fn from_i32(value: i32) -> Self; -} - -macro_rules! impl_scalar_ad_real { - ($ty:ty) => { - impl ScalarAd for $ty { - type Real = $ty; - - fn conj(self) -> Self { - self - } - - fn sqrt(self) -> Self { - <$ty>::sqrt(self) - } - - fn exp(self) -> Self { - <$ty>::exp(self) - } - - fn expm1(self) -> Self { - <$ty>::exp_m1(self) - } - - fn ln(self) -> Self { - <$ty>::ln(self) - } - - fn log1p(self) -> Self { - <$ty>::ln_1p(self) - } - - fn sin(self) -> Self { - <$ty>::sin(self) - } - - fn cos(self) -> Self { - <$ty>::cos(self) - } - - fn tanh(self) -> Self { - <$ty>::tanh(self) - } - - fn asin(self) -> Self { - <$ty>::asin(self) - } - - fn acos(self) -> Self { - <$ty>::acos(self) - } - - fn atan(self) -> Self { - <$ty>::atan(self) - } - - fn sinh(self) -> Self { - <$ty>::sinh(self) - } - - fn cosh(self) -> Self { - <$ty>::cosh(self) - } - - fn asinh(self) -> Self { - <$ty>::asinh(self) - } - - fn acosh(self) -> Self { - <$ty>::acosh(self) - } - - fn atanh(self) -> Self { - <$ty>::atanh(self) - } - - fn powf(self, exponent: Self::Real) -> Self { - <$ty>::powf(self, exponent) - } - - fn powi(self, exponent: i32) -> Self { - <$ty>::powi(self, exponent) - } - - fn from_real(value: Self::Real) -> Self { - value - } - - fn from_i32(value: i32) -> Self { - value as $ty - } - } - }; -} - -macro_rules! impl_scalar_ad_complex { - ($complex_ty:ty, $real_ty:ty, $one:expr) => { - impl ScalarAd for $complex_ty { - type Real = $real_ty; - - fn conj(self) -> Self { - <$complex_ty>::conj(&self) - } - - fn sqrt(self) -> Self { - <$complex_ty>::sqrt(self) - } - - fn exp(self) -> Self { - <$complex_ty>::exp(self) - } - - fn expm1(self) -> Self { - <$complex_ty>::exp(self) - $one - } - - fn ln(self) -> Self { - <$complex_ty>::ln(self) - } - - fn log1p(self) -> Self { - <$complex_ty>::ln(self + $one) - } - - fn sin(self) -> Self { - <$complex_ty>::sin(self) - } - - fn cos(self) -> Self { - <$complex_ty>::cos(self) - } - - fn tanh(self) -> Self { - <$complex_ty>::tanh(self) - } - - fn asin(self) -> Self { - <$complex_ty>::asin(self) - } - - fn acos(self) -> Self { - <$complex_ty>::acos(self) - } - - fn atan(self) -> Self { - <$complex_ty>::atan(self) - } - - fn sinh(self) -> Self { - <$complex_ty>::sinh(self) - } - - fn cosh(self) -> Self { - <$complex_ty>::cosh(self) - } - - fn asinh(self) -> Self { - <$complex_ty>::asinh(self) - } - - fn acosh(self) -> Self { - <$complex_ty>::acosh(self) - } - - fn atanh(self) -> Self { - <$complex_ty>::atanh(self) - } - - fn powf(self, exponent: Self::Real) -> Self { - <$complex_ty>::powf(self, exponent) - } - - fn powi(self, exponent: i32) -> Self { - <$complex_ty>::powi(&self, exponent) - } - - fn from_real(value: Self::Real) -> Self { - <$complex_ty>::new(value, 0.0) - } - - fn from_i32(value: i32) -> Self { - <$complex_ty>::new(value as $real_ty, 0.0) - } - } - }; -} - -impl_scalar_ad_real!(f32); -impl_scalar_ad_real!(f64); -impl_scalar_ad_complex!(Complex32, f32, Complex32::new(1.0, 0.0)); -impl_scalar_ad_complex!(Complex64, f64, Complex64::new(1.0, 0.0)); - -/// PyTorch-style real-input / complex-gradient projection helper (`handle_r_to_c`). -/// -/// This is equivalent to taking the real part when a gradient for real input -/// becomes complex during intermediate algebra. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::handle_r_to_c_f64; -/// use num_complex::Complex64; -/// -/// let g = Complex64::new(1.25, -3.0); -/// assert_eq!(handle_r_to_c_f64(g), 1.25); -/// ``` -pub fn handle_r_to_c_f64(gradient: Complex64) -> f64 { - gradient.re -} - -/// `f32` variant of [`handle_r_to_c_f64`]. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::handle_r_to_c_f32; -/// use num_complex::Complex32; -/// -/// let g = Complex32::new(2.0, 4.0); -/// assert_eq!(handle_r_to_c_f32(g), 2.0); -/// ``` -pub fn handle_r_to_c_f32(gradient: Complex32) -> f32 { - gradient.re -} diff --git a/crates/chainrules/src/scalar_ad/complex.rs b/crates/chainrules/src/scalar_ad/complex.rs new file mode 100644 index 0000000..6d55e9f --- /dev/null +++ b/crates/chainrules/src/scalar_ad/complex.rs @@ -0,0 +1,155 @@ +use super::ScalarAd; +use num_complex::{Complex32, Complex64, ComplexFloat}; +use num_traits::{FloatConst, Zero}; + +macro_rules! impl_scalar_ad_complex { + ($complex_ty:ty, $real_ty:ty, $one:expr) => { + impl ScalarAd for $complex_ty { + type Real = $real_ty; + + fn conj(self) -> Self { + <$complex_ty>::conj(&self) + } + + fn recip(self) -> Self { + <$complex_ty as ComplexFloat>::recip(self) + } + + fn cbrt(self) -> Self { + <$complex_ty as ComplexFloat>::cbrt(self) + } + + fn sqrt(self) -> Self { + <$complex_ty as ComplexFloat>::sqrt(self) + } + + fn exp(self) -> Self { + <$complex_ty as ComplexFloat>::exp(self) + } + + fn exp2(self) -> Self { + <$complex_ty as ComplexFloat>::exp2(self) + } + + fn exp10(self) -> Self { + <$complex_ty as ComplexFloat>::exp( + self * <$complex_ty>::new( + <$real_ty as FloatConst>::LN_10(), + <$real_ty as Zero>::zero(), + ), + ) + } + + fn expm1(self) -> Self { + <$complex_ty as ComplexFloat>::exp(self) - $one + } + + fn ln(self) -> Self { + <$complex_ty as ComplexFloat>::ln(self) + } + + fn log1p(self) -> Self { + <$complex_ty as ComplexFloat>::ln(self + $one) + } + + fn log2(self) -> Self { + <$complex_ty as ComplexFloat>::log2(self) + } + + fn log10(self) -> Self { + <$complex_ty as ComplexFloat>::log10(self) + } + + fn sin(self) -> Self { + <$complex_ty as ComplexFloat>::sin(self) + } + + fn cos(self) -> Self { + <$complex_ty as ComplexFloat>::cos(self) + } + + fn tan(self) -> Self { + <$complex_ty as ComplexFloat>::tan(self) + } + + fn tanh(self) -> Self { + <$complex_ty as ComplexFloat>::tanh(self) + } + + fn asin(self) -> Self { + <$complex_ty as ComplexFloat>::asin(self) + } + + fn acos(self) -> Self { + <$complex_ty as ComplexFloat>::acos(self) + } + + fn atan(self) -> Self { + <$complex_ty as ComplexFloat>::atan(self) + } + + fn sinh(self) -> Self { + <$complex_ty as ComplexFloat>::sinh(self) + } + + fn cosh(self) -> Self { + <$complex_ty as ComplexFloat>::cosh(self) + } + + fn asinh(self) -> Self { + <$complex_ty as ComplexFloat>::asinh(self) + } + + fn acosh(self) -> Self { + <$complex_ty as ComplexFloat>::acosh(self) + } + + fn atanh(self) -> Self { + <$complex_ty as ComplexFloat>::atanh(self) + } + + fn abs(self) -> Self::Real { + <$complex_ty as ComplexFloat>::abs(self) + } + + fn abs2(self) -> Self::Real { + <$complex_ty>::norm_sqr(&self) + } + + fn real(self) -> Self::Real { + self.re + } + + fn imag(self) -> Self::Real { + self.im + } + + fn angle(self) -> Self::Real { + <$complex_ty as ComplexFloat>::arg(self) + } + + fn powf(self, exponent: Self::Real) -> Self { + <$complex_ty as ComplexFloat>::powf(self, exponent) + } + + fn powi(self, exponent: i32) -> Self { + <$complex_ty as ComplexFloat>::powi(self, exponent) + } + + fn pow(self, exponent: Self) -> Self { + <$complex_ty as ComplexFloat>::powc(self, exponent) + } + + fn from_real(value: Self::Real) -> Self { + <$complex_ty>::new(value, 0.0) + } + + fn from_i32(value: i32) -> Self { + <$complex_ty>::new(value as $real_ty, 0.0) + } + } + }; +} + +impl_scalar_ad_complex!(Complex32, f32, Complex32::new(1.0, 0.0)); +impl_scalar_ad_complex!(Complex64, f64, Complex64::new(1.0, 0.0)); diff --git a/crates/chainrules/src/scalar_ad/mod.rs b/crates/chainrules/src/scalar_ad/mod.rs new file mode 100644 index 0000000..8db3e32 --- /dev/null +++ b/crates/chainrules/src/scalar_ad/mod.rs @@ -0,0 +1,133 @@ +use core::ops::{Add, Div, Mul, Neg, Sub}; + +use num_traits::{Float, FloatConst}; + +/// Scalar trait used by elementary AD rule helpers. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ScalarAd; +/// +/// fn takes_scalar(_x: S) {} +/// +/// takes_scalar(1.0_f32); +/// takes_scalar(1.0_f64); +/// ``` +pub trait ScalarAd: + Copy + + PartialEq + + Neg + + Add + + Sub + + Mul + + Div +{ + /// Real exponent type for `powf`. + type Real: Copy + Float + FloatConst; + + /// Complex conjugate (identity for real scalars). + fn conj(self) -> Self; + + /// Reciprocal. + fn recip(self) -> Self; + + /// Cubic root. + fn cbrt(self) -> Self; + + /// Square root. + fn sqrt(self) -> Self; + + /// Exponential. + fn exp(self) -> Self; + + /// `2^self`. + fn exp2(self) -> Self; + + /// `10^self`. + fn exp10(self) -> Self; + + /// `exp(self) - 1`. + fn expm1(self) -> Self; + + /// Natural logarithm. + fn ln(self) -> Self; + + /// `ln(1 + self)`. + fn log1p(self) -> Self; + + /// Base-2 logarithm. + fn log2(self) -> Self; + + /// Base-10 logarithm. + fn log10(self) -> Self; + + /// Sine. + fn sin(self) -> Self; + + /// Cosine. + fn cos(self) -> Self; + + /// Tangent. + fn tan(self) -> Self; + + /// Hyperbolic tangent. + fn tanh(self) -> Self; + + /// Arc sine. + fn asin(self) -> Self; + + /// Arc cosine. + fn acos(self) -> Self; + + /// Arc tangent. + fn atan(self) -> Self; + + /// Hyperbolic sine. + fn sinh(self) -> Self; + + /// Hyperbolic cosine. + fn cosh(self) -> Self; + + /// Area hyperbolic sine. + fn asinh(self) -> Self; + + /// Area hyperbolic cosine. + fn acosh(self) -> Self; + + /// Area hyperbolic tangent. + fn atanh(self) -> Self; + + /// Absolute value. + fn abs(self) -> Self::Real; + + /// Absolute value squared. + fn abs2(self) -> Self::Real; + + /// Real part. + fn real(self) -> Self::Real; + + /// Imaginary part. + fn imag(self) -> Self::Real; + + /// Polar angle. + fn angle(self) -> Self::Real; + + /// Power by real exponent. + fn powf(self, exponent: Self::Real) -> Self; + + /// Power by integer exponent. + fn powi(self, exponent: i32) -> Self; + + /// Power by same-scalar exponent. + fn pow(self, exponent: Self) -> Self; + + /// Convert real scalar to this scalar type. + fn from_real(value: Self::Real) -> Self; + + /// Convert signed integer to this scalar type. + fn from_i32(value: i32) -> Self; +} + +mod complex; +mod real; diff --git a/crates/chainrules/src/scalar_ad/real.rs b/crates/chainrules/src/scalar_ad/real.rs new file mode 100644 index 0000000..2069186 --- /dev/null +++ b/crates/chainrules/src/scalar_ad/real.rs @@ -0,0 +1,149 @@ +use super::ScalarAd; +use num_traits::{Float, FloatConst, Zero}; + +macro_rules! impl_scalar_ad_real { + ($ty:ty) => { + impl ScalarAd for $ty { + type Real = $ty; + + fn conj(self) -> Self { + self + } + + fn recip(self) -> Self { + <$ty as Float>::recip(self) + } + + fn cbrt(self) -> Self { + <$ty as Float>::cbrt(self) + } + + fn sqrt(self) -> Self { + <$ty as Float>::sqrt(self) + } + + fn exp(self) -> Self { + <$ty as Float>::exp(self) + } + + fn exp2(self) -> Self { + <$ty as Float>::exp2(self) + } + + fn exp10(self) -> Self { + <$ty as Float>::exp(self * <$ty as FloatConst>::LN_10()) + } + + fn expm1(self) -> Self { + <$ty as Float>::exp_m1(self) + } + + fn ln(self) -> Self { + <$ty as Float>::ln(self) + } + + fn log1p(self) -> Self { + <$ty as Float>::ln_1p(self) + } + + fn log2(self) -> Self { + <$ty as Float>::log2(self) + } + + fn log10(self) -> Self { + <$ty as Float>::log10(self) + } + + fn sin(self) -> Self { + <$ty as Float>::sin(self) + } + + fn cos(self) -> Self { + <$ty as Float>::cos(self) + } + + fn tan(self) -> Self { + <$ty as Float>::tan(self) + } + + fn tanh(self) -> Self { + <$ty as Float>::tanh(self) + } + + fn asin(self) -> Self { + <$ty as Float>::asin(self) + } + + fn acos(self) -> Self { + <$ty as Float>::acos(self) + } + + fn atan(self) -> Self { + <$ty as Float>::atan(self) + } + + fn sinh(self) -> Self { + <$ty as Float>::sinh(self) + } + + fn cosh(self) -> Self { + <$ty as Float>::cosh(self) + } + + fn asinh(self) -> Self { + <$ty as Float>::asinh(self) + } + + fn acosh(self) -> Self { + <$ty as Float>::acosh(self) + } + + fn atanh(self) -> Self { + <$ty as Float>::atanh(self) + } + + fn abs(self) -> Self::Real { + <$ty as Float>::abs(self) + } + + fn abs2(self) -> Self::Real { + self * self + } + + fn real(self) -> Self::Real { + self + } + + fn imag(self) -> Self::Real { + <$ty as Zero>::zero() + } + + fn angle(self) -> Self::Real { + <$ty as Float>::atan2(<$ty as Zero>::zero(), self) + } + + fn powf(self, exponent: Self::Real) -> Self { + <$ty as Float>::powf(self, exponent) + } + + fn powi(self, exponent: i32) -> Self { + <$ty as Float>::powi(self, exponent) + } + + fn pow(self, exponent: Self) -> Self { + <$ty as Float>::powf(self, exponent) + } + + fn from_real(value: Self::Real) -> Self { + value + } + + fn from_i32(value: i32) -> Self { + value as $ty + } + } + }; +} + +impl_scalar_ad_real!(f32); +impl_scalar_ad_real!(f64); diff --git a/crates/chainrules/src/tests/behavior.rs b/crates/chainrules/src/tests/behavior.rs index 4fe310a..60ea911 100644 --- a/crates/chainrules/src/tests/behavior.rs +++ b/crates/chainrules/src/tests/behavior.rs @@ -1,13 +1,12 @@ -use num_complex::{Complex32, Complex64}; +use num_complex::{Complex32, Complex64, ComplexFloat}; use crate::{ acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, atan, atan2, atan2_frule, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, - cosh_frule, cosh_rrule, exp, exp_frule, exp_rrule, expm1_frule, expm1_rrule, handle_r_to_c_f32, - handle_r_to_c_f64, log, log1p_frule, log1p_rrule, log_frule, log_rrule, sin, sin_frule, - sin_rrule, sinh, sinh_frule, sinh_rrule, sqrt, sqrt_frule, sqrt_rrule, tanh, tanh_frule, - tanh_rrule, ScalarAd, + cosh_frule, cosh_rrule, exp, exp_frule, exp_rrule, expm1_frule, expm1_rrule, log, log1p_frule, + log1p_rrule, log_frule, log_rrule, sin, sin_frule, sin_rrule, sinh, sinh_frule, sinh_rrule, + sqrt, sqrt_frule, sqrt_rrule, tanh, tanh_frule, tanh_rrule, ScalarAd, }; fn assert_close_f32(actual: f32, expected: f32) { @@ -78,31 +77,46 @@ fn scalar_ad_real_impls_match_std_real_ops() { #[test] fn scalar_ad_complex_impls_match_std_complex_ops() { let x32 = Complex32::new(0.25, -0.5); - assert_close_c32(::conj(x32), x32.conj()); - assert_close_c32(::sqrt(x32), x32.sqrt()); - assert_close_c32(::exp(x32), x32.exp()); + assert_close_c32(::conj(x32), ComplexFloat::conj(x32)); + assert_close_c32(::sqrt(x32), ComplexFloat::sqrt(x32)); + assert_close_c32(::exp(x32), ComplexFloat::exp(x32)); assert_close_c32( ::expm1(x32), - x32.exp() - Complex32::new(1.0, 0.0), + ComplexFloat::exp(x32) - Complex32::new(1.0, 0.0), ); - assert_close_c32(::ln(x32), x32.ln()); + assert_close_c32(::ln(x32), ComplexFloat::ln(x32)); assert_close_c32( ::log1p(x32), - (x32 + Complex32::new(1.0, 0.0)).ln(), - ); - assert_close_c32(::sin(x32), x32.sin()); - assert_close_c32(::cos(x32), x32.cos()); - assert_close_c32(::tanh(x32), x32.tanh()); - assert_close_c32(::asin(x32), x32.asin()); - assert_close_c32(::acos(x32), x32.acos()); - assert_close_c32(::atan(x32), x32.atan()); - assert_close_c32(::sinh(x32), x32.sinh()); - assert_close_c32(::cosh(x32), x32.cosh()); - assert_close_c32(::asinh(x32), x32.asinh()); - assert_close_c32(::acosh(x32), x32.acosh()); - assert_close_c32(::atanh(x32), x32.atanh()); - assert_close_c32(::powf(x32, 2.0), x32.powf(2.0)); - assert_close_c32(::powi(x32, 3), x32.powi(3)); + ComplexFloat::ln(x32 + Complex32::new(1.0, 0.0)), + ); + assert_close_c32(::sin(x32), ComplexFloat::sin(x32)); + assert_close_c32(::cos(x32), ComplexFloat::cos(x32)); + assert_close_c32(::tanh(x32), ComplexFloat::tanh(x32)); + assert_close_c32(::asin(x32), ComplexFloat::asin(x32)); + assert_close_c32(::acos(x32), ComplexFloat::acos(x32)); + assert_close_c32(::atan(x32), ComplexFloat::atan(x32)); + assert_close_c32(::sinh(x32), ComplexFloat::sinh(x32)); + assert_close_c32(::cosh(x32), ComplexFloat::cosh(x32)); + assert_close_c32( + ::asinh(x32), + ComplexFloat::asinh(x32), + ); + assert_close_c32( + ::acosh(x32), + ComplexFloat::acosh(x32), + ); + assert_close_c32( + ::atanh(x32), + ComplexFloat::atanh(x32), + ); + assert_close_c32( + ::powf(x32, 2.0), + ComplexFloat::powf(x32, 2.0), + ); + assert_close_c32( + ::powi(x32, 3), + ComplexFloat::powi(x32, 3), + ); assert_eq!( ::from_real(1.5), Complex32::new(1.5, 0.0) @@ -113,31 +127,46 @@ fn scalar_ad_complex_impls_match_std_complex_ops() { ); let x64 = Complex64::new(0.5, 0.75); - assert_close_c64(::conj(x64), x64.conj()); - assert_close_c64(::sqrt(x64), x64.sqrt()); - assert_close_c64(::exp(x64), x64.exp()); + assert_close_c64(::conj(x64), ComplexFloat::conj(x64)); + assert_close_c64(::sqrt(x64), ComplexFloat::sqrt(x64)); + assert_close_c64(::exp(x64), ComplexFloat::exp(x64)); assert_close_c64( ::expm1(x64), - x64.exp() - Complex64::new(1.0, 0.0), + ComplexFloat::exp(x64) - Complex64::new(1.0, 0.0), ); - assert_close_c64(::ln(x64), x64.ln()); + assert_close_c64(::ln(x64), ComplexFloat::ln(x64)); assert_close_c64( ::log1p(x64), - (x64 + Complex64::new(1.0, 0.0)).ln(), - ); - assert_close_c64(::sin(x64), x64.sin()); - assert_close_c64(::cos(x64), x64.cos()); - assert_close_c64(::tanh(x64), x64.tanh()); - assert_close_c64(::asin(x64), x64.asin()); - assert_close_c64(::acos(x64), x64.acos()); - assert_close_c64(::atan(x64), x64.atan()); - assert_close_c64(::sinh(x64), x64.sinh()); - assert_close_c64(::cosh(x64), x64.cosh()); - assert_close_c64(::asinh(x64), x64.asinh()); - assert_close_c64(::acosh(x64), x64.acosh()); - assert_close_c64(::atanh(x64), x64.atanh()); - assert_close_c64(::powf(x64, 1.5), x64.powf(1.5)); - assert_close_c64(::powi(x64, 2), x64.powi(2)); + ComplexFloat::ln(x64 + Complex64::new(1.0, 0.0)), + ); + assert_close_c64(::sin(x64), ComplexFloat::sin(x64)); + assert_close_c64(::cos(x64), ComplexFloat::cos(x64)); + assert_close_c64(::tanh(x64), ComplexFloat::tanh(x64)); + assert_close_c64(::asin(x64), ComplexFloat::asin(x64)); + assert_close_c64(::acos(x64), ComplexFloat::acos(x64)); + assert_close_c64(::atan(x64), ComplexFloat::atan(x64)); + assert_close_c64(::sinh(x64), ComplexFloat::sinh(x64)); + assert_close_c64(::cosh(x64), ComplexFloat::cosh(x64)); + assert_close_c64( + ::asinh(x64), + ComplexFloat::asinh(x64), + ); + assert_close_c64( + ::acosh(x64), + ComplexFloat::acosh(x64), + ); + assert_close_c64( + ::atanh(x64), + ComplexFloat::atanh(x64), + ); + assert_close_c64( + ::powf(x64, 1.5), + ComplexFloat::powf(x64, 1.5), + ); + assert_close_c64( + ::powi(x64, 2), + ComplexFloat::powi(x64, 2), + ); assert_eq!( ::from_real(2.5), Complex64::new(2.5, 0.0) @@ -149,10 +178,99 @@ fn scalar_ad_complex_impls_match_std_complex_ops() { } #[test] -fn direct_entrypoints_match_real_projection_and_atan2_formulas() { - assert_eq!(handle_r_to_c_f32(Complex32::new(2.0, -5.0)), 2.0); - assert_eq!(handle_r_to_c_f64(Complex64::new(-3.0, 1.5)), -3.0); +fn scalar_ad_real_extended_surface_matches_std_real_ops() { + let x32 = 0.25_f32; + assert_close_f32(::recip(x32), x32.recip()); + assert_close_f32(::cbrt(x32), x32.cbrt()); + assert_close_f32(::exp2(x32), x32.exp2()); + assert_close_f32(::exp10(x32), 10.0_f32.powf(x32)); + assert_close_f32(::log2(x32), x32.log2()); + assert_close_f32(::log10(x32), x32.log10()); + assert_close_f32(::tan(x32), x32.tan()); + assert_close_f32(::abs(x32), x32.abs()); + assert_close_f32(::abs2(x32), x32 * x32); + assert_close_f32(::real(x32), x32); + assert_close_f32(::imag(x32), 0.0); + assert_close_f32(::angle(-x32), 0.0_f32.atan2(-x32)); + assert_close_f32(::pow(x32, 1.5), x32.powf(1.5)); + + let x64 = 0.5_f64; + assert_close_f64(::recip(x64), x64.recip()); + assert_close_f64(::cbrt(x64), x64.cbrt()); + assert_close_f64(::exp2(x64), x64.exp2()); + assert_close_f64(::exp10(x64), 10.0_f64.powf(x64)); + assert_close_f64(::log2(x64), x64.log2()); + assert_close_f64(::log10(x64), x64.log10()); + assert_close_f64(::tan(x64), x64.tan()); + assert_close_f64(::abs(x64), x64.abs()); + assert_close_f64(::abs2(x64), x64 * x64); + assert_close_f64(::real(x64), x64); + assert_close_f64(::imag(x64), 0.0); + assert_close_f64(::angle(-x64), 0.0_f64.atan2(-x64)); + assert_close_f64(::pow(x64, 1.5), x64.powf(1.5)); +} +#[test] +fn scalar_ad_complex_extended_surface_matches_std_complex_ops() { + let x32 = Complex32::new(0.25, -0.5); + let y32 = Complex32::new(1.25, 0.75); + assert_close_c32( + ::recip(x32), + ComplexFloat::recip(x32), + ); + assert_close_c32(::cbrt(x32), ComplexFloat::cbrt(x32)); + assert_close_c32(::exp2(x32), ComplexFloat::exp2(x32)); + assert_close_c32( + ::exp10(x32), + ComplexFloat::exp(x32 * Complex32::new(std::f32::consts::LN_10, 0.0)), + ); + assert_close_c32(::log2(x32), ComplexFloat::log2(x32)); + assert_close_c32( + ::log10(x32), + ComplexFloat::log10(x32), + ); + assert_close_c32(::tan(x32), ComplexFloat::tan(x32)); + assert_close_f32(::abs(x32), x32.norm()); + assert_close_f32(::abs2(x32), x32.norm_sqr()); + assert_close_f32(::real(x32), x32.re); + assert_close_f32(::imag(x32), x32.im); + assert_close_f32(::angle(x32), x32.arg()); + assert_close_c32( + ::pow(x32, y32), + ComplexFloat::powc(x32, y32), + ); + + let x64 = Complex64::new(0.5, 0.75); + let y64 = Complex64::new(1.5, -0.25); + assert_close_c64( + ::recip(x64), + ComplexFloat::recip(x64), + ); + assert_close_c64(::cbrt(x64), ComplexFloat::cbrt(x64)); + assert_close_c64(::exp2(x64), ComplexFloat::exp2(x64)); + assert_close_c64( + ::exp10(x64), + ComplexFloat::exp(x64 * Complex64::new(std::f64::consts::LN_10, 0.0)), + ); + assert_close_c64(::log2(x64), ComplexFloat::log2(x64)); + assert_close_c64( + ::log10(x64), + ComplexFloat::log10(x64), + ); + assert_close_c64(::tan(x64), ComplexFloat::tan(x64)); + assert_close_f64(::abs(x64), x64.norm()); + assert_close_f64(::abs2(x64), x64.norm_sqr()); + assert_close_f64(::real(x64), x64.re); + assert_close_f64(::imag(x64), x64.im); + assert_close_f64(::angle(x64), x64.arg()); + assert_close_c64( + ::pow(x64, y64), + ComplexFloat::powc(x64, y64), + ); +} + +#[test] +fn direct_entrypoints_match_atan2_formulas() { let primal = atan2(3.0_f64, 4.0_f64); assert_close_f64(primal, 3.0_f64.atan2(4.0)); @@ -164,10 +282,10 @@ fn direct_entrypoints_match_real_projection_and_atan2_formulas() { #[test] fn unary_entrypoints_match_forward_and_reverse_formulas() { let complex = Complex32::new(1.0, -2.0); - assert_eq!(conj(complex), complex.conj()); + assert_eq!(conj(complex), ComplexFloat::conj(complex)); let (_y, dy) = conj_frule(complex, Complex32::new(3.0, 4.0)); assert_eq!(dy, Complex32::new(3.0, -4.0)); - assert_eq!(conj_rrule(complex), complex.conj()); + assert_eq!(conj_rrule(complex), ComplexFloat::conj(complex)); assert_eq!(sqrt(9.0_f32), 3.0); let (sqrt_y, sqrt_dy) = sqrt_frule(9.0_f32, 2.0_f32); @@ -305,46 +423,53 @@ fn trig_and_hyperbolic_primal_entrypoints_match_std_ops() { assert_close_f64(acosh(acosh_real), acosh_real.acosh()); let complex = Complex64::new(0.25, -0.5); - assert_close_c64(sin(complex), complex.sin()); - assert_close_c64(cos(complex), complex.cos()); - assert_close_c64(tanh(complex), complex.tanh()); - assert_close_c64(asin(complex), complex.asin()); - assert_close_c64(acos(complex), complex.acos()); - assert_close_c64(atan(complex), complex.atan()); - assert_close_c64(sinh(complex), complex.sinh()); - assert_close_c64(cosh(complex), complex.cosh()); - assert_close_c64(asinh(complex), complex.asinh()); - assert_close_c64(acosh(complex), complex.acosh()); - assert_close_c64(atanh(complex), complex.atanh()); + assert_close_c64(sin(complex), ComplexFloat::sin(complex)); + assert_close_c64(cos(complex), ComplexFloat::cos(complex)); + assert_close_c64(tanh(complex), ComplexFloat::tanh(complex)); + assert_close_c64(asin(complex), ComplexFloat::asin(complex)); + assert_close_c64(acos(complex), ComplexFloat::acos(complex)); + assert_close_c64(atan(complex), ComplexFloat::atan(complex)); + assert_close_c64(sinh(complex), ComplexFloat::sinh(complex)); + assert_close_c64(cosh(complex), ComplexFloat::cosh(complex)); + assert_close_c64(asinh(complex), ComplexFloat::asinh(complex)); + assert_close_c64(acosh(complex), ComplexFloat::acosh(complex)); + assert_close_c64(atanh(complex), ComplexFloat::atanh(complex)); } #[test] -fn extended_complex_unary_rules_conjugate_their_jacobians() { +fn extended_complex_unary_rules_use_standard_jvps_with_conjugate_rrules() { let x = Complex64::new(0.25, -0.5); let dx = Complex64::new(-0.75, 0.5); let cotangent = Complex64::new(0.5, -1.25); let (_sin_y, sin_dy) = sin_frule(x, dx); - assert_close_c64(sin_dy, dx * x.cos().conj()); - assert_close_c64(sin_rrule(x, cotangent), cotangent * x.cos().conj()); + assert_close_c64(sin_dy, dx * ComplexFloat::cos(x)); + assert_close_c64( + sin_rrule(x, cotangent), + cotangent * ComplexFloat::conj(ComplexFloat::cos(x)), + ); let (_cos_y, cos_dy) = cos_frule(x, dx); - assert_close_c64(cos_dy, dx * (-x.sin()).conj()); - assert_close_c64(cos_rrule(x, cotangent), cotangent * (-x.sin()).conj()); - - let tanh_y = x.tanh(); - let (_tanh_primal, tanh_dy) = tanh_frule(x, dx); + assert_close_c64(cos_dy, dx * -ComplexFloat::sin(x)); assert_close_c64( - tanh_dy, - dx * (Complex64::new(1.0, 0.0) - tanh_y * tanh_y).conj(), + cos_rrule(x, cotangent), + cotangent * ComplexFloat::conj(-ComplexFloat::sin(x)), ); + + let tanh_y = ComplexFloat::tanh(x); + let (_tanh_primal, tanh_dy) = tanh_frule(x, dx); + assert_close_c64(tanh_dy, dx * (Complex64::new(1.0, 0.0) - tanh_y * tanh_y)); assert_close_c64( tanh_rrule(tanh_y, cotangent), - cotangent * (Complex64::new(1.0, 0.0) - tanh_y * tanh_y).conj(), + cotangent * ComplexFloat::conj(Complex64::new(1.0, 0.0) - tanh_y * tanh_y), ); let (_asinh_y, asinh_dy) = asinh_frule(x, dx); - let asinh_scale = (Complex64::new(1.0, 0.0) / (Complex64::new(1.0, 0.0) + x * x).sqrt()).conj(); + let asinh_scale = + Complex64::new(1.0, 0.0) / ComplexFloat::sqrt(Complex64::new(1.0, 0.0) + x * x); assert_close_c64(asinh_dy, dx * asinh_scale); - assert_close_c64(asinh_rrule(x, cotangent), cotangent * asinh_scale); + assert_close_c64( + asinh_rrule(x, cotangent), + cotangent * ComplexFloat::conj(asinh_scale), + ); } diff --git a/crates/chainrules/src/tests/organization.rs b/crates/chainrules/src/tests/organization.rs index 0f6f72e..30eec50 100644 --- a/crates/chainrules/src/tests/organization.rs +++ b/crates/chainrules/src/tests/organization.rs @@ -9,22 +9,63 @@ fn assert_line_count(path: &str, content: &str, max_lines: usize) { // Do not delete or weaken this test: it protects the split scalar AD rule modules that keep this crate extensible. #[test] fn chainrules_modules_stay_under_size_guideline() { - assert_line_count("../lib.rs", include_str!("../lib.rs"), 120); - assert_line_count("../scalar_ad.rs", include_str!("../scalar_ad.rs"), 320); - assert_line_count("../binary.rs", include_str!("../binary.rs"), 260); - assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 80); - assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 80); + assert_line_count("../lib.rs", include_str!("../lib.rs"), 60); + assert_line_count("../binary.rs", include_str!("../binary.rs"), 220); + assert_line_count( + "../binary_special.rs", + include_str!("../binary_special.rs"), + 200, + ); + assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 60); + assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 40); + assert_line_count("../unary/roots.rs", include_str!("../unary/roots.rs"), 100); + assert_line_count( + "../unary/complex_parts.rs", + include_str!("../unary/complex_parts.rs"), + 210, + ); assert_line_count( "../unary/exp_log.rs", include_str!("../unary/exp_log.rs"), - 120, + 130, ); - assert_line_count("../unary/trig.rs", include_str!("../unary/trig.rs"), 140); + assert_line_count("../unary/trig.rs", include_str!("../unary/trig.rs"), 135); assert_line_count( "../unary/hyperbolic.rs", include_str!("../unary/hyperbolic.rs"), - 140, + 120, + ); + assert_line_count( + "../unary/trig_extra.rs", + include_str!("../unary/trig_extra.rs"), + 520, + ); + assert_line_count( + "../unary/hyperbolic_extra.rs", + include_str!("../unary/hyperbolic_extra.rs"), + 180, + ); + assert_line_count( + "../unary/nonsmooth.rs", + include_str!("../unary/nonsmooth.rs"), + 220, + ); + assert_line_count("../unary/smooth.rs", include_str!("../unary/smooth.rs"), 30); + assert_line_count("../power.rs", include_str!("../power.rs"), 170); + assert_line_count("../real_ops.rs", include_str!("../real_ops.rs"), 70); + assert_line_count( + "../scalar_ad/mod.rs", + include_str!("../scalar_ad/mod.rs"), + 180, + ); + assert_line_count( + "../scalar_ad/real.rs", + include_str!("../scalar_ad/real.rs"), + 160, + ); + assert_line_count( + "../scalar_ad/complex.rs", + include_str!("../scalar_ad/complex.rs"), + 170, ); - assert_line_count("../power.rs", include_str!("../power.rs"), 180); - assert_line_count("../real_ops.rs", include_str!("../real_ops.rs"), 120); } diff --git a/crates/chainrules/src/unary/basic.rs b/crates/chainrules/src/unary/basic.rs index 886c01b..7b76987 100644 --- a/crates/chainrules/src/unary/basic.rs +++ b/crates/chainrules/src/unary/basic.rs @@ -23,7 +23,7 @@ pub fn sqrt(x: S) -> S { /// Forward rule for `sqrt`. pub fn sqrt_frule(x: S, dx: S) -> (S, S) { let y = x.sqrt(); - let dy = dx / (S::from_i32(2) * y.conj()); + let dy = dx / (S::from_i32(2) * y); (y, dy) } diff --git a/crates/chainrules/src/unary/complex_parts.rs b/crates/chainrules/src/unary/complex_parts.rs new file mode 100644 index 0000000..c5fe129 --- /dev/null +++ b/crates/chainrules/src/unary/complex_parts.rs @@ -0,0 +1,198 @@ +use crate::ScalarAd; +use num_complex::Complex; +use num_traits::{Float, One, Zero}; + +trait ComplexProjectionScalar: ScalarAd { + fn from_parts(re: Self::Real, im: Self::Real) -> Self; +} + +impl ComplexProjectionScalar for num_complex::Complex32 { + fn from_parts(re: Self::Real, im: Self::Real) -> Self { + Complex::new(re, im) + } +} + +impl ComplexProjectionScalar for num_complex::Complex64 { + fn from_parts(re: Self::Real, im: Self::Real) -> Self { + Complex::new(re, im) + } +} + +/// Primal `abs`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs; +/// use num_complex::Complex64; +/// +/// assert_eq!(abs(Complex64::new(3.0, 4.0)), 5.0); +/// ``` +#[inline] +pub fn abs(x: S) -> S::Real { + x.abs() +} + +/// Primal `abs2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs2; +/// use num_complex::Complex64; +/// +/// assert_eq!(abs2(Complex64::new(3.0, 4.0)), 25.0); +/// ``` +#[inline] +pub fn abs2(x: S) -> S::Real { + x.abs2() +} + +/// Primal `real`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::real; +/// use num_complex::Complex64; +/// +/// assert_eq!(real(Complex64::new(3.0, 4.0)), 3.0); +/// ``` +#[inline] +pub fn real(x: S) -> S::Real { + x.real() +} + +/// Primal `imag`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::imag; +/// use num_complex::Complex64; +/// +/// assert_eq!(imag(Complex64::new(3.0, 4.0)), 4.0); +/// ``` +#[inline] +pub fn imag(x: S) -> S::Real { + x.imag() +} + +/// Primal `angle`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::angle; +/// use num_complex::Complex64; +/// +/// assert!((angle(Complex64::new(3.0, 4.0)) - 0.9272952180016122).abs() < 1e-12); +/// ``` +#[inline] +pub fn angle(x: S) -> S::Real { + x.angle() +} + +/// Construct a complex number from real and imaginary parts. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::complex; +/// use num_complex::Complex64; +/// +/// assert_eq!(complex(3.0_f64, 4.0_f64), Complex64::new(3.0, 4.0)); +/// ``` +#[inline] +pub fn complex(re: R, im: R) -> Complex { + Complex::new(re, im) +} + +/// Forward rule for `abs2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs2_frule; +/// use num_complex::Complex64; +/// +/// let z = Complex64::new(3.0, 4.0); +/// let dz = Complex64::new(1.0, -2.0); +/// let (y, dy) = abs2_frule(z, dz); +/// assert_eq!(y, 25.0); +/// assert_eq!(dy, -10.0); +/// ``` +#[inline] +pub fn abs2_frule(x: S, dx: S) -> (S::Real, S::Real) { + let y = x.abs2(); + let two = S::Real::one() + S::Real::one(); + let dy = two * (x.real() * dx.real() + x.imag() * dx.imag()); + (y, dy) +} + +/// Reverse rule for `abs2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs2_rrule; +/// use num_complex::Complex64; +/// +/// let z = Complex64::new(3.0, 4.0); +/// assert_eq!(abs2_rrule(z, 1.25), Complex64::new(7.5, 10.0)); +/// ``` +#[inline] +pub fn abs2_rrule(x: Complex, cotangent: R) -> Complex { + let two = R::one() + R::one(); + Complex::new(two * cotangent * x.re, two * cotangent * x.im) +} + +/// Reverse rule for `real`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::real_rrule; +/// use num_complex::Complex64; +/// +/// let grad: Complex64 = real_rrule(2.0); +/// assert_eq!(grad, Complex64::new(2.0, 0.0)); +/// ``` +#[inline] +#[allow(private_bounds)] +pub fn real_rrule(cotangent: S::Real) -> S { + S::from_parts(cotangent, S::Real::zero()) +} + +/// Reverse rule for `imag`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::imag_rrule; +/// use num_complex::Complex64; +/// +/// let grad: Complex64 = imag_rrule(2.0); +/// assert_eq!(grad, Complex64::new(0.0, 2.0)); +/// ``` +#[inline] +#[allow(private_bounds)] +pub fn imag_rrule(cotangent: S::Real) -> S { + S::from_parts(S::Real::zero(), cotangent) +} + +/// Reverse rule for `angle`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::angle_rrule; +/// use num_complex::Complex64; +/// +/// assert_eq!(angle_rrule(Complex64::new(3.0, 4.0), 1.0), Complex64::new(-0.16, 0.12)); +/// ``` +#[inline] +pub fn angle_rrule(x: Complex, cotangent: R) -> Complex { + let denom = x.re * x.re + x.im * x.im; + Complex::new(-x.im * cotangent / denom, x.re * cotangent / denom) +} diff --git a/crates/chainrules/src/unary/exp_log.rs b/crates/chainrules/src/unary/exp_log.rs index 3a7a0f0..7ebabbe 100644 --- a/crates/chainrules/src/unary/exp_log.rs +++ b/crates/chainrules/src/unary/exp_log.rs @@ -1,69 +1,118 @@ use crate::unary::one; use crate::ScalarAd; - +use num_traits::FloatConst; +fn ln_2() -> S { + S::from_real(S::Real::LN_2()) +} +fn ln_10() -> S { + S::from_real(S::Real::LN_10()) +} /// Primal `exp`. pub fn exp(x: S) -> S { x.exp() } - /// Forward rule for `exp`. pub fn exp_frule(x: S, dx: S) -> (S, S) { let y = x.exp(); - (y, dx * y.conj()) + (y, dx * y) } - /// Reverse rule for `exp`. pub fn exp_rrule(result: S, cotangent: S) -> S { cotangent * result.conj() } - /// Primal `exp(x) - 1`. pub fn expm1(x: S) -> S { x.expm1() } - /// Forward rule for `exp(x) - 1`. pub fn expm1_frule(x: S, dx: S) -> (S, S) { let y = x.expm1(); let scale = y + one::(); - (y, dx * scale.conj()) + (y, dx * scale) } - /// Reverse rule for `exp(x) - 1`. pub fn expm1_rrule(result: S, cotangent: S) -> S { cotangent * (result + one::()).conj() } - +#[doc = "Primal `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2;\n\nassert!((exp2(3.0_f64) - 8.0).abs() < 1e-12);\n```"] +pub fn exp2(x: S) -> S { + x.exp2() +} +#[doc = "Forward rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_frule;\n\nlet (y, dy) = exp2_frule(3.0_f64, 1.0);\nassert!((y - 8.0).abs() < 1e-12);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] +pub fn exp2_frule(x: S, dx: S) -> (S, S) { + let y = x.exp2(); + (y, dx * (y * ln_2::())) +} +#[doc = "Reverse rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_rrule;\n\nlet dy = exp2_rrule(8.0_f64, 1.0);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] +pub fn exp2_rrule(result: S, cotangent: S) -> S { + cotangent * (result * ln_2::()).conj() +} +#[doc = "Primal `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10;\n\nassert!((exp10(2.0_f64) - 100.0).abs() < 1e-12);\n```"] +pub fn exp10(x: S) -> S { + x.exp10() +} +#[doc = "Forward rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_frule;\n\nlet (y, dy) = exp10_frule(2.0_f64, 0.5);\nassert!((y - 100.0).abs() < 1e-12);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"] +pub fn exp10_frule(x: S, dx: S) -> (S, S) { + let y = x.exp10(); + (y, dx * (y * ln_10::())) +} +#[doc = "Reverse rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_rrule;\n\nlet dy = exp10_rrule(100.0_f64, 0.5);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"] +pub fn exp10_rrule(result: S, cotangent: S) -> S { + cotangent * (result * ln_10::()).conj() +} /// Primal `log`. pub fn log(x: S) -> S { x.ln() } - /// Forward rule for `log`. pub fn log_frule(x: S, dx: S) -> (S, S) { let y = x.ln(); - let dy = dx * (one::() / x).conj(); + let dy = dx * (one::() / x); (y, dy) } - /// Reverse rule for `log`. pub fn log_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / x).conj() } - /// Primal `log(1 + x)`. pub fn log1p(x: S) -> S { x.log1p() } - /// Forward rule for `log(1 + x)`. pub fn log1p_frule(x: S, dx: S) -> (S, S) { let y = x.log1p(); - let dy = dx * (one::() / (one::() + x)).conj(); + let dy = dx * (one::() / (one::() + x)); (y, dy) } - /// Reverse rule for `log(1 + x)`. pub fn log1p_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / (one::() + x)).conj() } +#[doc = "Primal `log2`.\n\n# Examples\n```rust\nuse chainrules::log2;\n\nassert_eq!(log2(8.0_f64), 3.0);\n```"] +pub fn log2(x: S) -> S { + x.log2() +} +#[doc = "Forward rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_frule;\n\nlet (y, dy) = log2_frule(8.0_f64, 2.0);\nassert!((y - 3.0).abs() < 1e-12);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"] +pub fn log2_frule(x: S, dx: S) -> (S, S) { + let y = x.log2(); + let scale = one::() / (x * ln_2::()); + (y, dx * scale) +} +#[doc = "Reverse rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_rrule;\n\nlet dy = log2_rrule(8.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"] +pub fn log2_rrule(x: S, cotangent: S) -> S { + cotangent * (one::() / (x * ln_2::())).conj() +} +#[doc = "Primal `log10`.\n\n# Examples\n```rust\nuse chainrules::log10;\n\nassert_eq!(log10(100.0_f64), 2.0);\n```"] +pub fn log10(x: S) -> S { + x.log10() +} +#[doc = "Forward rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_frule;\n\nlet (y, dy) = log10_frule(100.0_f64, 2.0);\nassert!((y - 2.0).abs() < 1e-12);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"] +pub fn log10_frule(x: S, dx: S) -> (S, S) { + let y = x.log10(); + let scale = one::() / (x * ln_10::()); + (y, dx * scale) +} +#[doc = "Reverse rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_rrule;\n\nlet dy = log10_rrule(100.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"] +pub fn log10_rrule(x: S, cotangent: S) -> S { + cotangent * (one::() / (x * ln_10::())).conj() +} diff --git a/crates/chainrules/src/unary/hyperbolic.rs b/crates/chainrules/src/unary/hyperbolic.rs index 5d09e9c..07cfca5 100644 --- a/crates/chainrules/src/unary/hyperbolic.rs +++ b/crates/chainrules/src/unary/hyperbolic.rs @@ -10,7 +10,7 @@ pub fn tanh(x: S) -> S { pub fn tanh_frule(x: S, dx: S) -> (S, S) { let y = x.tanh(); let scale = one::() - y * y; - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `tanh`. @@ -26,7 +26,7 @@ pub fn sinh(x: S) -> S { /// Forward rule for `sinh`. pub fn sinh_frule(x: S, dx: S) -> (S, S) { let y = x.sinh(); - (y, dx * x.cosh().conj()) + (y, dx * x.cosh()) } /// Reverse rule for `sinh`. @@ -42,7 +42,7 @@ pub fn cosh(x: S) -> S { /// Forward rule for `cosh`. pub fn cosh_frule(x: S, dx: S) -> (S, S) { let y = x.cosh(); - (y, dx * x.sinh().conj()) + (y, dx * x.sinh()) } /// Reverse rule for `cosh`. @@ -63,7 +63,7 @@ pub fn asinh(x: S) -> S { pub fn asinh_frule(x: S, dx: S) -> (S, S) { let y = x.asinh(); let scale = inverse_sqrt_one_plus_square(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `asinh`. @@ -84,7 +84,7 @@ pub fn acosh(x: S) -> S { pub fn acosh_frule(x: S, dx: S) -> (S, S) { let y = x.acosh(); let scale = inverse_acosh_scale(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `acosh`. @@ -101,7 +101,7 @@ pub fn atanh(x: S) -> S { pub fn atanh_frule(x: S, dx: S) -> (S, S) { let y = x.atanh(); let scale = one::() / (one::() - x * x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `atanh`. diff --git a/crates/chainrules/src/unary/hyperbolic_extra.rs b/crates/chainrules/src/unary/hyperbolic_extra.rs new file mode 100644 index 0000000..754fa97 --- /dev/null +++ b/crates/chainrules/src/unary/hyperbolic_extra.rs @@ -0,0 +1,141 @@ +use crate::unary::{ + cosh, cosh_frule, cosh_rrule, inv, inv_frule, inv_rrule, sinh, sinh_frule, sinh_rrule, tanh, + tanh_frule, tanh_rrule, +}; +use crate::ScalarAd; + +/// Primal `sech`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sech; +/// +/// assert!((sech(0.5_f64) - 1.0 / 0.5_f64.cosh()).abs() < 1e-12); +/// ``` +pub fn sech(x: S) -> S { + inv(cosh(x)) +} + +/// Forward rule for `sech`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sech_frule; +/// +/// let (_, dy) = sech_frule(0.5_f64, 1.0); +/// let sech_x = 1.0 / 0.5_f64.cosh(); +/// assert!((dy + sech_x * 0.5_f64.tanh()).abs() < 1e-12); +/// ``` +pub fn sech_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = cosh_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `sech`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sech_rrule; +/// +/// let dy = sech_rrule(0.5_f64, 1.0); +/// let sech_x = 1.0 / 0.5_f64.cosh(); +/// assert!((dy + sech_x * 0.5_f64.tanh()).abs() < 1e-12); +/// ``` +pub fn sech_rrule(x: S, cotangent: S) -> S { + let y = sech(x); + let d_y = inv_rrule(y, cotangent); + cosh_rrule(x, d_y) +} + +/// Primal `csch`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::csch; +/// +/// assert!((csch(0.5_f64) - 1.0 / 0.5_f64.sinh()).abs() < 1e-12); +/// ``` +pub fn csch(x: S) -> S { + inv(sinh(x)) +} + +/// Forward rule for `csch`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::csch_frule; +/// +/// let (_, dy) = csch_frule(0.5_f64, 1.0); +/// let csch_x = 1.0 / 0.5_f64.sinh(); +/// assert!((dy + csch_x * 0.5_f64.cosh() / 0.5_f64.sinh()).abs() < 1e-12); +/// ``` +pub fn csch_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = sinh_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `csch`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::csch_rrule; +/// +/// let dy = csch_rrule(0.5_f64, 1.0); +/// let csch_x = 1.0 / 0.5_f64.sinh(); +/// assert!((dy + csch_x * 0.5_f64.cosh() / 0.5_f64.sinh()).abs() < 1e-12); +/// ``` +pub fn csch_rrule(x: S, cotangent: S) -> S { + let y = csch(x); + let d_y = inv_rrule(y, cotangent); + sinh_rrule(x, d_y) +} + +/// Primal `coth`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::coth; +/// +/// assert!((coth(0.5_f64) - 1.0 / 0.5_f64.tanh()).abs() < 1e-12); +/// ``` +pub fn coth(x: S) -> S { + inv(tanh(x)) +} + +/// Forward rule for `coth`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::coth_frule; +/// +/// let (_, dy) = coth_frule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sinh().powi(2)).abs() < 1e-12); +/// ``` +pub fn coth_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = tanh_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `coth`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::coth_rrule; +/// +/// let dy = coth_rrule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sinh().powi(2)).abs() < 1e-12); +/// ``` +pub fn coth_rrule(x: S, cotangent: S) -> S { + let y = coth(x); + let d_y = inv_rrule(y, cotangent); + tanh_rrule(tanh(x), d_y) +} diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index 27488f4..dced019 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -1,7 +1,13 @@ mod basic; +mod complex_parts; mod exp_log; mod hyperbolic; +mod hyperbolic_extra; +mod nonsmooth; +mod roots; +mod smooth; mod trig; +mod trig_extra; use crate::ScalarAd; @@ -10,16 +16,38 @@ fn one() -> S { } pub use basic::{conj, conj_frule, conj_rrule, sqrt, sqrt_frule, sqrt_rrule}; +pub use complex_parts::{ + abs, abs2, abs2_frule, abs2_rrule, angle, angle_rrule, complex, imag, imag_rrule, real, + real_rrule, +}; pub use exp_log::{ - exp, exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, log, log1p, log1p_frule, - log1p_rrule, log_frule, log_rrule, + exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, + expm1, expm1_frule, expm1_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, + log1p_rrule, log2, log2_frule, log2_rrule, log_frule, log_rrule, }; pub use hyperbolic::{ acosh, acosh_frule, acosh_rrule, asinh, asinh_frule, asinh_rrule, atanh, atanh_frule, atanh_rrule, cosh, cosh_frule, cosh_rrule, sinh, sinh_frule, sinh_rrule, tanh, tanh_frule, tanh_rrule, }; +pub use hyperbolic_extra::{ + coth, coth_frule, coth_rrule, csch, csch_frule, csch_rrule, sech, sech_frule, sech_rrule, +}; +pub use nonsmooth::{ + ceil, ceil_frule, ceil_rrule, floor, floor_frule, floor_rrule, round, round_frule, round_rrule, + sign, sign_frule, sign_rrule, +}; +pub use roots::{cbrt, cbrt_frule, cbrt_rrule, inv, inv_frule, inv_rrule}; +pub use smooth::{ + hypot, hypot_frule, hypot_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, + tan, tan_frule, tan_rrule, +}; pub use trig::{ acos, acos_frule, acos_rrule, asin, asin_frule, asin_rrule, atan, atan_frule, atan_rrule, cos, cos_frule, cos_rrule, sin, sin_frule, sin_rrule, }; +pub use trig_extra::{ + cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, csc, + csc_frule, csc_rrule, sec, sec_frule, sec_rrule, sincospi, sincospi_frule, sincospi_rrule, + sind, sind_frule, sind_rrule, sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, +}; diff --git a/crates/chainrules/src/unary/nonsmooth.rs b/crates/chainrules/src/unary/nonsmooth.rs new file mode 100644 index 0000000..7c2c958 --- /dev/null +++ b/crates/chainrules/src/unary/nonsmooth.rs @@ -0,0 +1,204 @@ +use num_traits::Float; + +/// Primal `round`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point, including integer inputs. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::round; +/// +/// assert_eq!(round(1.4_f64), 1.0); +/// assert_eq!(round(1.5_f64), 2.0); +/// ``` +pub fn round(x: R) -> R { + x.round() +} + +/// Forward rule for `round`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::round_frule; +/// +/// let (y, dy) = round_frule(1.6_f64, 0.75); +/// assert_eq!(y, 2.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn round_frule(x: R, _dx: R) -> (R, R) { + (x.round(), R::zero()) +} + +/// Reverse rule for `round`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::round_rrule; +/// +/// assert_eq!(round_rrule(1.0_f64, 0.5), 0.0); +/// ``` +pub fn round_rrule(_x: R, _cotangent: R) -> R { + R::zero() +} + +/// Primal `floor`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::floor; +/// +/// assert_eq!(floor(1.9_f64), 1.0); +/// ``` +pub fn floor(x: R) -> R { + x.floor() +} + +/// Forward rule for `floor`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::floor_frule; +/// +/// let (y, dy) = floor_frule(1.6_f64, 0.75); +/// assert_eq!(y, 1.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn floor_frule(x: R, _dx: R) -> (R, R) { + (x.floor(), R::zero()) +} + +/// Reverse rule for `floor`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::floor_rrule; +/// +/// assert_eq!(floor_rrule(1.0_f64, 0.5), 0.0); +/// ``` +pub fn floor_rrule(_x: R, _cotangent: R) -> R { + R::zero() +} + +/// Primal `ceil`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ceil; +/// +/// assert_eq!(ceil(1.1_f64), 2.0); +/// ``` +pub fn ceil(x: R) -> R { + x.ceil() +} + +/// Forward rule for `ceil`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ceil_frule; +/// +/// let (y, dy) = ceil_frule(1.1_f64, 0.75); +/// assert_eq!(y, 2.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn ceil_frule(x: R, _dx: R) -> (R, R) { + (x.ceil(), R::zero()) +} + +/// Reverse rule for `ceil`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ceil_rrule; +/// +/// assert_eq!(ceil_rrule(1.0_f64, 0.5), 0.0); +/// ``` +pub fn ceil_rrule(_x: R, _cotangent: R) -> R { + R::zero() +} + +/// Primal `sign`. +/// +/// The primal follows Julia-style `sign`: it returns signed zero for zero +/// inputs, `+1`/`-1` for positive/negative infinities, and `x.signum()` +/// otherwise. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sign; +/// +/// assert_eq!(sign(-3.0_f64), -1.0); +/// assert_eq!(sign(0.0_f64), 0.0); +/// assert_eq!(sign(-0.0_f64).is_sign_negative(), true); +/// ``` +pub fn sign(x: R) -> R { + if x == R::zero() { + x + } else { + x.signum() + } +} + +/// Forward rule for `sign`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sign_frule; +/// +/// let (y, dy) = sign_frule(-2.0_f64, 0.75); +/// assert_eq!(y, -1.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn sign_frule(x: R, _dx: R) -> (R, R) { + (sign(x), R::zero()) +} + +/// Reverse rule for `sign`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sign_rrule; +/// +/// assert_eq!(sign_rrule(1.0_f64, 0.5), 0.0); +/// ``` +pub fn sign_rrule(_x: R, _cotangent: R) -> R { + R::zero() +} diff --git a/crates/chainrules/src/unary/roots.rs b/crates/chainrules/src/unary/roots.rs new file mode 100644 index 0000000..b1eed73 --- /dev/null +++ b/crates/chainrules/src/unary/roots.rs @@ -0,0 +1,88 @@ +use crate::ScalarAd; + +/// Primal `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt; +/// +/// assert_eq!(cbrt(8.0_f64), 2.0); +/// ``` +pub fn cbrt(x: S) -> S { + x.cbrt() +} + +/// Forward rule for `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt_frule; +/// +/// let (y, dy) = cbrt_frule(8.0_f64, 1.0); +/// assert_eq!(y, 2.0); +/// assert!((dy - (1.0 / 12.0)).abs() < 1e-12); +/// ``` +pub fn cbrt_frule(x: S, dx: S) -> (S, S) { + let y = x.cbrt(); + let scale = S::from_i32(1) / (S::from_i32(3) * y * y); + (y, dx * scale) +} + +/// Reverse rule for `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt_rrule; +/// +/// let dx = cbrt_rrule(2.0_f64, 1.0); +/// assert!((dx - (1.0 / 12.0)).abs() < 1e-12); +/// ``` +pub fn cbrt_rrule(result: S, cotangent: S) -> S { + cotangent * (S::from_i32(1) / (S::from_i32(3) * result * result)).conj() +} + +/// Primal `inv`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::inv; +/// +/// assert_eq!(inv(4.0_f64), 0.25); +/// ``` +pub fn inv(x: S) -> S { + x.recip() +} + +/// Forward rule for `inv`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::inv_frule; +/// +/// let (y, dy) = inv_frule(4.0_f64, 2.0); +/// assert_eq!(y, 0.25); +/// assert!((dy + 0.125).abs() < 1e-12); +/// ``` +pub fn inv_frule(x: S, dx: S) -> (S, S) { + let y = x.recip(); + (y, dx * (-(y * y))) +} + +/// Reverse rule for `inv`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::inv_rrule; +/// +/// let dx = inv_rrule(0.25_f64, 2.0); +/// assert!((dx + 0.125).abs() < 1e-12); +/// ``` +pub fn inv_rrule(result: S, cotangent: S) -> S { + cotangent * (-(result * result)).conj() +} diff --git a/crates/chainrules/src/unary/smooth.rs b/crates/chainrules/src/unary/smooth.rs new file mode 100644 index 0000000..402d157 --- /dev/null +++ b/crates/chainrules/src/unary/smooth.rs @@ -0,0 +1,10 @@ +#![allow(unused_imports)] + +pub use super::exp_log::{ + exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, log10, log10_frule, log10_rrule, + log2, log2_frule, log2_rrule, +}; +pub use super::roots::{cbrt, cbrt_frule, cbrt_rrule, inv, inv_frule, inv_rrule}; +pub use super::trig::{sincos, sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule}; +pub use crate::binary_special::{hypot, hypot_frule, hypot_rrule}; +pub use crate::power::{pow, pow_frule, pow_rrule}; diff --git a/crates/chainrules/src/unary/trig.rs b/crates/chainrules/src/unary/trig.rs index e5cdf3c..e740b1a 100644 --- a/crates/chainrules/src/unary/trig.rs +++ b/crates/chainrules/src/unary/trig.rs @@ -9,7 +9,7 @@ pub fn sin(x: S) -> S { /// Forward rule for `sin`. pub fn sin_frule(x: S, dx: S) -> (S, S) { let y = x.sin(); - (y, dx * x.cos().conj()) + (y, dx * x.cos()) } /// Reverse rule for `sin`. @@ -25,7 +25,7 @@ pub fn cos(x: S) -> S { /// Forward rule for `cos`. pub fn cos_frule(x: S, dx: S) -> (S, S) { let y = x.cos(); - (y, dx * (-x.sin()).conj()) + (y, dx * -x.sin()) } /// Reverse rule for `cos`. @@ -46,7 +46,7 @@ pub fn asin(x: S) -> S { pub fn asin_frule(x: S, dx: S) -> (S, S) { let y = x.asin(); let scale = inverse_sqrt_one_minus_square(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `asin`. @@ -63,7 +63,7 @@ pub fn acos(x: S) -> S { pub fn acos_frule(x: S, dx: S) -> (S, S) { let y = x.acos(); let scale = -inverse_sqrt_one_minus_square(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `acos`. @@ -80,10 +80,46 @@ pub fn atan(x: S) -> S { pub fn atan_frule(x: S, dx: S) -> (S, S) { let y = x.atan(); let scale = one::() / (one::() + x * x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `atan`. pub fn atan_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / (one::() + x * x)).conj() } + +#[doc = "Primal `tan`.\n\n# Examples\n```rust\nuse chainrules::tan;\n\nassert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12);\n```"] +pub fn tan(x: S) -> S { + x.tan() +} + +#[doc = "Forward rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_frule;\n\nlet (y, dy) = tan_frule(0.25_f64, 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"] +pub fn tan_frule(x: S, dx: S) -> (S, S) { + let y = x.tan(); + (y, dx * (one::() + y * y)) +} + +#[doc = "Reverse rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_rrule;\n\nlet dy = tan_rrule(0.25_f64.tan(), 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"] +pub fn tan_rrule(result: S, cotangent: S) -> S { + cotangent * (one::() + result * result).conj() +} + +#[doc = "Primal `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos;\n\nlet (s, c) = sincos(0.5_f64);\nassert!((s - 0.5_f64.sin()).abs() < 1e-12);\nassert!((c - 0.5_f64.cos()).abs() < 1e-12);\n```"] +pub fn sincos(x: S) -> (S, S) { + (x.sin(), x.cos()) +} + +#[doc = "Forward rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_frule;\n\nlet ((s, c), (ds, dc)) = sincos_frule(0.25_f64, 1.0);\nassert!((ds - 0.25_f64.cos()).abs() < 1e-12);\nassert!((dc + 0.25_f64.sin()).abs() < 1e-12);\n```"] +pub fn sincos_frule(x: S, dx: S) -> ((S, S), (S, S)) { + let sin_x = x.sin(); + let cos_x = x.cos(); + ((sin_x, cos_x), (dx * cos_x, dx * -sin_x)) +} + +#[doc = "Reverse rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_rrule;\n\nlet dx = sincos_rrule(0.25_f64, (1.0, 1.0));\nassert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12);\n```"] +pub fn sincos_rrule(x: S, cotangents: (S, S)) -> S { + let (cotangent_sin, cotangent_cos) = cotangents; + let sin_x = x.sin(); + let cos_x = x.cos(); + cotangent_sin * cos_x.conj() + cotangent_cos * (-sin_x).conj() +} diff --git a/crates/chainrules/src/unary/trig_extra.rs b/crates/chainrules/src/unary/trig_extra.rs new file mode 100644 index 0000000..b6a52d8 --- /dev/null +++ b/crates/chainrules/src/unary/trig_extra.rs @@ -0,0 +1,490 @@ +use crate::binary::{mul_frule, mul_rrule}; +use crate::unary::{ + cos, cos_frule, cos_rrule, inv, inv_frule, inv_rrule, sin, sin_frule, sin_rrule, sincos, tan, + tan_frule, tan_rrule, +}; +use crate::ScalarAd; +use num_traits::{Float, FloatConst, Zero}; + +fn pi() -> S { + S::from_real(S::Real::PI()) +} + +fn real(value: f64) -> R { + match R::from(value) { + Some(value) => value, + None => unreachable!("float constant conversion should succeed"), + } +} + +fn deg2rad() -> S { + pi::() / S::from_real(real::(180.0)) +} + +fn real_input(x: S) -> Option { + if x.imag().is_zero() { + Some(x.real()) + } else { + None + } +} + +fn sinpi_real(x: R) -> R { + let two = real::(2.0); + let reduced = x - (x / two).floor() * two; + let zero = real::(0.0); + let one = real::(1.0); + let half = real::(0.5); + let three_half = real::(1.5); + if reduced == zero || reduced == one { + zero + } else if reduced == half { + one + } else if reduced == three_half { + -one + } else { + (R::PI() * reduced).sin() + } +} + +fn cospi_real(x: R) -> R { + let two = real::(2.0); + let reduced = x - (x / two).floor() * two; + let zero = real::(0.0); + let one = real::(1.0); + let half = real::(0.5); + let three_half = real::(1.5); + if reduced == zero { + one + } else if reduced == one { + -one + } else if reduced == half || reduced == three_half { + zero + } else { + (R::PI() * reduced).cos() + } +} + +fn tand_real(x: R) -> R { + let one_eighty = real::(180.0); + let reduced = x - (x / one_eighty).floor() * one_eighty; + let zero = real::(0.0); + let forty_five = real::(45.0); + let ninety = real::(90.0); + let one_thirty_five = real::(135.0); + if reduced == zero { + zero + } else if reduced == forty_five { + real::(1.0) + } else if reduced == ninety { + R::infinity().copysign(sinpi_real(x / one_eighty)) + } else if reduced == one_thirty_five { + real::(-1.0) + } else { + (R::PI() * reduced / one_eighty).tan() + } +} + +/// Primal `sec`. +/// +/// # Examples +/// ```rust +/// use chainrules::sec; +/// assert!((sec(0.5_f64) - 1.0 / 0.5_f64.cos()).abs() < 1e-12); +/// ``` +pub fn sec(x: S) -> S { + inv(cos(x)) +} + +/// Forward rule for `sec`. +/// +/// # Examples +/// ```rust +/// use chainrules::sec_frule; +/// let (y, dy) = sec_frule(0.5_f64, 1.0); +/// assert!((y - 1.0 / 0.5_f64.cos()).abs() < 1e-12); +/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12); +/// ``` +pub fn sec_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = cos_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `sec`. +/// +/// # Examples +/// ```rust +/// use chainrules::sec_rrule; +/// let dy = sec_rrule(0.5_f64, 1.0); +/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12); +/// ``` +pub fn sec_rrule(x: S, cotangent: S) -> S { + let y = sec(x); + let d_y = inv_rrule(y, cotangent); + cos_rrule(x, d_y) +} + +/// Primal `csc`. +/// +/// # Examples +/// ```rust +/// use chainrules::csc; +/// assert!((csc(0.5_f64) - 1.0 / 0.5_f64.sin()).abs() < 1e-12); +/// ``` +pub fn csc(x: S) -> S { + inv(sin(x)) +} + +/// Forward rule for `csc`. +/// +/// # Examples +/// ```rust +/// use chainrules::csc_frule; +/// let (_, dy) = csc_frule(0.5_f64, 1.0); +/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` +pub fn csc_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = sin_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `csc`. +/// +/// # Examples +/// ```rust +/// use chainrules::csc_rrule; +/// let dy = csc_rrule(0.5_f64, 1.0); +/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` +pub fn csc_rrule(x: S, cotangent: S) -> S { + let y = csc(x); + let d_y = inv_rrule(y, cotangent); + sin_rrule(x, d_y) +} + +/// Primal `cot`. +/// +/// # Examples +/// ```rust +/// use chainrules::cot; +/// assert!((cot(0.5_f64) - 1.0 / 0.5_f64.tan()).abs() < 1e-12); +/// ``` +pub fn cot(x: S) -> S { + inv(tan(x)) +} + +/// Forward rule for `cot`. +/// +/// # Examples +/// ```rust +/// use chainrules::cot_frule; +/// let (_, dy) = cot_frule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` +pub fn cot_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = tan_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `cot`. +/// +/// # Examples +/// ```rust +/// use chainrules::cot_rrule; +/// let dy = cot_rrule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` +pub fn cot_rrule(x: S, cotangent: S) -> S { + let y = cot(x); + let d_y = inv_rrule(y, cotangent); + tan_rrule(tan(x), d_y) +} + +/// Primal `sinpi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sinpi; +/// +/// assert!((sinpi(0.25_f64) - 0.25_f64.mul_add(std::f64::consts::PI, 0.0).sin()).abs() < 1e-12); +/// ``` +pub fn sinpi(x: S) -> S { + if let Some(x_real) = real_input(x) { + return S::from_real(sinpi_real(x_real)); + } + sincos(pi::() * x).0 +} + +/// Forward rule for `sinpi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sinpi_frule; +/// +/// let (_, dy) = sinpi_frule(0.25_f64, 1.0); +/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// ``` +pub fn sinpi_frule(x: S, dx: S) -> (S, S) { + let y = sinpi(x); + let scale = pi::() * cospi(x); + (y, dx * scale) +} + +/// Reverse rule for `sinpi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sinpi_rrule; +/// +/// let dy = sinpi_rrule(0.25_f64, 1.0); +/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// ``` +pub fn sinpi_rrule(x: S, cotangent: S) -> S { + cotangent * (pi::() * cospi(x)).conj() +} + +/// Primal `cospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cospi; +/// +/// assert!((cospi(0.25_f64) - (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// ``` +pub fn cospi(x: S) -> S { + if let Some(x_real) = real_input(x) { + return S::from_real(cospi_real(x_real)); + } + sincos(pi::() * x).1 +} + +/// Forward rule for `cospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cospi_frule; +/// +/// let (_, dy) = cospi_frule(0.25_f64, 1.0); +/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12); +/// ``` +pub fn cospi_frule(x: S, dx: S) -> (S, S) { + let y = cospi(x); + let scale = -(pi::() * sinpi(x)); + (y, dx * scale) +} + +/// Reverse rule for `cospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cospi_rrule; +/// +/// let dy = cospi_rrule(0.25_f64, 1.0); +/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12); +/// ``` +pub fn cospi_rrule(x: S, cotangent: S) -> S { + cotangent * (-(pi::() * sinpi(x))).conj() +} + +/// Primal `sincospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincospi; +/// +/// let (s, c) = sincospi(0.25_f64); +/// assert!((s - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12); +/// assert!((c - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12); +/// ``` +pub fn sincospi(x: S) -> (S, S) { + (sinpi(x), cospi(x)) +} + +/// Forward rule for `sincospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincospi_frule; +/// +/// let ((_, _), (ds, dc)) = sincospi_frule(0.25_f64, 1.0); +/// assert!((ds - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// assert!((dc + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12); +/// ``` +pub fn sincospi_frule(x: S, dx: S) -> ((S, S), (S, S)) { + let sin_x = sinpi(x); + let cos_x = cospi(x); + ( + (sin_x, cos_x), + (dx * (pi::() * cos_x), dx * (-(pi::() * sin_x))), + ) +} + +/// Reverse rule for `sincospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincospi_rrule; +/// +/// let dx = sincospi_rrule(0.25_f64, (1.0, 1.0)); +/// assert!( +/// (dx - (std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos() +/// - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin())) +/// .abs() +/// < 1e-12 +/// ); +/// ``` +pub fn sincospi_rrule(x: S, cotangents: (S, S)) -> S { + let (cotangent_sin, cotangent_cos) = cotangents; + sinpi_rrule(x, cotangent_sin) + cospi_rrule(x, cotangent_cos) +} + +/// Primal `sind`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sind; +/// +/// assert!((sind(30.0_f64) - 0.5_f64).abs() < 1e-12); +/// ``` +pub fn sind(x: S) -> S { + sinpi(x / S::from_real(real::(180.0))) +} + +/// Forward rule for `sind`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sind_frule; +/// +/// let (_, dy) = sind_frule(30.0_f64, 1.0); +/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12); +/// ``` +pub fn sind_frule(x: S, dx: S) -> (S, S) { + let scale = S::from_real(real::(1.0 / 180.0)); + let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); + sinpi_frule(scaled_x, dscaled_x) +} + +/// Reverse rule for `sind`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sind_rrule; +/// +/// let dy = sind_rrule(30.0_f64, 1.0); +/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12); +/// ``` +pub fn sind_rrule(x: S, cotangent: S) -> S { + let scale = S::from_real(real::(1.0 / 180.0)); + let scaled_x = scale * x; + let dscaled_x = sinpi_rrule(scaled_x, cotangent); + let (_, dx) = mul_rrule(scale, x, dscaled_x); + dx +} + +/// Primal `cosd`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cosd; +/// +/// assert!((cosd(60.0_f64) - 0.5_f64).abs() < 1e-12); +/// ``` +pub fn cosd(x: S) -> S { + cospi(x / S::from_real(real::(180.0))) +} + +/// Forward rule for `cosd`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cosd_frule; +/// +/// let (_, dy) = cosd_frule(60.0_f64, 1.0); +/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12); +/// ``` +pub fn cosd_frule(x: S, dx: S) -> (S, S) { + let scale = S::from_real(real::(1.0 / 180.0)); + let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); + cospi_frule(scaled_x, dscaled_x) +} + +/// Reverse rule for `cosd`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cosd_rrule; +/// +/// let dy = cosd_rrule(60.0_f64, 1.0); +/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12); +/// ``` +pub fn cosd_rrule(x: S, cotangent: S) -> S { + let scale = S::from_real(real::(1.0 / 180.0)); + let scaled_x = scale * x; + let dscaled_x = cospi_rrule(scaled_x, cotangent); + let (_, dx) = mul_rrule(scale, x, dscaled_x); + dx +} + +/// Primal `tand`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tand; +/// +/// assert_eq!(tand(45.0_f64), 1.0); +/// ``` +pub fn tand(x: S) -> S { + if let Some(x_real) = real_input(x) { + return S::from_real(tand_real(x_real)); + } + tan(deg2rad::() * x) +} + +/// Forward rule for `tand`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tand_frule; +/// +/// let (_, dy) = tand_frule(45.0_f64, 1.0); +/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12); +/// ``` +pub fn tand_frule(x: S, dx: S) -> (S, S) { + let y = tand(x); + let scale = deg2rad::() * (S::from_i32(1) + y * y); + (y, dx * scale) +} + +/// Reverse rule for `tand`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tand_rrule; +/// +/// let dy = tand_rrule(45.0_f64, 1.0); +/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12); +/// ``` +pub fn tand_rrule(x: S, cotangent: S) -> S { + let y = tand(x); + let scale = deg2rad::() * (S::from_i32(1) + y * y); + cotangent * scale.conj() +} diff --git a/crates/chainrules/tests/common.rs b/crates/chainrules/tests/common.rs new file mode 100644 index 0000000..00b0c81 --- /dev/null +++ b/crates/chainrules/tests/common.rs @@ -0,0 +1,228 @@ +#![allow(dead_code)] + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; + +use num_complex::Complex64; +use serde_json::Value; + +pub struct UnaryOracleCase { + pub op: &'static str, + pub frule: fn(T, T) -> (T, T), + pub rrule: fn(T, T, T) -> T, +} + +pub trait OracleScalar: Copy + core::fmt::Debug { + fn dtype() -> &'static str; + fn is_scalar_value(value: &Value) -> bool; + fn from_json(value: &Value, path: &str) -> Self; + fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str); +} + +impl OracleScalar for f64 { + fn dtype() -> &'static str { + "float64" + } + + fn is_scalar_value(value: &Value) -> bool { + value.is_number() + } + + fn from_json(value: &Value, path: &str) -> Self { + scalar_f64(value, path) + } + + fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str) { + assert_close_f64(actual, expected, atol, rtol, label); + } +} + +impl OracleScalar for Complex64 { + fn dtype() -> &'static str { + "complex128" + } + + fn is_scalar_value(value: &Value) -> bool { + value + .as_array() + .is_some_and(|items| items.len() == 2 && items.iter().all(Value::is_number)) + } + + fn from_json(value: &Value, path: &str) -> Self { + scalar_complex64(value, path) + } + + fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str) { + assert_close_complex64(actual, expected, atol, rtol, label); + } +} + +fn oracle_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .join("third_party") + .join("tensor-ad-oracles") +} + +pub fn successful_cases(op: &str, dtype: &str) -> Vec { + let path = oracle_root().join("cases").join(op).join("identity.jsonl"); + let file = File::open(&path).unwrap_or_else(|err| panic!("open {}: {err}", path.display())); + let reader = BufReader::new(file); + let mut cases = Vec::new(); + + for line in reader.lines() { + let line = line.unwrap_or_else(|err| panic!("read {}: {err}", path.display())); + let value: Value = serde_json::from_str(&line) + .unwrap_or_else(|err| panic!("parse {}: {err}", path.display())); + let case_dtype = value["dtype"].as_str(); + let behavior = value["expected_behavior"].as_str(); + if case_dtype == Some(dtype) && behavior == Some("success") { + cases.push(value); + } + } + + assert!( + !cases.is_empty(), + "no successful {dtype} cases found in {}", + path.display() + ); + cases +} + +pub fn scalar_f64(value: &Value, path: &str) -> f64 { + value + .as_f64() + .unwrap_or_else(|| panic!("expected float64 at {path}, got {value}")) +} + +pub fn scalar_complex64(value: &Value, path: &str) -> Complex64 { + let data = value + .as_array() + .unwrap_or_else(|| panic!("expected complex128 pair at {path}, got {value}")); + assert!( + data.len() == 2, + "expected complex128 pair at {path}, got {value}" + ); + Complex64::new(scalar_f64(&data[0], path), scalar_f64(&data[1], path)) +} + +pub fn assert_close_f64(actual: f64, expected: f64, atol: f64, rtol: f64, label: &str) { + let tol = atol + rtol * expected.abs().max(actual.abs()); + assert!( + (actual - expected).abs() <= tol, + "{label}: actual={actual}, expected={expected}, atol={atol}, rtol={rtol}", + ); +} + +pub fn assert_close_complex64( + actual: Complex64, + expected: Complex64, + atol: f64, + rtol: f64, + label: &str, +) { + let tol = atol + rtol * expected.norm().max(actual.norm()); + assert!( + (actual - expected).norm() <= tol, + "{label}: actual={actual:?}, expected={expected:?}, atol={atol}, rtol={rtol}", + ); +} + +fn collect_scalar_values(value: &Value, path: &str, out: &mut Vec) { + if T::is_scalar_value(value) { + out.push(T::from_json(value, path)); + return; + } + + let items = value + .as_array() + .unwrap_or_else(|| panic!("expected array at {path}, got {value}")); + for (index, item) in items.iter().enumerate() { + collect_scalar_values::(item, &format!("{path}[{index}]"), out); + } +} + +fn scalar_values(value: &Value, path: &str) -> Vec { + let mut out = Vec::new(); + collect_scalar_values(value, path, &mut out); + out +} + +pub fn run_unary_oracle_cases(cases: &[UnaryOracleCase]) { + for case in cases { + for (case_index, oracle) in successful_cases(case.op, T::dtype()) + .into_iter() + .enumerate() + { + let inputs = scalar_values::(&oracle["inputs"]["a"]["data"], "inputs.a.data"); + let probes = oracle["probes"] + .as_array() + .unwrap_or_else(|| panic!("expected probes array for {}", case.op)); + let atol = scalar_f64( + &oracle["comparison"]["first_order"]["atol"], + "comparison.first_order.atol", + ); + let rtol = scalar_f64( + &oracle["comparison"]["first_order"]["rtol"], + "comparison.first_order.rtol", + ); + + for (probe_index, probe) in probes.iter().enumerate() { + let tangents = scalar_values::( + &probe["direction"]["a"]["data"], + &format!("probes[{probe_index}].direction.a.data"), + ); + let cotangents = scalar_values::( + &probe["cotangent"]["value"]["data"], + &format!("probes[{probe_index}].cotangent.value.data"), + ); + let expected_jvps = scalar_values::( + &probe["pytorch_ref"]["jvp"]["value"]["data"], + &format!("probes[{probe_index}].pytorch_ref.jvp.value.data"), + ); + let expected_vjps = scalar_values::( + &probe["pytorch_ref"]["vjp"]["a"]["data"], + &format!("probes[{probe_index}].pytorch_ref.vjp.a.data"), + ); + + assert_eq!( + inputs.len(), + tangents.len(), + "{} case {case_index} probe {probe_index}: input and tangent lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + cotangents.len(), + "{} case {case_index} probe {probe_index}: input and cotangent lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + expected_jvps.len(), + "{} case {case_index} probe {probe_index}: input and expected jvp lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + expected_vjps.len(), + "{} case {case_index} probe {probe_index}: input and expected vjp lengths differ", + case.op + ); + + for index in 0..inputs.len() { + let (result, actual_jvp) = (case.frule)(inputs[index], tangents[index]); + let actual_vjp = (case.rrule)(inputs[index], result, cotangents[index]); + let label = format!( + "{} case {case_index} probe {probe_index} element {index}", + case.op + ); + T::assert_close(actual_jvp, expected_jvps[index], atol, rtol, &label); + T::assert_close(actual_vjp, expected_vjps[index], atol, rtol, &label); + } + } + } + } +} diff --git a/crates/chainrules/tests/complex_helper_tests.rs b/crates/chainrules/tests/complex_helper_tests.rs new file mode 100644 index 0000000..8b8dfc9 --- /dev/null +++ b/crates/chainrules/tests/complex_helper_tests.rs @@ -0,0 +1,70 @@ +mod common; + +use chainrules::{ + abs, abs2, abs2_frule, abs2_rrule, angle, angle_rrule, complex, imag, imag_rrule, real, + real_rrule, +}; +use num_complex::Complex64; + +use common::{assert_close_complex64, assert_close_f64}; + +#[test] +fn complex_helpers_match_expected_formulas() { + let x = 3.0_f64; + let z = Complex64::new(3.0, 4.0); + let dz = Complex64::new(1.0, -2.0); + + let constructed: Complex64 = complex(3.0, 4.0); + assert_eq!(constructed, z); + assert_close_f64(abs(x), 3.0, 1.0e-12, 0.0, "abs(x)"); + assert_close_f64(abs2(x), 9.0, 1.0e-12, 0.0, "abs2(x)"); + assert_close_f64(real(x), 3.0, 1.0e-12, 0.0, "real(x)"); + assert_close_f64(imag(x), 0.0, 1.0e-12, 0.0, "imag(x)"); + assert_close_f64(angle(x), 0.0_f64.atan2(x), 1.0e-12, 0.0, "angle(x)"); + assert_close_f64(abs(z), 5.0, 1.0e-12, 0.0, "abs(z)"); + assert_close_f64(abs2(z), 25.0, 1.0e-12, 0.0, "abs2(z)"); + assert_close_f64(real(z), 3.0, 1.0e-12, 0.0, "real(z)"); + assert_close_f64(imag(z), 4.0, 1.0e-12, 0.0, "imag(z)"); + assert_close_f64(angle(z), z.arg(), 1.0e-12, 0.0, "angle(z)"); + + let (abs2_y, abs2_dy) = abs2_frule(z, dz); + assert_close_f64(abs2_y, 25.0, 1.0e-12, 0.0, "abs2.y"); + assert_close_f64( + abs2_dy, + 2.0 * (z.re * dz.re + z.im * dz.im), + 1.0e-12, + 0.0, + "abs2.dy", + ); + + assert_close_complex64( + abs2_rrule(z, 1.25), + Complex64::new(7.5, 10.0), + 1.0e-12, + 0.0, + "abs2.rrule", + ); + let real_grad: Complex64 = real_rrule(2.0); + assert_close_complex64( + real_grad, + Complex64::new(2.0, 0.0), + 1.0e-12, + 0.0, + "real.rrule", + ); + let imag_grad: Complex64 = imag_rrule(2.0); + assert_close_complex64( + imag_grad, + Complex64::new(0.0, 2.0), + 1.0e-12, + 0.0, + "imag.rrule", + ); + assert_close_complex64( + angle_rrule(z, 1.0), + Complex64::new(-0.16, 0.12), + 1.0e-12, + 0.0, + "angle.rrule", + ); +} diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs new file mode 100644 index 0000000..51680c0 --- /dev/null +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -0,0 +1,388 @@ +mod common; + +use chainrules::{ + cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, coth, + coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, csch_frule, csch_rrule, sec, + sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sincospi, sincospi_frule, sincospi_rrule, + sind, sind_frule, sind_rrule, sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, +}; +use common::{assert_close_complex64, assert_close_f64}; +use num_complex::Complex64; + +#[test] +fn julia_compat_landmark_real_inputs_match_julia_style_values() { + assert_eq!(sinpi(1.0_f64), 0.0_f64); + assert_eq!(sinpi(0.5_f64), 1.0_f64); + assert_eq!(cospi(0.5_f64), 0.0_f64); + let (s, c) = sincospi(0.5_f64); + assert_eq!(s, 1.0_f64); + assert_eq!(c, 0.0_f64); + assert_eq!(sind(180.0_f64), 0.0_f64); + assert_eq!(cosd(90.0_f64), 0.0_f64); + assert_eq!(tand(45.0_f64), 1.0_f64); + assert_eq!(tand(90.0_f64), f64::INFINITY); + assert_eq!(tand(-90.0_f64), f64::NEG_INFINITY); + assert_eq!(tand(270.0_f64), f64::NEG_INFINITY); +} + +#[test] +fn julia_compat_primal_helpers_match_expected_values() { + let x = 0.25_f64; + assert_close_f64(sec(x), 1.0 / x.cos(), 1e-12, 0.0, "sec"); + assert_close_f64(csc(x), 1.0 / x.sin(), 1e-12, 0.0, "csc"); + assert_close_f64(cot(x), 1.0 / x.tan(), 1e-12, 0.0, "cot"); + assert_close_f64( + sinpi(x), + (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sinpi", + ); + assert_close_f64( + cospi(x), + (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "cospi", + ); + let (s, c) = sincospi(x); + assert_close_f64( + s, + (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sincospi.sin", + ); + assert_close_f64( + c, + (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sincospi.cos", + ); + assert_close_f64(sind(30.0_f64), 0.5_f64, 1e-12, 0.0, "sind"); + assert_close_f64(cosd(60.0_f64), 0.5_f64, 1e-12, 0.0, "cosd"); + assert_close_f64(tand(45.0_f64), 1.0_f64, 1e-12, 0.0, "tand"); + assert_close_f64(sech(x), 1.0 / x.cosh(), 1e-12, 0.0, "sech"); + assert_close_f64(csch(x), 1.0 / x.sinh(), 1e-12, 0.0, "csch"); + assert_close_f64(coth(x), 1.0 / x.tanh(), 1e-12, 0.0, "coth"); +} + +#[test] +fn julia_compat_derivative_helpers_match_expected_values() { + let x = 0.25_f64; + let g = 1.0_f64; + + let (_, dsec) = sec_frule(x, g); + assert_close_f64(dsec, x.sin() / x.cos().powi(2), 1e-12, 0.0, "sec_frule"); + assert_close_f64( + sec_rrule(x, g), + x.sin() / x.cos().powi(2), + 1e-12, + 0.0, + "sec_rrule", + ); + + let (_, dsinpi_landmark) = sinpi_frule(1.0_f64, g); + assert_close_f64( + dsinpi_landmark, + -std::f64::consts::PI, + 1e-12, + 0.0, + "sinpi_frule landmark", + ); + assert_close_f64( + sinpi_rrule(1.0_f64, g), + -std::f64::consts::PI, + 1e-12, + 0.0, + "sinpi_rrule landmark", + ); + + let (_, dcospi_landmark) = cospi_frule(0.5_f64, g); + assert_close_f64( + dcospi_landmark, + -std::f64::consts::PI, + 1e-12, + 0.0, + "cospi_frule landmark", + ); + assert_close_f64( + cospi_rrule(0.5_f64, g), + -std::f64::consts::PI, + 1e-12, + 0.0, + "cospi_rrule landmark", + ); + + let (_, dsinpi) = sinpi_frule(x, g); + assert_close_f64( + dsinpi, + std::f64::consts::PI * (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sinpi_frule", + ); + assert_close_f64( + sinpi_rrule(x, g), + std::f64::consts::PI * (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sinpi_rrule", + ); + + let (_, dtand) = tand_frule(45.0_f64, g); + assert_close_f64( + dtand, + std::f64::consts::PI / 180.0 * 2.0, + 1e-12, + 0.0, + "tand_frule", + ); + assert_close_f64( + tand_rrule(45.0_f64, g), + std::f64::consts::PI / 180.0 * 2.0, + 1e-12, + 0.0, + "tand_rrule", + ); + + let (_, dsech) = sech_frule(x, g); + let sech_x: f64 = 1.0 / x.cosh(); + assert_close_f64(dsech, -sech_x * x.tanh(), 1e-12, 0.0, "sech_frule"); + assert_close_f64( + sech_rrule(x, g), + -sech_x * x.tanh(), + 1e-12, + 0.0, + "sech_rrule", + ); + + let (_, dcsc) = csc_frule(x, g); + assert_close_f64(dcsc, -(x.cos() / x.sin().powi(2)), 1e-12, 0.0, "csc_frule"); + assert_close_f64( + csc_rrule(x, g), + -(x.cos() / x.sin().powi(2)), + 1e-12, + 0.0, + "csc_rrule", + ); + + let (_, dcot) = cot_frule(x, g); + assert_close_f64(dcot, -(1.0 / x.sin().powi(2)), 1e-12, 0.0, "cot_frule"); + assert_close_f64( + cot_rrule(x, g), + -(1.0 / x.sin().powi(2)), + 1e-12, + 0.0, + "cot_rrule", + ); + + let (_, dcsch) = csch_frule(x, g); + let csch_x: f64 = 1.0 / x.sinh(); + assert_close_f64( + dcsch, + -csch_x * x.cosh() / x.sinh(), + 1e-12, + 0.0, + "csch_frule", + ); + assert_close_f64( + csch_rrule(x, g), + -csch_x * x.cosh() / x.sinh(), + 1e-12, + 0.0, + "csch_rrule", + ); + + let (_, dcoth) = coth_frule(x, g); + assert_close_f64(dcoth, -(1.0 / x.sinh().powi(2)), 1e-12, 0.0, "coth_frule"); + assert_close_f64( + coth_rrule(x, g), + -(1.0 / x.sinh().powi(2)), + 1e-12, + 0.0, + "coth_rrule", + ); + + let (_, dsind) = sind_frule(30.0_f64, g); + assert_close_f64( + dsind, + std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos(), + 1e-12, + 0.0, + "sind_frule", + ); + assert_close_f64( + sind_rrule(30.0_f64, g), + std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos(), + 1e-12, + 0.0, + "sind_rrule", + ); + + let (_, dcospi) = cospi_frule(x, g); + assert_close_f64( + dcospi, + -std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "cospi_frule", + ); + assert_close_f64( + cospi_rrule(x, g), + -std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "cospi_rrule", + ); + + let (_, dcosd) = cosd_frule(60.0_f64, g); + assert_close_f64( + dcosd, + -std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin(), + 1e-12, + 0.0, + "cosd_frule", + ); + assert_close_f64( + cosd_rrule(60.0_f64, g), + -std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin(), + 1e-12, + 0.0, + "cosd_rrule", + ); + + let (_, dsincospi) = sincospi_frule(x, g); + assert_close_f64( + dsincospi.0, + std::f64::consts::PI * (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sincospi_frule.sin", + ); + assert_close_f64( + dsincospi.1, + -std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sincospi_frule.cos", + ); + assert_close_f64( + sincospi_rrule(x, (g, g)), + std::f64::consts::PI * (std::f64::consts::PI * x).cos() + - std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sincospi_rrule", + ); +} + +#[test] +fn julia_compat_helpers_cover_complex_primal_forward_and_cotangent_paths() { + let z = Complex64::new(0.25, -0.1); + let dz = Complex64::new(1.0, -0.25); + let cotangent = Complex64::new(0.75, -0.5); + let pi_z = Complex64::new(std::f64::consts::PI, 0.0) * z; + + assert_close_complex64(sinpi(z), pi_z.sin(), 1e-12, 0.0, "sinpi(z)"); + assert_close_complex64(cospi(z), pi_z.cos(), 1e-12, 0.0, "cospi(z)"); + + let (_, dsinpi) = sinpi_frule(z, dz); + assert_close_complex64( + dsinpi, + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), + 1e-12, + 0.0, + "sinpi_frule(z)", + ); + let (_, dcospi) = cospi_frule(z, dz); + assert_close_complex64( + dcospi, + dz * (-(Complex64::new(std::f64::consts::PI, 0.0) * pi_z.sin())), + 1e-12, + 0.0, + "cospi_frule(z)", + ); + let ((_, _), (dsinpi_pair, dcospi_pair)) = sincospi_frule(z, dz); + assert_close_complex64( + dsinpi_pair, + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), + 1e-12, + 0.0, + "sincospi_frule.sin(z)", + ); + assert_close_complex64( + dcospi_pair, + dz * (-(Complex64::new(std::f64::consts::PI, 0.0) * pi_z.sin())), + 1e-12, + 0.0, + "sincospi_frule.cos(z)", + ); + let (_, dtand) = tand_frule(z, dz); + let tand_z = (Complex64::new(std::f64::consts::PI / 180.0, 0.0) * z).tan(); + assert_close_complex64( + dtand, + dz * (Complex64::new(std::f64::consts::PI / 180.0, 0.0) + * (Complex64::new(1.0, 0.0) + tand_z * tand_z)), + 1e-12, + 0.0, + "tand_frule(z)", + ); + + assert_close_complex64( + sec_rrule(z, cotangent), + cotangent * (z.sin() / z.cos().powi(2)).conj(), + 1e-12, + 0.0, + "sec_rrule(z)", + ); + assert_close_complex64( + sinpi_rrule(z, cotangent), + cotangent * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()).conj(), + 1e-12, + 0.0, + "sinpi_rrule(z)", + ); + let (_, dcot) = cot_frule(z, dz); + assert_close_complex64( + dcot, + dz * (-(Complex64::new(1.0, 0.0) / z.sin().powi(2))), + 1e-12, + 0.0, + "cot_frule(z)", + ); + let (_, dsech) = sech_frule(z, dz); + let sech_z = sech(z); + assert_close_complex64( + dsech, + dz * (-(sech_z * z.tanh())), + 1e-12, + 0.0, + "sech_frule(z)", + ); +} + +#[test] +fn julia_compat_complex_inputs_cover_generic_surface() { + let z = Complex64::new(0.25, -0.5); + + let pi_z = Complex64::new(std::f64::consts::PI, 0.0) * z; + assert_close_complex64(sinpi(z), pi_z.sin(), 1e-12, 0.0, "sinpi(z)"); + assert_close_complex64(cospi(z), pi_z.cos(), 1e-12, 0.0, "cospi(z)"); + let (s, c) = sincospi(z); + assert_close_complex64(s, pi_z.sin(), 1e-12, 0.0, "sincospi.sin(z)"); + assert_close_complex64(c, pi_z.cos(), 1e-12, 0.0, "sincospi.cos(z)"); + + let deg_z = Complex64::new(std::f64::consts::PI / 180.0, 0.0) * z; + assert_close_complex64(sind(z), deg_z.sin(), 1e-12, 0.0, "sind(z)"); + assert_close_complex64(cosd(z), deg_z.cos(), 1e-12, 0.0, "cosd(z)"); + assert_close_complex64(tand(z), deg_z.tan(), 1e-12, 0.0, "tand(z)"); + assert_close_complex64( + coth(z), + Complex64::new(1.0, 0.0) / z.tanh(), + 1e-12, + 0.0, + "coth(z)", + ); +} diff --git a/crates/chainrules/tests/nonsmooth_scalar_tests.rs b/crates/chainrules/tests/nonsmooth_scalar_tests.rs new file mode 100644 index 0000000..6636b40 --- /dev/null +++ b/crates/chainrules/tests/nonsmooth_scalar_tests.rs @@ -0,0 +1,176 @@ +use chainrules::{ + ceil, ceil_frule, ceil_rrule, floor, floor_frule, floor_rrule, max, max_frule, max_rrule, min, + min_frule, min_rrule, round, round_frule, round_rrule, sign, sign_frule, sign_rrule, +}; +use num_traits::Float; + +fn assert_close(actual: T, expected: T) +where + T: core::fmt::Debug + PartialEq, +{ + assert_eq!(actual, expected); +} + +fn assert_zero(actual: T) +where + T: core::fmt::Debug + PartialEq + Float, +{ + assert_eq!(actual, T::zero()); +} + +fn assert_negative_zero(actual: T) +where + T: core::fmt::Debug + PartialEq + Float, +{ + assert_eq!(actual, T::zero()); + assert!( + actual.is_sign_negative(), + "expected negative zero, got {actual:?}" + ); +} + +fn cast(value: f32) -> T +where + T: Float, +{ + T::from(value).expect("cast to float") +} + +fn check_nonsmooth_scalar_rules() +where + T: Copy + core::fmt::Debug + PartialEq + Float, +{ + let x = cast::(1.6_f32); + let y = cast::(-2.4_f32); + let zero = cast::(0.0_f32); + let neg_zero = cast::(-0.0_f32); + let inf = T::infinity(); + let neg_inf = T::neg_infinity(); + + assert_close(round(x), cast::(2.0_f32)); + assert_close(floor(x), cast::(1.0_f32)); + assert_close(ceil(y), cast::(-2.0_f32)); + assert_close(sign(y), cast::(-1.0_f32)); + assert_zero(sign(zero)); + assert_negative_zero(sign(neg_zero)); + assert_close(sign(inf), T::one()); + assert_close(sign(neg_inf), -T::one()); + + let (round_y, round_dy) = round_frule(x, cast::(7.0_f32)); + assert_close(round_y, cast::(2.0_f32)); + assert_zero(round_dy); + assert_zero(round_rrule(x, cast::(7.0_f32))); + + let (floor_y, floor_dy) = floor_frule(x, cast::(5.0_f32)); + assert_close(floor_y, cast::(1.0_f32)); + assert_zero(floor_dy); + assert_zero(floor_rrule(x, cast::(5.0_f32))); + + let (ceil_y, ceil_dy) = ceil_frule(y, cast::(11.0_f32)); + assert_close(ceil_y, cast::(-2.0_f32)); + assert_zero(ceil_dy); + assert_zero(ceil_rrule(y, cast::(11.0_f32))); + + let (sign_y, sign_dy) = sign_frule(y, cast::(3.0_f32)); + assert_close(sign_y, cast::(-1.0_f32)); + assert_zero(sign_dy); + assert_zero(sign_rrule(y, cast::(3.0_f32))); + + let (min_y, min_dy) = min_frule( + cast::(1.0_f32), + cast::(2.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close( + min(cast::(1.0_f32), cast::(2.0_f32)), + cast::(1.0_f32), + ); + assert_close(min_y, cast::(1.0_f32)); + assert_close(min_dy, cast::(4.0_f32)); + let (min_tie_y, min_tie_dy) = min_frule( + cast::(3.0_f32), + cast::(3.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close(min_tie_y, cast::(3.0_f32)); + assert_close(min_tie_dy, cast::(8.0_f32)); + let (min_dx, min_dy) = min_rrule(cast::(3.0_f32), cast::(3.0_f32), cast::(6.0_f32)); + assert_zero(min_dx); + assert_close(min_dy, cast::(6.0_f32)); + let (min_dx, min_dy) = min_rrule(cast::(2.0_f32), cast::(3.0_f32), cast::(5.0_f32)); + assert_close(min_dx, cast::(5.0_f32)); + assert_zero(min_dy); + let (min_tie_y, min_tie_dy) = min_frule(neg_zero, zero, cast::(4.0_f32), cast::(8.0_f32)); + assert_zero(min_tie_y); + assert_close(min_tie_dy, cast::(8.0_f32)); + let (min_dx, min_dy) = min_rrule(neg_zero, zero, cast::(5.0_f32)); + assert_zero(min_dx); + assert_close(min_dy, cast::(5.0_f32)); + let (min_nan_y, min_nan_dy) = min_frule( + cast::(2.0_f32), + T::nan(), + cast::(6.0_f32), + cast::(9.0_f32), + ); + assert_close(min_nan_y, cast::(2.0_f32)); + assert_close(min_nan_dy, cast::(6.0_f32)); + let (min_dx, min_dy) = min_rrule(cast::(2.0_f32), T::nan(), cast::(5.0_f32)); + assert_close(min_dx, cast::(5.0_f32)); + assert_zero(min_dy); + + let (max_y, max_dy) = max_frule( + cast::(1.0_f32), + cast::(2.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close( + max(cast::(1.0_f32), cast::(2.0_f32)), + cast::(2.0_f32), + ); + assert_close(max_y, cast::(2.0_f32)); + assert_close(max_dy, cast::(8.0_f32)); + let (max_tie_y, max_tie_dy) = max_frule( + cast::(3.0_f32), + cast::(3.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close(max_tie_y, cast::(3.0_f32)); + assert_close(max_tie_dy, cast::(8.0_f32)); + let (max_dx, max_dy) = max_rrule(cast::(3.0_f32), cast::(3.0_f32), cast::(6.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(6.0_f32)); + let (max_dx, max_dy) = max_rrule(cast::(2.0_f32), cast::(3.0_f32), cast::(5.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(5.0_f32)); + let (max_tie_y, max_tie_dy) = max_frule(neg_zero, zero, cast::(4.0_f32), cast::(8.0_f32)); + assert_zero(max_tie_y); + assert_close(max_tie_dy, cast::(8.0_f32)); + let (max_dx, max_dy) = max_rrule(neg_zero, zero, cast::(5.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(5.0_f32)); + let (max_nan_y, max_nan_dy) = max_frule( + T::nan(), + cast::(3.0_f32), + cast::(6.0_f32), + cast::(9.0_f32), + ); + assert_close(max_nan_y, cast::(3.0_f32)); + assert_close(max_nan_dy, cast::(9.0_f32)); + let (max_dx, max_dy) = max_rrule(T::nan(), cast::(3.0_f32), cast::(5.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(5.0_f32)); +} + +#[test] +fn nonsmooth_scalar_rules_match_expected_policy_for_f64() { + check_nonsmooth_scalar_rules::(); +} + +#[test] +fn nonsmooth_scalar_rules_match_expected_policy_for_f32() { + check_nonsmooth_scalar_rules::(); +} diff --git a/crates/chainrules/tests/oracle_scalar_rules.rs b/crates/chainrules/tests/oracle_scalar_rules.rs index 2a9aefd..0f3a01b 100644 --- a/crates/chainrules/tests/oracle_scalar_rules.rs +++ b/crates/chainrules/tests/oracle_scalar_rules.rs @@ -1,181 +1,138 @@ -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::path::PathBuf; +mod common; use chainrules::{ acos_frule, acos_rrule, acosh_frule, acosh_rrule, asin_frule, asin_rrule, asinh_frule, asinh_rrule, atan_frule, atan_rrule, atanh_frule, atanh_rrule, cos_frule, cos_rrule, - cosh_frule, cosh_rrule, exp_frule, exp_rrule, expm1_frule, expm1_rrule, log1p_frule, - log1p_rrule, log_frule, log_rrule, sin_frule, sin_rrule, sinh_frule, sinh_rrule, sqrt_frule, - sqrt_rrule, tanh_frule, tanh_rrule, + cosh_frule, cosh_rrule, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1_frule, expm1_rrule, + log1p_frule, log1p_rrule, log2_frule, log2_rrule, log_frule, log_rrule, sin_frule, sin_rrule, + sinh_frule, sinh_rrule, sqrt_frule, sqrt_rrule, tan_frule, tan_rrule, tanh_frule, tanh_rrule, }; -use serde_json::Value; +use num_complex::Complex64; -struct UnaryRuleCase { - op: &'static str, - frule: fn(f64, f64) -> (f64, f64), - rrule: fn(f64, f64, f64) -> f64, -} - -fn oracle_root() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("..") - .join("..") - .join("third_party") - .join("tensor-ad-oracles") -} - -fn first_successful_float64_case(op: &str) -> Value { - let path = oracle_root().join("cases").join(op).join("identity.jsonl"); - let file = File::open(&path).unwrap_or_else(|err| panic!("open {}: {err}", path.display())); - let reader = BufReader::new(file); - - for line in reader.lines() { - let line = line.unwrap_or_else(|err| panic!("read {}: {err}", path.display())); - let value: Value = serde_json::from_str(&line) - .unwrap_or_else(|err| panic!("parse {}: {err}", path.display())); - let dtype = value["dtype"].as_str(); - let behavior = value["expected_behavior"].as_str(); - if dtype == Some("float64") && behavior == Some("success") { - return value; - } - } - - panic!("no successful float64 case found in {}", path.display()); -} - -fn scalar(value: &Value, path: &str) -> f64 { - value - .as_f64() - .unwrap_or_else(|| panic!("expected float64 at {path}, got {value}")) -} - -fn assert_close(actual: f64, expected: f64, atol: f64, rtol: f64, label: &str) { - let tol = atol + rtol * expected.abs().max(actual.abs()); - assert!( - (actual - expected).abs() <= tol, - "{label}: actual={actual}, expected={expected}, atol={atol}, rtol={rtol}", - ); -} +use common::{run_unary_oracle_cases, UnaryOracleCase}; #[test] fn published_float64_oracles_match_unary_rule_entrypoints() { - let cases = [ - UnaryRuleCase { + let cases: [UnaryOracleCase; 19] = [ + UnaryOracleCase { op: "sqrt", frule: sqrt_frule, - rrule: |x, _result, cotangent| sqrt_rrule(x.sqrt(), cotangent), + rrule: |x: f64, _result, cotangent| sqrt_rrule(x.sqrt(), cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "exp", frule: exp_frule, - rrule: |_x, result, cotangent| exp_rrule(result, cotangent), + rrule: |_x: f64, result, cotangent| exp_rrule(result, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "expm1", frule: expm1_frule, - rrule: |_x, result, cotangent| expm1_rrule(result, cotangent), + rrule: |_x: f64, result, cotangent| expm1_rrule(result, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "log", frule: log_frule, - rrule: |x, _result, cotangent| log_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| log_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "log1p", frule: log1p_frule, - rrule: |x, _result, cotangent| log1p_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| log1p_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "sin", frule: sin_frule, - rrule: |x, _result, cotangent| sin_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| sin_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "cos", frule: cos_frule, - rrule: |x, _result, cotangent| cos_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| cos_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "tanh", frule: tanh_frule, - rrule: |_x, result, cotangent| tanh_rrule(result, cotangent), + rrule: |_x: f64, result, cotangent| tanh_rrule(result, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "asin", frule: asin_frule, - rrule: |x, _result, cotangent| asin_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| asin_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "acos", frule: acos_frule, - rrule: |x, _result, cotangent| acos_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| acos_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "atan", frule: atan_frule, - rrule: |x, _result, cotangent| atan_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| atan_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "sinh", frule: sinh_frule, - rrule: |x, _result, cotangent| sinh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| sinh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "cosh", frule: cosh_frule, - rrule: |x, _result, cotangent| cosh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| cosh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "asinh", frule: asinh_frule, - rrule: |x, _result, cotangent| asinh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| asinh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "acosh", frule: acosh_frule, - rrule: |x, _result, cotangent| acosh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| acosh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "atanh", frule: atanh_frule, - rrule: |x, _result, cotangent| atanh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| atanh_rrule(x, cotangent), + }, + UnaryOracleCase { + op: "tan", + frule: tan_frule, + rrule: |_x: f64, result, cotangent| tan_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "exp2", + frule: exp2_frule, + rrule: |_x: f64, result, cotangent| exp2_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "log2", + frule: log2_frule, + rrule: |x: f64, _result, cotangent| log2_rrule(x, cotangent), }, ]; - for case in cases { - let oracle = first_successful_float64_case(case.op); - let input = scalar(&oracle["inputs"]["a"]["data"][0], "inputs.a.data[0]"); - let probe = &oracle["probes"][0]; - let tangent = scalar( - &probe["direction"]["a"]["data"][0], - "probes[0].direction.a.data[0]", - ); - let cotangent = scalar( - &probe["cotangent"]["value"]["data"][0], - "probes[0].cotangent.value.data[0]", - ); - let expected_jvp = scalar( - &probe["pytorch_ref"]["jvp"]["value"]["data"][0], - "probes[0].pytorch_ref.jvp.value.data[0]", - ); - let expected_vjp = scalar( - &probe["pytorch_ref"]["vjp"]["a"]["data"][0], - "probes[0].pytorch_ref.vjp.a.data[0]", - ); - let atol = scalar( - &oracle["comparison"]["first_order"]["atol"], - "comparison.first_order.atol", - ); - let rtol = scalar( - &oracle["comparison"]["first_order"]["rtol"], - "comparison.first_order.rtol", - ); + run_unary_oracle_cases(&cases); +} - let (result, actual_jvp) = (case.frule)(input, tangent); - let actual_vjp = (case.rrule)(input, result, cotangent); +#[test] +fn published_complex128_oracles_match_unary_rule_entrypoints() { + let cases: [UnaryOracleCase; 3] = [ + UnaryOracleCase { + op: "tan", + frule: tan_frule, + rrule: |_x: Complex64, result, cotangent| tan_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "exp2", + frule: exp2_frule, + rrule: |_x: Complex64, result, cotangent| exp2_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "log2", + frule: log2_frule, + rrule: |x: Complex64, _result, cotangent| log2_rrule(x, cotangent), + }, + ]; - assert_close(actual_jvp, expected_jvp, atol, rtol, case.op); - assert_close(actual_vjp, expected_vjp, atol, rtol, case.op); - } + run_unary_oracle_cases(&cases); } diff --git a/crates/chainrules/tests/scalarops_tests.rs b/crates/chainrules/tests/scalarops_tests.rs index 11216bf..8ae3084 100644 --- a/crates/chainrules/tests/scalarops_tests.rs +++ b/crates/chainrules/tests/scalarops_tests.rs @@ -1,18 +1,10 @@ use chainrules::{ - add, add_frule, add_rrule, conj, conj_frule, conj_rrule, div, div_frule, div_rrule, - handle_r_to_c_f32, handle_r_to_c_f64, mul, mul_frule, mul_rrule, powf, powf_frule, powf_rrule, - powi, powi_frule, powi_rrule, sqrt, sqrt_frule, sqrt_rrule, sub, sub_frule, sub_rrule, + add, add_frule, add_rrule, conj, conj_frule, conj_rrule, div, div_frule, div_rrule, mul, + mul_frule, mul_rrule, powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule, sqrt, + sqrt_frule, sqrt_rrule, sub, sub_frule, sub_rrule, }; use num_complex::{Complex32, Complex64}; -#[test] -fn handle_r_to_c_projects_real_part() { - let g32 = Complex32::new(1.25, -9.0); - let g64 = Complex64::new(-3.5, 2.0); - assert_eq!(handle_r_to_c_f32(g32), 1.25_f32); - assert_eq!(handle_r_to_c_f64(g64), -3.5_f64); -} - #[test] fn conj_rules_match_formula_complex64() { let x = Complex64::new(2.0, -3.0); @@ -128,7 +120,7 @@ fn mul_div_rules_match_formula_complex64() { let (mul_y, mul_dy) = mul_frule(x, y, dx, dy); assert!((mul_y - (x * y)).norm() < 1e-12); - let expected_mul_tangent = dx * y.conj() + dy * x.conj(); + let expected_mul_tangent = dx * y + dy * x; assert!((mul_dy - expected_mul_tangent).norm() < 1e-12); let (mul_dx, mul_dy_rr) = mul_rrule(x, y, g); assert!((mul_dx - g * y.conj()).norm() < 1e-12); @@ -136,8 +128,8 @@ fn mul_div_rules_match_formula_complex64() { let (div_y, div_dy) = div_frule(x, y, dx, dy); assert!((div_y - (x / y)).norm() < 1e-12); - let expected_div_tangent = dx * (Complex64::new(1.0, 0.0) / y).conj() - + dy * ((Complex64::new(-1.0, 0.0) * x) / (y * y)).conj(); + let expected_div_tangent = + dx * (Complex64::new(1.0, 0.0) / y) + dy * ((Complex64::new(-1.0, 0.0) * x) / (y * y)); assert!((div_dy - expected_div_tangent).norm() < 1e-12); let (div_dx, div_dy_rr) = div_rrule(x, y, g); assert!((div_dx - g * (Complex64::new(1.0, 0.0) / y).conj()).norm() < 1e-12); @@ -155,11 +147,11 @@ fn powi_rules_match_formula_complex64() { let expected_y = x * x * x; assert!((y - expected_y).norm() < 1e-12); - let expected_scale = (Complex64::new(3.0, 0.0) * x.powi(2)).conj(); + let expected_scale = Complex64::new(3.0, 0.0) * x.powi(2); assert!((dy - (dx * expected_scale)).norm() < 1e-12); let grad = powi_rrule(x, exponent, g); - assert!((grad - (g * expected_scale)).norm() < 1e-12); + assert!((grad - (g * expected_scale.conj())).norm() < 1e-12); } #[test] @@ -204,16 +196,16 @@ fn complex_frules_and_rrules_cover_from_real_paths() { let g32 = Complex32::new(0.6, -0.2); let (_y32, dy32) = powf_frule(x32, 2.0_f32, dx32); let grad32 = powf_rrule(x32, 2.0_f32, g32); - let expected_scale32 = (Complex32::new(2.0, 0.0) * x32.powf(1.0_f32)).conj(); + let expected_scale32 = Complex32::new(2.0, 0.0) * x32.powf(1.0_f32); assert!((dy32 - dx32 * expected_scale32).norm() < 1e-5); - assert!((grad32 - g32 * expected_scale32).norm() < 1e-5); + assert!((grad32 - g32 * expected_scale32.conj()).norm() < 1e-5); let x64 = Complex64::new(-0.8, 1.1); let dx64 = Complex64::new(0.9, -0.4); let g64 = Complex64::new(0.3, 0.2); let (_y64, dy64) = powi_frule(x64, 4, dx64); let grad64 = powi_rrule(x64, 4, g64); - let expected_scale64 = (Complex64::new(4.0, 0.0) * x64.powi(3)).conj(); + let expected_scale64 = Complex64::new(4.0, 0.0) * x64.powi(3); assert!((dy64 - dx64 * expected_scale64).norm() < 1e-12); - assert!((grad64 - g64 * expected_scale64).norm() < 1e-12); + assert!((grad64 - g64 * expected_scale64.conj()).norm() < 1e-12); } diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs new file mode 100644 index 0000000..f5b7bb5 --- /dev/null +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -0,0 +1,275 @@ +mod common; + +use chainrules::{ + cbrt, cbrt_frule, cbrt_rrule, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, + hypot, hypot_frule, hypot_rrule, inv, inv_frule, inv_rrule, log10, log10_frule, log10_rrule, + log2, log2_frule, log2_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, + tan, tan_frule, tan_rrule, +}; +use num_complex::{Complex64, ComplexFloat}; + +use common::{assert_close_complex64, assert_close_f64}; + +#[test] +fn smooth_basis_helpers_are_reexported_from_chainrules() { + assert_close_f64(cbrt(8.0_f64), 2.0, 1.0e-12, 0.0, "cbrt"); + assert_close_f64(inv(4.0_f64), 0.25, 1.0e-12, 0.0, "inv"); + assert_close_f64(exp2(3.0_f64), 8.0, 1.0e-12, 0.0, "exp2"); + assert_close_f64(exp10(2.0_f64), 100.0, 1.0e-12, 0.0, "exp10"); + assert_close_f64(hypot(3.0_f64, 4.0_f64), 5.0, 1.0e-12, 0.0, "hypot"); + assert_close_f64(log2(8.0_f64), 3.0, 1.0e-12, 0.0, "log2"); + assert_close_f64(log10(100.0_f64), 2.0, 1.0e-12, 0.0, "log10"); + assert_close_f64(pow(2.0_f64, 3.0_f64), 8.0, 1.0e-12, 0.0, "pow"); + assert_close_f64(tan(0.5_f64), 0.5_f64.tan(), 1.0e-12, 0.0, "tan"); + let (sin_x, cos_x) = sincos(0.5_f64); + assert_close_f64(sin_x, 0.5_f64.sin(), 1.0e-12, 0.0, "sincos.sin"); + assert_close_f64(cos_x, 0.5_f64.cos(), 1.0e-12, 0.0, "sincos.cos"); + + let z = Complex64::new(1.0, 2.0); + let _ = pow(z, Complex64::new(2.0, 0.0)); +} + +#[test] +fn smooth_basis_frules_and_rrules_match_expected_derivatives() { + let (tan_y, tan_dy) = tan_frule(0.25_f64, 1.0_f64); + assert_close_f64(tan_y, 0.25_f64.tan(), 1.0e-12, 0.0, "tan.y"); + assert_close_f64( + tan_dy, + 1.0_f64 + 0.25_f64.tan().powi(2), + 1.0e-12, + 0.0, + "tan.dy", + ); + assert_close_f64( + tan_rrule(0.25_f64.tan(), 1.0_f64), + 1.0_f64 + 0.25_f64.tan().powi(2), + 1.0e-12, + 0.0, + "tan.rrule", + ); + + let (exp2_y, exp2_dy) = exp2_frule(3.0_f64, 1.0_f64); + assert_close_f64(exp2_y, 8.0, 1.0e-12, 0.0, "exp2.y"); + assert_close_f64( + exp2_dy, + 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "exp2.dy", + ); + assert_close_f64( + exp2_rrule(8.0_f64, 1.0_f64), + 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "exp2.rrule", + ); + + let (hypot_y, hypot_dy) = hypot_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64); + assert_close_f64(hypot_y, 5.0, 1.0e-12, 0.0, "hypot.y"); + assert_close_f64(hypot_dy, 0.5, 1.0e-12, 0.0, "hypot.dy"); + let (hypot_dx, hypot_dy) = hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64); + assert_close_f64(hypot_dx, 0.6_f64, 1.0e-12, 0.0, "hypot.rrule.dx"); + assert_close_f64(hypot_dy, 0.8_f64, 1.0e-12, 0.0, "hypot.rrule.dy"); + + let (pow_y, pow_dy) = pow_frule(2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); + assert_close_f64(pow_y, 8.0, 1.0e-12, 0.0, "pow.y"); + assert_close_f64(pow_dy, 12.0, 1.0e-12, 0.0, "pow.dy"); + let (pow_dx, pow_dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0_f64); + assert_close_f64(pow_dx, 12.0, 1.0e-12, 0.0, "pow.rrule.dx"); + assert_close_f64( + pow_dexp, + 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "pow.rrule.dexp", + ); + + let (pow_y, pow_dy) = pow_frule(2.0_f64, 3.0_f64, 1.0_f64, 0.5_f64); + assert_close_f64(pow_y, 8.0, 1.0e-12, 0.0, "pow.y.dexp"); + assert_close_f64( + pow_dy, + 12.0 + 0.5_f64 * 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "pow.dy.dexp", + ); + + let (sincos_y, sincos_dy) = sincos_frule(0.25_f64, 1.0_f64); + assert_close_f64(sincos_y.0, 0.25_f64.sin(), 1.0e-12, 0.0, "sincos.y.sin"); + assert_close_f64(sincos_y.1, 0.25_f64.cos(), 1.0e-12, 0.0, "sincos.y.cos"); + assert_close_f64(sincos_dy.0, 0.25_f64.cos(), 1.0e-12, 0.0, "sincos.dy.sin"); + assert_close_f64(sincos_dy.1, -0.25_f64.sin(), 1.0e-12, 0.0, "sincos.dy.cos"); + assert_close_f64( + sincos_rrule(0.25_f64, (1.0_f64, 1.0_f64)), + 0.25_f64.cos() - 0.25_f64.sin(), + 1.0e-12, + 0.0, + "sincos.rrule", + ); + + let (cbrt_y, cbrt_dy) = cbrt_frule(8.0_f64, 1.0_f64); + assert_close_f64(cbrt_y, 2.0, 1.0e-12, 0.0, "cbrt.y"); + assert_close_f64( + cbrt_dy, + 1.0_f64 / (3.0_f64 * 4.0_f64), + 1.0e-12, + 0.0, + "cbrt.dy", + ); + assert_close_f64( + cbrt_rrule(2.0_f64, 1.0_f64), + 1.0_f64 / (3.0_f64 * 4.0_f64), + 1.0e-12, + 0.0, + "cbrt.rrule", + ); + + let (inv_y, inv_dy) = inv_frule(4.0_f64, 2.0_f64); + assert_close_f64(inv_y, 0.25, 1.0e-12, 0.0, "inv.y"); + assert_close_f64(inv_dy, -0.125, 1.0e-12, 0.0, "inv.dy"); + assert_close_f64( + inv_rrule(0.25_f64, 2.0_f64), + -0.125, + 1.0e-12, + 0.0, + "inv.rrule", + ); + + let (log2_y, log2_dy) = log2_frule(8.0_f64, 2.0_f64); + assert_close_f64(log2_y, 3.0, 1.0e-12, 0.0, "log2.y"); + let expected_log2 = 2.0_f64 / (8.0_f64 * std::f64::consts::LN_2); + assert_close_f64(log2_dy, expected_log2, 1.0e-12, 0.0, "log2.dy"); + assert_close_f64( + log2_rrule(8.0_f64, 2.0_f64), + expected_log2, + 1.0e-12, + 0.0, + "log2.rrule", + ); + + let (log10_y, log10_dy) = log10_frule(100.0_f64, 2.0_f64); + assert_close_f64(log10_y, 2.0, 1.0e-12, 0.0, "log10.y"); + let expected_log10 = 2.0_f64 / (100.0_f64 * std::f64::consts::LN_10); + assert_close_f64(log10_dy, expected_log10, 1.0e-12, 0.0, "log10.dy"); + assert_close_f64( + log10_rrule(100.0_f64, 2.0_f64), + expected_log10, + 1.0e-12, + 0.0, + "log10.rrule", + ); + + let (exp10_y, exp10_dy) = exp10_frule(2.0_f64, 0.5_f64); + assert_close_f64(exp10_y, 100.0, 1.0e-12, 0.0, "exp10.y"); + let expected_exp10 = 100.0_f64 * std::f64::consts::LN_10 * 0.5_f64; + assert_close_f64(exp10_dy, expected_exp10, 1.0e-12, 0.0, "exp10.dy"); + assert_close_f64( + exp10_rrule(100.0_f64, 0.5_f64), + expected_exp10, + 1.0e-12, + 0.0, + "exp10.rrule", + ); +} + +#[test] +fn smooth_basis_complex_frules_match_standard_jvps() { + let z = Complex64::new(0.25, -0.5); + let dz = Complex64::new(0.5, -0.25); + + let (tan_y, tan_dy) = tan_frule(z, dz); + let tan_scale = Complex64::new(1.0, 0.0) + tan_y * tan_y; + assert_close_complex64(tan_y, z.tan(), 1.0e-12, 0.0, "tan.z"); + assert_close_complex64(tan_dy, dz * tan_scale, 1.0e-12, 0.0, "tan.dz"); + + let (exp2_y, exp2_dy) = exp2_frule(z, dz); + let exp2_scale = exp2_y * Complex64::new(std::f64::consts::LN_2, 0.0); + assert_close_complex64(exp2_y, z.exp2(), 1.0e-12, 0.0, "exp2.z"); + assert_close_complex64(exp2_dy, dz * exp2_scale, 1.0e-12, 0.0, "exp2.dz"); + + let (log2_y, log2_dy) = log2_frule(z, dz); + let log2_scale = Complex64::new(1.0, 0.0) / (z * Complex64::new(std::f64::consts::LN_2, 0.0)); + assert_close_complex64(log2_y, z.log2(), 1.0e-12, 0.0, "log2.z"); + assert_close_complex64(log2_dy, dz * log2_scale, 1.0e-12, 0.0, "log2.dz"); +} + +#[test] +fn smooth_basis_complex_frules_cover_additional_standard_jvps() { + let z = Complex64::new(0.25, -0.5); + let dz = Complex64::new(0.5, -0.25); + + let ((sin_y, cos_y), (dsin_dz, dcos_dz)) = sincos_frule(z, dz); + assert_close_complex64(sin_y, z.sin(), 1.0e-12, 0.0, "sincos.z.sin"); + assert_close_complex64(cos_y, z.cos(), 1.0e-12, 0.0, "sincos.z.cos"); + assert_close_complex64(dsin_dz, dz * z.cos(), 1.0e-12, 0.0, "sincos.dz.sin"); + assert_close_complex64(dcos_dz, dz * -z.sin(), 1.0e-12, 0.0, "sincos.dz.cos"); + + let (inv_y, inv_dy) = inv_frule(z, dz); + assert_close_complex64(inv_y, z.recip(), 1.0e-12, 0.0, "inv.z"); + assert_close_complex64(inv_dy, dz * (-(inv_y * inv_y)), 1.0e-12, 0.0, "inv.dz"); +} + +#[test] +fn pow_rules_handle_zero_and_negative_real_paths() { + let (neg_y, neg_dy) = pow_frule(-2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); + assert!((neg_y + 8.0).abs() < 1.0e-12); + assert!((neg_dy - 12.0).abs() < 1.0e-12); + + let (zero_y, zero_dy) = pow_frule(0.0_f64, 2.0_f64, 1.0_f64, 0.0_f64); + assert!((zero_y - 0.0).abs() < 1.0e-12); + assert!((zero_dy - 0.0).abs() < 1.0e-12); + + let (dx, dexp) = pow_rrule(0.0_f64, 2.0_f64, 1.0_f64); + assert!((dx - 0.0).abs() < 1.0e-12); + assert!((dexp - 0.0).abs() < 1.0e-12); +} + +#[test] +fn pow_rules_mark_zero_base_exponent_singularities_for_real_inputs() { + let (_, zero_zero_dy) = pow_frule(0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64); + assert!(zero_zero_dy.is_nan()); + + let (_, zero_neg_dy) = pow_frule(0.0_f64, -1.0_f64, 0.0_f64, 1.0_f64); + assert!(zero_neg_dy.is_nan()); + + let (_, zero_zero_dexp) = pow_rrule(0.0_f64, 0.0_f64, 1.0_f64); + assert!(zero_zero_dexp.is_nan()); + + let (_, zero_neg_dexp) = pow_rrule(0.0_f64, -1.0_f64, 1.0_f64); + assert!(zero_neg_dexp.is_nan()); + + let (_, zero_zero_dy32) = pow_frule(0.0_f32, 0.0_f32, 0.0_f32, 1.0_f32); + assert!(zero_zero_dy32.is_nan()); + + let (_, zero_neg_dexp32) = pow_rrule(0.0_f32, -1.0_f32, 1.0_f32); + assert!(zero_neg_dexp32.is_nan()); +} + +#[test] +fn pow_rules_cover_complex_frule_and_rrule_paths() { + let x = Complex64::new(1.0, 1.0); + let exponent = Complex64::new(2.0, 0.5); + let dx = Complex64::new(0.5, -0.25); + let dexp = Complex64::new(0.1, -0.2); + + let (y, dy) = pow_frule(x, exponent, dx, dexp); + let expected_y = x.powc(exponent); + let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))) + + dexp * (expected_y * x.ln()); + assert!((y - expected_y).norm() < 1.0e-12); + assert!((dy - expected_dy).norm() < 1.0e-12); + + let cotangent = Complex64::new(0.5, -0.25); + let (dx_rr, dexp_rr) = pow_rrule(x, exponent, cotangent); + let expected_dx_rr = + cotangent * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))).conj(); + let expected_dexp_rr = cotangent * (expected_y * x.ln()).conj(); + assert!((dx_rr - expected_dx_rr).norm() < 1.0e-12); + assert!((dexp_rr - expected_dexp_rr).norm() < 1.0e-12); + + let imag_x = Complex64::new(0.0, 1.0); + let (_, imag_dexp_rr) = pow_rrule(imag_x, Complex64::new(2.0, 0.0), Complex64::new(1.0, 0.0)); + let expected_imag_dexp_rr = (imag_x.powc(Complex64::new(2.0, 0.0)) * imag_x.ln()).conj(); + assert!((imag_dexp_rr - expected_imag_dexp_rr).norm() < 1.0e-12); +} diff --git a/docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md b/docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md new file mode 100644 index 0000000..93ee8fe --- /dev/null +++ b/docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md @@ -0,0 +1,220 @@ +# ChainRules Shared Scalar Basis Design + +## Goal + +Expand `chainrules-rs` into a broader shared scalar automatic-differentiation +basis for Rust crates while preserving the existing crate boundary: + +- `chainrules-core` stays protocol-only +- `chainrules` grows into the reusable scalar rule library +- runtime execution remains in engine crates such as `tidu-rs` +- tensor, array, and operation-specific rules stay in downstream crates + +The target scalar domains are: + +- `f32` +- `f64` +- `Complex32` +- `Complex64` + +## Boundary + +`chainrules-core` remains responsible for: + +- `Differentiable` +- `ReverseRule` +- `ForwardRule` +- `AutodiffError` +- shared AD result and node types + +`chainrules` remains responsible for: + +- stateless scalar primal helpers +- stateless scalar `*_frule` helpers +- stateless scalar `*_rrule` helpers +- PyTorch-style real-input/complex-gradient projection helpers + +Explicitly out of scope for this repository: + +- tape or traced-value runtimes +- `RuleConfig`, `rrule_via_ad`, `frule_via_ad` +- generic `ProjectTo`, `Thunk`, or `InplaceableThunk` machinery +- tensor, array, broadcast, reduction, or factorization rules +- operation-specific AD rules such as einsum or SVD + +## Documentation Layout + +Documentation is split by responsibility. + +- `README.md` + Repository boundary, crate roles, and representative function families. +- `crates/chainrules/README.md` + Canonical list of supported scalar functions and examples. +- `crates/chainrules-core/README.md` + Protocol-only documentation. It explicitly states that this crate does not + ship function rules. + +This keeps the supported function inventory attached to the crate that +actually provides it, while still making the repository-level boundary obvious. + +## API Shape + +The public API keeps the existing flat naming style: + +- `foo` +- `foo_frule` +- `foo_rrule` + +The implementation stays internally modular and family-oriented, with +small focused source files and flat re-exports from `chainrules::lib`. + +`rrule` helpers take the minimum saved values needed for reverse-mode formulas. +Examples: + +- `exp_rrule(result, cotangent)` +- `log_rrule(x, cotangent)` +- `pow_rrule(x, p, cotangent) -> (dx, dp)` +- `sincos_rrule(x, (dsin, dcos)) -> dx` + +## Function Families + +The end state is a broad scalar rule basis, not a narrow Rust-idiomatic subset. +This keeps `chainrules` useful as a landing zone for Julia-to-Rust rule ports. + +### Core Numeric Basis + +- arithmetic: `add`, `sub`, `mul`, `div` +- powers and roots: `pow`, `powf`, `powi`, `sqrt`, `cbrt`, `inv` +- exponentials and logs: + `exp`, `exp2`, `exp10`, `expm1`, `log`, `log2`, `log10`, `log1p` +- trigonometric: + `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `atan2`, `sincos` +- hyperbolic: + `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` +- combined real/complex helpers: + `hypot` + +### Complex And Projection Helpers + +- `conj` +- `real` +- `imag` +- `abs` +- `abs2` +- `angle` +- `complex` +- `handle_r_to_c` + +### Julia Migration Convenience Surface + +- reciprocal trig: `sec`, `csc`, `cot` +- reciprocal hyperbolic: `sech`, `csch`, `coth` +- pi-based trig: `sinpi`, `cospi`, `sincospi` +- degree-based trig: `sind`, `cosd`, `tand` +- unit conversion: `deg2rad`, `rad2deg` + +### Non-Smooth Utilities + +- `sign` +- `copysign` +- `min` +- `max` +- `round` +- `floor` +- `ceil` +- `rem` +- `mod` +- `fma` +- `muladd` + +## Compatibility Policy + +The crate prefers semantic compatibility with the Julia scalar surface where +that improves migration and shared rule reuse, but it does not attempt to +become a full port of `ChainRules.jl`. + +Compatibility rules: + +- include Julia-style convenience scalar functions when they are useful as + migration landing zones +- keep tensor and runtime-specific abstractions out of scope +- document branch cuts, singularities, and tie behavior explicitly +- preserve the existing complex-gradient convention used by the crate + +## Implementation Shape + +The public API remains flat, but source files should stay small. The expected +end-state module layout in `crates/chainrules/src` is: + +- `binary.rs` for basic binary arithmetic +- `binary_special.rs` for `atan2`, `hypot`, `min/max`, `copysign`, + `rem/mod`, `fma`, `muladd` +- `power.rs` for `pow`, `powf`, `powi` +- `scalar_ad.rs` for the scalar trait surface and shared scalar helpers +- `unary/basic.rs` +- `unary/roots.rs` +- `unary/exp_log.rs` +- `unary/trig.rs` +- `unary/trig_extra.rs` +- `unary/hyperbolic.rs` +- `unary/hyperbolic_extra.rs` +- `unary/complex_parts.rs` +- `unary/nonsmooth.rs` + +The exact split may adjust slightly while keeping file sizes under control. + +## Testing Strategy + +Testing extends the current two-layer approach. + +### Formula Tests + +Add family-specific integration tests for: + +- smooth numeric basis +- complex and projection helpers +- Julia compatibility helpers +- non-smooth utilities + +Use generic helpers to avoid duplicating logic across scalar types. Favor +`f64` and `Complex64` as the primary correctness targets, with focused +sanity checks for `f32` and `Complex32`. + +### Oracle Tests + +Continue using the vendored `tensor-ad-oracles` data for published references. +Expand coverage for functions that already exist in the oracle set, such as: + +- `tan` +- `exp2` +- `exp10` +- `log2` +- `log10` +- `abs` +- `angle` +- `hypot` +- `deg2rad` +- `rad2deg` +- `sign` + +### Behavioral Tests + +Keep hand-written tests for: + +- singularities such as `sqrt(0)` +- branch-cut-sensitive complex functions +- non-smooth behavior for `sign`, `round`, `floor`, `ceil` +- tie behavior for `min` and `max` + +## Rollout + +Implement in four ordered phases. + +1. README reorganization and crate-boundary documentation. +2. Smooth core scalar basis and complex/projection helpers. +3. Julia convenience scalar surface. +4. Non-smooth and tie-sensitive utility rules. + +This sequencing makes the public boundary clear first, lands the highest-value +smooth basis next, and defers specification-heavy non-smooth behavior until the +rest of the surface is stable. diff --git a/docs/plans/2026-03-21-complex-jvp-alignment.md b/docs/plans/2026-03-21-complex-jvp-alignment.md new file mode 100644 index 0000000..9869cee --- /dev/null +++ b/docs/plans/2026-03-21-complex-jvp-alignment.md @@ -0,0 +1,437 @@ +# Complex JVP Alignment Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Align complex `frule` semantics with PyTorch and standard JVP over `C ~= R^2`, keep complex `rrule` on the existing conjugate-Wirtinger convention for real-valued losses, remove the unused `handle_r_to_c_*` helpers, and validate complex forward rules directly against vendored oracle JVPs. + +**Architecture:** Treat this as a family-wide semantic bug, not a one-off formula mismatch. Any complex `frule` that currently applies `conj(local_derivative)` must switch to the plain pushforward `J * dx`, while `rrule` keeps the existing `J^H * cotangent` behavior. Audit the shared scalar basis by module family, update local formula tests first, then re-enable direct complex oracle replay so the repository validates both float64 and Complex64 entrypoints against the same published PyTorch JVP/VJP data. + +**Tech Stack:** Rust 2021, `chainrules`, `chainrules-core`, `num-complex`, vendored `tensor-ad-oracles`, `cargo fmt`, `cargo nextest`, `cargo test --doc`, `cargo clippy`, `cargo llvm-cov`. + +--- + +### Task 1: Re-enable Direct Complex Oracle Replay + +**Files:** +- Modify: `crates/chainrules/tests/common.rs` +- Modify: `crates/chainrules/tests/oracle_scalar_rules.rs` +- Modify: `crates/chainrules/src/unary/trig.rs` +- Modify: `crates/chainrules/src/unary/exp_log.rs` + +**Step 1: Write the failing test** + +Replace the reverse-only Complex64 oracle entrypoint with a direct `UnaryOracleCase` replay for the currently covered ops: + +```rust +#[test] +fn published_complex128_oracles_match_unary_rule_entrypoints() { + let cases: [UnaryOracleCase; 3] = [ + UnaryOracleCase { + op: "tan", + frule: tan_frule, + rrule: |_x: Complex64, result, cotangent| tan_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "exp2", + frule: exp2_frule, + rrule: |_x: Complex64, result, cotangent| exp2_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "log2", + frule: log2_frule, + rrule: |x: Complex64, _result, cotangent| log2_rrule(x, cotangent), + }, + ]; + + run_unary_oracle_cases(&cases); +} +``` + +Delete the reverse-only helper path and its comment in `tests/common.rs`. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test oracle_scalar_rules +``` + +Expected: FAIL with Complex64 JVP mismatches for `tan`, `exp2`, or `log2`. + +**Step 3: Write minimal implementation** + +Make the smallest source change that matches the oracle: + +```rust +pub fn tan_frule(x: S, dx: S) -> (S, S) { + let y = x.tan(); + (y, dx * (one::() + y * y)) +} + +pub fn exp2_frule(x: S, dx: S) -> (S, S) { + let y = x.exp2(); + (y, dx * (y * ln_2::())) +} + +pub fn log2_frule(x: S, dx: S) -> (S, S) { + let y = x.log2(); + let scale = one::() / (x * ln_2::()); + (y, dx * scale) +} +``` + +Keep the corresponding `rrule` implementations unchanged. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test oracle_scalar_rules +``` + +Expected: PASS for the Complex64 oracle replay test. + +**Step 5: Commit** + +```bash +git add crates/chainrules/tests/common.rs crates/chainrules/tests/oracle_scalar_rules.rs crates/chainrules/src/unary/trig.rs crates/chainrules/src/unary/exp_log.rs +git commit -m "refactor: replay complex unary oracles with direct jvp" +``` + +### Task 2: Align Smooth Unary Complex `frule` Families + +**Files:** +- Modify: `crates/chainrules/src/unary/basic.rs` +- Modify: `crates/chainrules/src/unary/exp_log.rs` +- Modify: `crates/chainrules/src/unary/trig.rs` +- Modify: `crates/chainrules/src/unary/hyperbolic.rs` +- Modify: `crates/chainrules/src/unary/roots.rs` +- Modify: `crates/chainrules/src/tests/behavior.rs` +- Modify: `crates/chainrules/tests/smooth_basis_tests.rs` + +**Step 1: Write the failing test** + +Rename the behavior test so it states the new rule, then switch the expectations to plain JVP: + +```rust +#[test] +fn extended_complex_unary_frules_match_standard_jvp_while_rrules_stay_conjugate() { + let x = Complex64::new(0.25, -0.5); + let dx = Complex64::new(-0.75, 0.5); + let cotangent = Complex64::new(0.5, -1.25); + + let (_sin_y, sin_dy) = sin_frule(x, dx); + assert_close_c64(sin_dy, dx * ComplexFloat::cos(x)); + assert_close_c64( + sin_rrule(x, cotangent), + cotangent * ComplexFloat::conj(ComplexFloat::cos(x)), + ); +} +``` + +Update `smooth_basis_complex_frules_match_expected_derivatives` the same way for `tan_frule`, `exp2_frule`, and `log2_frule`: remove `.conj()` from the expected forward scales but keep the reverse expectations unchanged. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test smooth_basis_tests --test chainrules \ + smooth_basis_complex_frules_match_expected_derivatives \ + extended_complex_unary_frules_match_standard_jvp_while_rrules_stay_conjugate +``` + +Expected: FAIL because the remaining unary `frule` implementations still conjugate their local derivatives. + +**Step 3: Write minimal implementation** + +Remove forward-mode conjugation from every smooth unary helper in these files: + +```rust +pub fn sqrt_frule(x: S, dx: S) -> (S, S) { + let y = x.sqrt(); + let dy = dx / (S::from_i32(2) * y); + (y, dy) +} + +pub fn exp_frule(x: S, dx: S) -> (S, S) { + let y = x.exp(); + (y, dx * y) +} + +pub fn asinh_frule(x: S, dx: S) -> (S, S) { + let y = x.asinh(); + let scale = inverse_sqrt_one_plus_square(x); + (y, dx * scale) +} +``` + +Audit every `*_frule` in the listed modules and change only the forward formulas. Keep `rrule` code as-is. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test smooth_basis_tests +cargo nextest run --release --test chainrules extended_complex_unary_frules_match_standard_jvp_while_rrules_stay_conjugate +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/unary/basic.rs crates/chainrules/src/unary/exp_log.rs crates/chainrules/src/unary/trig.rs crates/chainrules/src/unary/hyperbolic.rs crates/chainrules/src/unary/roots.rs crates/chainrules/src/tests/behavior.rs crates/chainrules/tests/smooth_basis_tests.rs +git commit -m "refactor: align smooth unary complex frules with standard jvp" +``` + +### Task 3: Align Binary And Power Complex `frule` Families + +**Files:** +- Modify: `crates/chainrules/src/binary.rs` +- Modify: `crates/chainrules/src/power.rs` +- Modify: `crates/chainrules/tests/scalarops_tests.rs` +- Modify: `crates/chainrules/tests/smooth_basis_tests.rs` + +**Step 1: Write the failing test** + +Change the complex forward expectations so they use the plain Jacobian instead of the conjugated one: + +```rust +#[test] +fn mul_div_rules_match_formula_complex64() { + let x = Complex64::new(1.5, -0.5); + let y = Complex64::new(-0.25, 2.0); + let dx = Complex64::new(0.3, -0.2); + let dy = Complex64::new(-0.1, 0.4); + + let (mul_y, mul_dy) = mul_frule(x, y, dx, dy); + assert!((mul_y - (x * y)).norm() < 1e-12); + assert!((mul_dy - (dx * y + dy * x)).norm() < 1e-12); +} +``` + +Update the power tests likewise: + +```rust +let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))) + + dexp * (expected_y * x.ln()); +assert!((dy - expected_dy).norm() < 1e-12); +``` + +Keep all `rrule` expectations on the conjugated derivative. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test scalarops_tests --test smooth_basis_tests +``` + +Expected: FAIL in the Complex64 forward checks for `mul`, `div`, `powf`, `powi`, or `pow`. + +**Step 3: Write minimal implementation** + +Update only the forward formulas: + +```rust +pub fn mul_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { + let primal = x * y; + let tangent = dx * y + dy * x; + (primal, tangent) +} + +pub fn div_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { + let primal = x / y; + let inv_y = one::() / y; + let dfdx = inv_y; + let dfdy = -(x * inv_y * inv_y); + (primal, dx * dfdx + dy * dfdy) +} +``` + +Mirror the same change in `powf_frule`, `powi_frule`, and `pow_frule`. Keep the singularity behavior at `pow(0, 0)` and zero-base negative exponents unchanged. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test scalarops_tests +cargo nextest run --release --test smooth_basis_tests +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/binary.rs crates/chainrules/src/power.rs crates/chainrules/tests/scalarops_tests.rs crates/chainrules/tests/smooth_basis_tests.rs +git commit -m "refactor: align binary and power complex frules with standard jvp" +``` + +### Task 4: Align Julia Compatibility Helpers With Standard JVP + +**Files:** +- Modify: `crates/chainrules/src/unary/trig_extra.rs` +- Modify: `crates/chainrules/src/unary/hyperbolic_extra.rs` +- Modify: `crates/chainrules/tests/julia_compat_trig_tests.rs` + +**Step 1: Write the failing test** + +Change the Complex64 forward expectations in the Julia-compat tests and add one extra representative check so the family is covered: + +```rust +let (_, dsinpi) = sinpi_frule(z, dz); +assert_close_complex64( + dsinpi, + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), + 1e-12, + 0.0, + "sinpi_frule(z)", +); + +let ((_sin_y, _cos_y), (dsin, dcos)) = sincospi_frule(z, dz); +assert_close_complex64(dsin, dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), 1e-12, 0.0, "sincospi_frule.sin"); +assert_close_complex64(dcos, dz * (-(Complex64::new(std::f64::consts::PI, 0.0) * pi_z.sin())), 1e-12, 0.0, "sincospi_frule.cos"); +``` + +Keep the `rrule` expectations on the conjugated derivative. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test julia_compat_trig_tests +``` + +Expected: FAIL for Complex64 forward checks in `sinpi_frule`, `sincospi_frule`, or related helpers. + +**Step 3: Write minimal implementation** + +Audit the Julia migration helpers and remove forward-mode conjugation everywhere it still appears: + +```rust +pub fn sinpi_frule(x: S, dx: S) -> (S, S) { + let y = sinpi(x); + let scale = pi::() * cospi(x); + (y, dx * scale) +} +``` + +Repeat the same cleanup for `cospi_frule`, `sincospi_frule`, `tand_frule`, `sec_frule`, `csc_frule`, `cot_frule`, `sech_frule`, `csch_frule`, and `coth_frule`. If a degree-based helper delegates entirely to corrected primitives, prefer reuse over re-deriving a new formula. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test julia_compat_trig_tests +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/unary/trig_extra.rs crates/chainrules/src/unary/hyperbolic_extra.rs crates/chainrules/tests/julia_compat_trig_tests.rs +git commit -m "refactor: align Julia compatibility frules with standard jvp" +``` + +### Task 5: Remove `handle_r_to_c_*` And Update Public Docs + +**Files:** +- Modify: `crates/chainrules/src/scalar_ad/mod.rs` +- Modify: `crates/chainrules/src/lib.rs` +- Modify: `crates/chainrules/src/tests/behavior.rs` +- Modify: `crates/chainrules/tests/scalarops_tests.rs` +- Modify: `README.md` +- Modify: `crates/chainrules/README.md` + +**Step 1: Write the failing cleanup** + +Remove the public re-exports first: + +```rust +#[doc(inline)] +pub use scalar_ad::ScalarAd; +``` + +Then update the test imports so any remaining helper references are explicit failures: + +```rust +use chainrules::{ + add, add_frule, add_rrule, conj, conj_frule, conj_rrule, div, div_frule, div_rrule, mul, + mul_frule, mul_rrule, powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule, sqrt, + sqrt_frule, sqrt_rrule, sub, sub_frule, sub_rrule, +}; +``` + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test scalarops_tests --test chainrules +cargo test --doc --release -p chainrules +``` + +Expected: FAIL because `handle_r_to_c_*` doctests and dedicated tests still exist. + +**Step 3: Write minimal implementation** + +Delete the unused helpers and update the docs to state the new convention clearly: + +- remove `handle_r_to_c_f32` and `handle_r_to_c_f64` from `scalar_ad/mod.rs` +- delete the dedicated helper checks from `src/tests/behavior.rs` and `tests/scalarops_tests.rs` +- update `README.md` so the testing section says Complex64 oracle replay is direct, not reverse-only +- update `crates/chainrules/README.md` so it: + - removes `handle_r_to_c_*` from the supported surface + - states that complex `frule` follows standard JVP on `C ~= R^2` + - states that complex `rrule` remains conjugate-Wirtinger for real-valued losses + - keeps the provenance note that this crate is a landing zone for scalar rules ported or adapted from `ChainRules.jl`, not a full port + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test scalarops_tests --test chainrules +cargo test --doc --release -p chainrules +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/scalar_ad/mod.rs crates/chainrules/src/lib.rs crates/chainrules/src/tests/behavior.rs crates/chainrules/tests/scalarops_tests.rs README.md crates/chainrules/README.md +git commit -m "cleanup: remove projection helpers and document complex jvp" +``` + +## Final Verification + +Run the full repository verification after the last task: + +```bash +cargo fmt --all --check +cargo nextest run --release --workspace --no-fail-fast +cargo test --doc --release --workspace +cargo clippy --workspace --all-targets -- -D warnings +cargo llvm-cov nextest --workspace --release --json --output-path coverage.json +python3 scripts/check-coverage.py coverage.json +cargo doc --workspace --no-deps +python3 scripts/check-docs-site.py +``` + +Expected: all commands PASS with no direct-complex-oracle failures, no remaining `handle_r_to_c_*` references, and no docs claiming that complex forward replay is local-only.