From 9861189ffee03cc81de4b57a9b48f393f1f3a492 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Fri, 20 Mar 2026 08:58:45 +0900 Subject: [PATCH 01/32] Add MIT and Apache-2.0 license files Cargo.toml already declares `license = "MIT OR Apache-2.0"` but the actual license text files were missing from the repository. Co-Authored-By: Claude Opus 4.6 --- LICENSE-APACHE | 190 +++++++++++++++++++++++++++++++++++++++++++++++++ LICENSE-MIT | 21 ++++++ 2 files changed, 211 insertions(+) create mode 100644 LICENSE-APACHE create mode 100644 LICENSE-MIT diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..7c436b4 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,190 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to the Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by the Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding any notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +Copyright 2026 Hiroshi Shinaoka, GiggleLiu, and contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..98eeb52 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Hiroshi Shinaoka, GiggleLiu, and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From e23d0e02a971fae204d1599a5bfa715d29957104 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:04:08 +0900 Subject: [PATCH 02/32] docs: record shared scalar basis design --- ...1-chainrules-shared-scalar-basis-design.md | 220 ++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md diff --git a/docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md b/docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md new file mode 100644 index 0000000..93ee8fe --- /dev/null +++ b/docs/plans/2026-03-21-chainrules-shared-scalar-basis-design.md @@ -0,0 +1,220 @@ +# ChainRules Shared Scalar Basis Design + +## Goal + +Expand `chainrules-rs` into a broader shared scalar automatic-differentiation +basis for Rust crates while preserving the existing crate boundary: + +- `chainrules-core` stays protocol-only +- `chainrules` grows into the reusable scalar rule library +- runtime execution remains in engine crates such as `tidu-rs` +- tensor, array, and operation-specific rules stay in downstream crates + +The target scalar domains are: + +- `f32` +- `f64` +- `Complex32` +- `Complex64` + +## Boundary + +`chainrules-core` remains responsible for: + +- `Differentiable` +- `ReverseRule` +- `ForwardRule` +- `AutodiffError` +- shared AD result and node types + +`chainrules` remains responsible for: + +- stateless scalar primal helpers +- stateless scalar `*_frule` helpers +- stateless scalar `*_rrule` helpers +- PyTorch-style real-input/complex-gradient projection helpers + +Explicitly out of scope for this repository: + +- tape or traced-value runtimes +- `RuleConfig`, `rrule_via_ad`, `frule_via_ad` +- generic `ProjectTo`, `Thunk`, or `InplaceableThunk` machinery +- tensor, array, broadcast, reduction, or factorization rules +- operation-specific AD rules such as einsum or SVD + +## Documentation Layout + +Documentation is split by responsibility. + +- `README.md` + Repository boundary, crate roles, and representative function families. +- `crates/chainrules/README.md` + Canonical list of supported scalar functions and examples. +- `crates/chainrules-core/README.md` + Protocol-only documentation. It explicitly states that this crate does not + ship function rules. + +This keeps the supported function inventory attached to the crate that +actually provides it, while still making the repository-level boundary obvious. + +## API Shape + +The public API keeps the existing flat naming style: + +- `foo` +- `foo_frule` +- `foo_rrule` + +The implementation stays internally modular and family-oriented, with +small focused source files and flat re-exports from `chainrules::lib`. + +`rrule` helpers take the minimum saved values needed for reverse-mode formulas. +Examples: + +- `exp_rrule(result, cotangent)` +- `log_rrule(x, cotangent)` +- `pow_rrule(x, p, cotangent) -> (dx, dp)` +- `sincos_rrule(x, (dsin, dcos)) -> dx` + +## Function Families + +The end state is a broad scalar rule basis, not a narrow Rust-idiomatic subset. +This keeps `chainrules` useful as a landing zone for Julia-to-Rust rule ports. + +### Core Numeric Basis + +- arithmetic: `add`, `sub`, `mul`, `div` +- powers and roots: `pow`, `powf`, `powi`, `sqrt`, `cbrt`, `inv` +- exponentials and logs: + `exp`, `exp2`, `exp10`, `expm1`, `log`, `log2`, `log10`, `log1p` +- trigonometric: + `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `atan2`, `sincos` +- hyperbolic: + `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` +- combined real/complex helpers: + `hypot` + +### Complex And Projection Helpers + +- `conj` +- `real` +- `imag` +- `abs` +- `abs2` +- `angle` +- `complex` +- `handle_r_to_c` + +### Julia Migration Convenience Surface + +- reciprocal trig: `sec`, `csc`, `cot` +- reciprocal hyperbolic: `sech`, `csch`, `coth` +- pi-based trig: `sinpi`, `cospi`, `sincospi` +- degree-based trig: `sind`, `cosd`, `tand` +- unit conversion: `deg2rad`, `rad2deg` + +### Non-Smooth Utilities + +- `sign` +- `copysign` +- `min` +- `max` +- `round` +- `floor` +- `ceil` +- `rem` +- `mod` +- `fma` +- `muladd` + +## Compatibility Policy + +The crate prefers semantic compatibility with the Julia scalar surface where +that improves migration and shared rule reuse, but it does not attempt to +become a full port of `ChainRules.jl`. + +Compatibility rules: + +- include Julia-style convenience scalar functions when they are useful as + migration landing zones +- keep tensor and runtime-specific abstractions out of scope +- document branch cuts, singularities, and tie behavior explicitly +- preserve the existing complex-gradient convention used by the crate + +## Implementation Shape + +The public API remains flat, but source files should stay small. The expected +end-state module layout in `crates/chainrules/src` is: + +- `binary.rs` for basic binary arithmetic +- `binary_special.rs` for `atan2`, `hypot`, `min/max`, `copysign`, + `rem/mod`, `fma`, `muladd` +- `power.rs` for `pow`, `powf`, `powi` +- `scalar_ad.rs` for the scalar trait surface and shared scalar helpers +- `unary/basic.rs` +- `unary/roots.rs` +- `unary/exp_log.rs` +- `unary/trig.rs` +- `unary/trig_extra.rs` +- `unary/hyperbolic.rs` +- `unary/hyperbolic_extra.rs` +- `unary/complex_parts.rs` +- `unary/nonsmooth.rs` + +The exact split may adjust slightly while keeping file sizes under control. + +## Testing Strategy + +Testing extends the current two-layer approach. + +### Formula Tests + +Add family-specific integration tests for: + +- smooth numeric basis +- complex and projection helpers +- Julia compatibility helpers +- non-smooth utilities + +Use generic helpers to avoid duplicating logic across scalar types. Favor +`f64` and `Complex64` as the primary correctness targets, with focused +sanity checks for `f32` and `Complex32`. + +### Oracle Tests + +Continue using the vendored `tensor-ad-oracles` data for published references. +Expand coverage for functions that already exist in the oracle set, such as: + +- `tan` +- `exp2` +- `exp10` +- `log2` +- `log10` +- `abs` +- `angle` +- `hypot` +- `deg2rad` +- `rad2deg` +- `sign` + +### Behavioral Tests + +Keep hand-written tests for: + +- singularities such as `sqrt(0)` +- branch-cut-sensitive complex functions +- non-smooth behavior for `sign`, `round`, `floor`, `ceil` +- tie behavior for `min` and `max` + +## Rollout + +Implement in four ordered phases. + +1. README reorganization and crate-boundary documentation. +2. Smooth core scalar basis and complex/projection helpers. +3. Julia convenience scalar surface. +4. Non-smooth and tie-sensitive utility rules. + +This sequencing makes the public boundary clear first, lands the highest-value +smooth basis next, and defers specification-heavy non-smooth behavior until the +rest of the surface is stable. From 9c5d113c7a2c179f298469fcf613837aee1c5ec3 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:08:21 +0900 Subject: [PATCH 03/32] docs: clarify chainrules crate boundaries --- README.md | 18 +++++++++---- crates/chainrules-core/README.md | 31 +++++++++++++++++++++ crates/chainrules/README.md | 46 ++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 crates/chainrules-core/README.md create mode 100644 crates/chainrules/README.md diff --git a/README.md b/README.md index 7b8569b..55dac97 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ specific AD engine. It contains: -- `chainrules-core`: core AD traits and error types -- `chainrules`: reusable scalar `frule`/`rrule` helpers and related utilities +- `chainrules-core`: engine-independent AD protocol +- `chainrules`: shared scalar rule basis It intentionally does **not** ship a tape, traced value type, or any other AD engine runtime. Those live in separate engine crates such as @@ -25,9 +25,9 @@ engine runtime. Those live in separate engine crates such as ## Repository Layout -- [`crates/chainrules-core`](crates/chainrules-core): `Differentiable`, - `ReverseRule`, `ForwardRule`, `AutodiffError`, and related core types -- [`crates/chainrules`](crates/chainrules): scalar rule implementations such as +- [`crates/chainrules-core`](crates/chainrules-core): protocol-only crate for + `Differentiable`, `ReverseRule`, `ForwardRule`, and `AutodiffError` +- [`crates/chainrules`](crates/chainrules): shared scalar rules such as `exp`, `log1p`, `sin`, `atanh`, `powf`, and `atan2` - [`third_party/tensor-ad-oracles`](third_party/tensor-ad-oracles): vendored oracle data used to validate scalar rules against published references @@ -48,6 +48,14 @@ engine that executes those rules over a tape. The boundary is deliberate: - `tidu-rs` can evolve independently as an engine - downstream tensor libraries can swap engines without rewriting scalar rules +## Crate Roles + +`chainrules-core` does not provide function rules. +`chainrules` provides stateless scalar `foo`, `foo_frule`, and `foo_rrule` +helpers. + +See the crate READMEs for the supported scalar function inventory and examples. + ## Testing ```bash diff --git a/crates/chainrules-core/README.md b/crates/chainrules-core/README.md new file mode 100644 index 0000000..d288b82 --- /dev/null +++ b/crates/chainrules-core/README.md @@ -0,0 +1,31 @@ +# chainrules-core + +`chainrules-core` defines the engine-independent AD protocol used by +`chainrules-rs`. + +It is intentionally small and does not provide function rules. The crate exists +to define the traits and error types that downstream AD engines and rule +libraries build on. + +## What It Provides + +- `Differentiable` +- `ReverseRule` +- `ForwardRule` +- `AutodiffError` +- `NodeId` +- `SavePolicy` + +## Example + +```rust +use chainrules_core::NodeId; + +let id = NodeId::new(7); +assert_eq!(id.index(), 7); +``` + +## Notes + +This crate is protocol-only. Shared scalar `frule` and `rrule` helpers live in +[`chainrules`](../chainrules). diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md new file mode 100644 index 0000000..ae17fcf --- /dev/null +++ b/crates/chainrules/README.md @@ -0,0 +1,46 @@ +# chainrules + +`chainrules` provides a shared scalar rule basis for Rust automatic +differentiation crates. + +It is designed for reusable scalar calculus, not for tapes, traced values, or +tensor-specific execution engines. The crate focuses on stateless helpers that +can be called from downstream AD runtimes and tensor libraries. + +## What It Provides + +- stateless scalar primal helpers +- stateless scalar `foo_frule` helpers +- stateless scalar `foo_rrule` helpers +- real/complex projection helpers for common scalar formulas + +Supported scalar domains: + +- `f32` +- `f64` +- `Complex32` +- `Complex64` + +## Examples + +```rust +use chainrules::{powf, powf_frule, powf_rrule}; + +let y = powf(2.0_f64, 3.0_f64); +assert_eq!(y, 8.0_f64); + +let (y, dy) = powf_frule(2.0_f64, 3.0_f64, 1.0_f64); +assert_eq!(y, 8.0_f64); +assert_eq!(dy, 12.0_f64); + +let dx = powf_rrule(2.0_f64, 3.0_f64, 1.0_f64); +assert_eq!(dx, 12.0_f64); +``` + +## Notes + +This crate is the landing zone for shared scalar rule logic, including +Julia-style convenience functions when they help migration. + +It does not define tensor, array, broadcast, reduction, or engine-specific +rules. From 4cbdf5205643c250e2330c4ce7dbea85cff7269d Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:15:47 +0900 Subject: [PATCH 04/32] docs: wire scalar readmes into rustdoc --- crates/chainrules-core/src/lib.rs | 2 ++ crates/chainrules/README.md | 16 ++++++++++++++++ crates/chainrules/src/lib.rs | 2 ++ 3 files changed, 20 insertions(+) diff --git a/crates/chainrules-core/src/lib.rs b/crates/chainrules-core/src/lib.rs index ddcde00..883f108 100644 --- a/crates/chainrules-core/src/lib.rs +++ b/crates/chainrules-core/src/lib.rs @@ -1,3 +1,5 @@ +#![doc = include_str!("../README.md")] + //! Core AD trait definitions (like Julia's ChainRulesCore.jl). //! //! This crate defines the interface for automatic differentiation without diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index ae17fcf..b2c4852 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -21,6 +21,22 @@ Supported scalar domains: - `Complex32` - `Complex64` +## Supported Functions + +Current shipped scalar families: + +- arithmetic: `add`, `sub`, `mul`, `div` +- powers and roots: `powf`, `powi`, `sqrt` +- exponentials and logs: `exp`, `expm1`, `log`, `log1p` +- trigonometric: `sin`, `cos`, `asin`, `acos`, `atan` +- hyperbolic: `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` +- complex and projection helpers: `conj`, `handle_r_to_c` +- real-valued binary helpers: `atan2` + +This crate is intended as a landing zone for scalar rules ported or adapted +from Julia's `ChainRules.jl` where they fit this crate boundary, but it is not +a full port of `ChainRules.jl`. + ## Examples ```rust diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 9fb6fa8..9619666 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -1,3 +1,5 @@ +#![doc = include_str!("../README.md")] + //! Engine-independent scalar AD helper rules for elementary operations. //! //! This crate provides stateless primal/frule/rrule helpers for scalar From 4f6d09ba2058e45372f76da14abf40f7a16a2c52 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:21:21 +0900 Subject: [PATCH 05/32] docs: remove duplicate scalar crate overview --- crates/chainrules/README.md | 2 +- crates/chainrules/src/lib.rs | 36 ------------------------------------ 2 files changed, 1 insertion(+), 37 deletions(-) diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index b2c4852..0aae8d2 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -30,7 +30,7 @@ Current shipped scalar families: - exponentials and logs: `exp`, `expm1`, `log`, `log1p` - trigonometric: `sin`, `cos`, `asin`, `acos`, `atan` - hyperbolic: `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` -- complex and projection helpers: `conj`, `handle_r_to_c` +- complex and projection helpers: `conj`, `handle_r_to_c_f32`, `handle_r_to_c_f64` - real-valued binary helpers: `atan2` This crate is intended as a landing zone for scalar rules ported or adapted diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 9619666..e5b0f21 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -1,41 +1,5 @@ #![doc = include_str!("../README.md")] -//! Engine-independent scalar AD helper rules for elementary operations. -//! -//! This crate provides stateless primal/frule/rrule helpers for scalar -//! operations. It is designed to be shared across AD engines and higher-level -//! tensor libraries. -//! -//! Public helper families: -//! -//! - `add`, `sub`, `mul`, `div` -//! - `conj` -//! - `sqrt` -//! - `exp`, `expm1`, `log`, `log1p` -//! - `sin`, `cos` -//! - `sinh`, `cosh` -//! - `asin`, `acos`, `atan`, `asinh`, `acosh`, `atanh` -//! - `powf` (fixed real exponent) -//! - `powi` (fixed integer exponent) -//! - `atan2` (real scalars) -//! -//! The [`ScalarAd`] trait also provides the scalar method basis used by the -//! higher-level wrappers, including `expm1`, `log1p`, `sin`, `cos`, and -//! `tanh`. -//! -//! # Examples -//! -//! ```rust -//! use chainrules::{powf_frule, powf_rrule}; -//! -//! let (y, dy) = powf_frule(2.0_f64, 3.0, 1.0); -//! assert_eq!(y, 8.0); -//! assert_eq!(dy, 12.0); -//! -//! let dx = powf_rrule(2.0_f64, 3.0, 1.0); -//! assert_eq!(dx, 12.0); -//! ``` - pub use chainrules_core::{ AdResult, AutodiffError, Differentiable, ForwardRule, NodeId, PullbackEntry, PullbackWithTangentsEntry, ReverseRule, SavePolicy, From 93873b4751de49c80578c9017432e74671971ee4 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:27:16 +0900 Subject: [PATCH 06/32] docs: describe scalar rule provenance and validation --- README.md | 10 ++++++++++ crates/chainrules/README.md | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/README.md b/README.md index 55dac97..92df6a2 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,20 @@ engine that executes those rules over a tape. The boundary is deliberate: `chainrules` provides stateless scalar `foo`, `foo_frule`, and `foo_rrule` helpers. +`chainrules` is a landing zone for scalar rules ported or adapted from Julia's +`ChainRules.jl` where they fit this repository boundary, but `chainrules-rs` is +not a full port of `ChainRules.jl`. + See the crate READMEs for the supported scalar function inventory and examples. ## Testing +Scalar rules are checked in two complementary ways: + +- formula and behavior tests in `crates/chainrules/tests/scalarops_tests.rs` +- oracle replay tests in `crates/chainrules/tests/oracle_scalar_rules.rs` + against vendored published cases from `third_party/tensor-ad-oracles` + ```bash cargo test --workspace --release cargo llvm-cov --workspace --json --output-path coverage.json diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index 0aae8d2..b19cf91 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -37,6 +37,16 @@ This crate is intended as a landing zone for scalar rules ported or adapted from Julia's `ChainRules.jl` where they fit this crate boundary, but it is not a full port of `ChainRules.jl`. +## Validation + +Rules in this crate are not accepted on provenance alone. They are checked with +repository-local tests: + +- `tests/scalarops_tests.rs` covers direct formulas, edge cases, and + real/complex behavior +- `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from + `../../third_party/tensor-ad-oracles` + ## Examples ```rust From ab341bfad36c136799064f341048f10144c04d27 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:33:18 +0900 Subject: [PATCH 07/32] refactor: extend scalar trait surface for shared basis --- crates/chainrules/src/lib.rs | 8 +- crates/chainrules/src/scalar_ad.rs | 226 +++++++++++++++--- crates/chainrules/src/unary/mod.rs | 2 + crates/chainrules/src/unary/smooth.rs | 80 +++++++ crates/chainrules/tests/smooth_basis_tests.rs | 15 ++ 5 files changed, 288 insertions(+), 43 deletions(-) create mode 100644 crates/chainrules/src/unary/smooth.rs create mode 100644 crates/chainrules/tests/smooth_basis_tests.rs diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index e5b0f21..39da7ad 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -26,10 +26,10 @@ pub use scalar_ad::{handle_r_to_c_f32, handle_r_to_c_f64, ScalarAd}; pub use unary::{ acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, - conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, cosh_frule, cosh_rrule, exp, - exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, log, log1p, log1p_frule, log1p_rrule, - log_frule, log_rrule, sin, sin_frule, sin_rrule, sinh, sinh_frule, sinh_rrule, sqrt, - sqrt_frule, sqrt_rrule, tanh, tanh_frule, tanh_rrule, + cbrt, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, cosh_frule, cosh_rrule, + exp, exp2, exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, hypot, log, log1p, + log1p_frule, log1p_rrule, log2, log_frule, log_rrule, pow, sin, sin_frule, sin_rrule, sinh, + sinh_frule, sinh_rrule, sqrt, sqrt_frule, sqrt_rrule, tan, tanh, tanh_frule, tanh_rrule, }; #[cfg(test)] diff --git a/crates/chainrules/src/scalar_ad.rs b/crates/chainrules/src/scalar_ad.rs index 7b0a6dd..0ea43a3 100644 --- a/crates/chainrules/src/scalar_ad.rs +++ b/crates/chainrules/src/scalar_ad.rs @@ -1,7 +1,7 @@ use core::ops::{Add, Div, Mul, Neg, Sub}; -use num_complex::{Complex32, Complex64}; -use num_traits::Float; +use num_complex::{Complex32, Complex64, ComplexFloat}; +use num_traits::{Float, FloatConst, Zero}; /// Scalar trait used by elementary AD rule helpers. /// @@ -25,17 +25,29 @@ pub trait ScalarAd: + Div { /// Real exponent type for `powf`. - type Real: Copy + Float; + type Real: Copy + Float + FloatConst; /// Complex conjugate (identity for real scalars). fn conj(self) -> Self; + /// Reciprocal. + fn recip(self) -> Self; + + /// Cubic root. + fn cbrt(self) -> Self; + /// Square root. fn sqrt(self) -> Self; /// Exponential. fn exp(self) -> Self; + /// `2^self`. + fn exp2(self) -> Self; + + /// `10^self`. + fn exp10(self) -> Self; + /// `exp(self) - 1`. fn expm1(self) -> Self; @@ -45,12 +57,21 @@ pub trait ScalarAd: /// `ln(1 + self)`. fn log1p(self) -> Self; + /// Base-2 logarithm. + fn log2(self) -> Self; + + /// Base-10 logarithm. + fn log10(self) -> Self; + /// Sine. fn sin(self) -> Self; /// Cosine. fn cos(self) -> Self; + /// Tangent. + fn tan(self) -> Self; + /// Hyperbolic tangent. fn tanh(self) -> Self; @@ -78,12 +99,30 @@ pub trait ScalarAd: /// Area hyperbolic tangent. fn atanh(self) -> Self; + /// Absolute value. + fn abs(self) -> Self::Real; + + /// Absolute value squared. + fn abs2(self) -> Self::Real; + + /// Real part. + fn real(self) -> Self::Real; + + /// Imaginary part. + fn imag(self) -> Self::Real; + + /// Polar angle. + fn angle(self) -> Self::Real; + /// Power by real exponent. fn powf(self, exponent: Self::Real) -> Self; /// Power by integer exponent. fn powi(self, exponent: i32) -> Self; + /// Power by same-scalar exponent. + fn pow(self, exponent: Self) -> Self; + /// Convert real scalar to this scalar type. fn from_real(value: Self::Real) -> Self; @@ -100,76 +139,128 @@ macro_rules! impl_scalar_ad_real { self } + fn recip(self) -> Self { + <$ty as Float>::recip(self) + } + + fn cbrt(self) -> Self { + <$ty as Float>::cbrt(self) + } + fn sqrt(self) -> Self { - <$ty>::sqrt(self) + <$ty as Float>::sqrt(self) } fn exp(self) -> Self { - <$ty>::exp(self) + <$ty as Float>::exp(self) + } + + fn exp2(self) -> Self { + <$ty as Float>::exp2(self) + } + + fn exp10(self) -> Self { + (<$ty as Float>::exp(self * <$ty as FloatConst>::LN_10())) } fn expm1(self) -> Self { - <$ty>::exp_m1(self) + <$ty as Float>::exp_m1(self) } fn ln(self) -> Self { - <$ty>::ln(self) + <$ty as Float>::ln(self) } fn log1p(self) -> Self { - <$ty>::ln_1p(self) + <$ty as Float>::ln_1p(self) + } + + fn log2(self) -> Self { + <$ty as Float>::log2(self) + } + + fn log10(self) -> Self { + <$ty as Float>::log10(self) } fn sin(self) -> Self { - <$ty>::sin(self) + <$ty as Float>::sin(self) } fn cos(self) -> Self { - <$ty>::cos(self) + <$ty as Float>::cos(self) + } + + fn tan(self) -> Self { + <$ty as Float>::tan(self) } fn tanh(self) -> Self { - <$ty>::tanh(self) + <$ty as Float>::tanh(self) } fn asin(self) -> Self { - <$ty>::asin(self) + <$ty as Float>::asin(self) } fn acos(self) -> Self { - <$ty>::acos(self) + <$ty as Float>::acos(self) } fn atan(self) -> Self { - <$ty>::atan(self) + <$ty as Float>::atan(self) } fn sinh(self) -> Self { - <$ty>::sinh(self) + <$ty as Float>::sinh(self) } fn cosh(self) -> Self { - <$ty>::cosh(self) + <$ty as Float>::cosh(self) } fn asinh(self) -> Self { - <$ty>::asinh(self) + <$ty as Float>::asinh(self) } fn acosh(self) -> Self { - <$ty>::acosh(self) + <$ty as Float>::acosh(self) } fn atanh(self) -> Self { - <$ty>::atanh(self) + <$ty as Float>::atanh(self) + } + + fn abs(self) -> Self::Real { + <$ty as Float>::abs(self) + } + + fn abs2(self) -> Self::Real { + self * self + } + + fn real(self) -> Self::Real { + self + } + + fn imag(self) -> Self::Real { + <$ty as Zero>::zero() + } + + fn angle(self) -> Self::Real { + <$ty as Float>::atan2(<$ty as Zero>::zero(), self) } fn powf(self, exponent: Self::Real) -> Self { - <$ty>::powf(self, exponent) + <$ty as Float>::powf(self, exponent) } fn powi(self, exponent: i32) -> Self { - <$ty>::powi(self, exponent) + <$ty as Float>::powi(self, exponent) + } + + fn pow(self, exponent: Self) -> Self { + <$ty as Float>::powf(self, exponent) } fn from_real(value: Self::Real) -> Self { @@ -192,76 +283,133 @@ macro_rules! impl_scalar_ad_complex { <$complex_ty>::conj(&self) } + fn recip(self) -> Self { + <$complex_ty as ComplexFloat>::recip(self) + } + + fn cbrt(self) -> Self { + <$complex_ty as ComplexFloat>::cbrt(self) + } + fn sqrt(self) -> Self { - <$complex_ty>::sqrt(self) + <$complex_ty as ComplexFloat>::sqrt(self) } fn exp(self) -> Self { - <$complex_ty>::exp(self) + <$complex_ty as ComplexFloat>::exp(self) + } + + fn exp2(self) -> Self { + <$complex_ty as ComplexFloat>::exp2(self) + } + + fn exp10(self) -> Self { + (<$complex_ty as ComplexFloat>::exp( + self * <$complex_ty>::new( + <$real_ty as FloatConst>::LN_10(), + <$real_ty as Zero>::zero(), + ), + )) } fn expm1(self) -> Self { - <$complex_ty>::exp(self) - $one + <$complex_ty as ComplexFloat>::exp(self) - $one } fn ln(self) -> Self { - <$complex_ty>::ln(self) + <$complex_ty as ComplexFloat>::ln(self) } fn log1p(self) -> Self { - <$complex_ty>::ln(self + $one) + <$complex_ty as ComplexFloat>::ln(self + $one) + } + + fn log2(self) -> Self { + <$complex_ty as ComplexFloat>::log2(self) + } + + fn log10(self) -> Self { + <$complex_ty as ComplexFloat>::log10(self) } fn sin(self) -> Self { - <$complex_ty>::sin(self) + <$complex_ty as ComplexFloat>::sin(self) } fn cos(self) -> Self { - <$complex_ty>::cos(self) + <$complex_ty as ComplexFloat>::cos(self) + } + + fn tan(self) -> Self { + <$complex_ty as ComplexFloat>::tan(self) } fn tanh(self) -> Self { - <$complex_ty>::tanh(self) + <$complex_ty as ComplexFloat>::tanh(self) } fn asin(self) -> Self { - <$complex_ty>::asin(self) + <$complex_ty as ComplexFloat>::asin(self) } fn acos(self) -> Self { - <$complex_ty>::acos(self) + <$complex_ty as ComplexFloat>::acos(self) } fn atan(self) -> Self { - <$complex_ty>::atan(self) + <$complex_ty as ComplexFloat>::atan(self) } fn sinh(self) -> Self { - <$complex_ty>::sinh(self) + <$complex_ty as ComplexFloat>::sinh(self) } fn cosh(self) -> Self { - <$complex_ty>::cosh(self) + <$complex_ty as ComplexFloat>::cosh(self) } fn asinh(self) -> Self { - <$complex_ty>::asinh(self) + <$complex_ty as ComplexFloat>::asinh(self) } fn acosh(self) -> Self { - <$complex_ty>::acosh(self) + <$complex_ty as ComplexFloat>::acosh(self) } fn atanh(self) -> Self { - <$complex_ty>::atanh(self) + <$complex_ty as ComplexFloat>::atanh(self) + } + + fn abs(self) -> Self::Real { + <$complex_ty as ComplexFloat>::abs(self) + } + + fn abs2(self) -> Self::Real { + <$complex_ty>::norm_sqr(&self) + } + + fn real(self) -> Self::Real { + self.re + } + + fn imag(self) -> Self::Real { + self.im + } + + fn angle(self) -> Self::Real { + <$complex_ty as ComplexFloat>::arg(self) } fn powf(self, exponent: Self::Real) -> Self { - <$complex_ty>::powf(self, exponent) + <$complex_ty as ComplexFloat>::powf(self, exponent) } fn powi(self, exponent: i32) -> Self { - <$complex_ty>::powi(&self, exponent) + <$complex_ty as ComplexFloat>::powi(self, exponent) + } + + fn pow(self, exponent: Self) -> Self { + <$complex_ty as ComplexFloat>::powc(self, exponent) } fn from_real(value: Self::Real) -> Self { diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index 27488f4..031f94f 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -1,6 +1,7 @@ mod basic; mod exp_log; mod hyperbolic; +mod smooth; mod trig; use crate::ScalarAd; @@ -19,6 +20,7 @@ pub use hyperbolic::{ atanh_rrule, cosh, cosh_frule, cosh_rrule, sinh, sinh_frule, sinh_rrule, tanh, tanh_frule, tanh_rrule, }; +pub use smooth::{cbrt, exp2, hypot, log2, pow, tan}; pub use trig::{ acos, acos_frule, acos_rrule, asin, asin_frule, asin_rrule, atan, atan_frule, atan_rrule, cos, cos_frule, cos_rrule, sin, sin_frule, sin_rrule, diff --git a/crates/chainrules/src/unary/smooth.rs b/crates/chainrules/src/unary/smooth.rs new file mode 100644 index 0000000..9d13393 --- /dev/null +++ b/crates/chainrules/src/unary/smooth.rs @@ -0,0 +1,80 @@ +use crate::ScalarAd; +use num_traits::Float; + +/// Primal `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt; +/// +/// assert_eq!(cbrt(8.0_f64), 2.0); +/// ``` +pub fn cbrt(x: S) -> S { + x.cbrt() +} + +/// Primal `exp2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp2; +/// +/// assert_eq!(exp2(3.0_f64), 8.0); +/// ``` +pub fn exp2(x: S) -> S { + x.exp2() +} + +/// Primal `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot; +/// +/// assert_eq!(hypot(3.0_f64, 4.0_f64), 5.0); +/// ``` +pub fn hypot(x: R, y: R) -> R { + x.hypot(y) +} + +/// Primal `log2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log2; +/// +/// assert_eq!(log2(8.0_f64), 3.0); +/// ``` +pub fn log2(x: S) -> S { + x.log2() +} + +/// Primal `pow`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::pow; +/// +/// assert_eq!(pow(2.0_f64, 3.0_f64), 8.0); +/// ``` +pub fn pow(x: S, exponent: S) -> S { + x.pow(exponent) +} + +/// Primal `tan`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tan; +/// +/// assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12); +/// ``` +pub fn tan(x: S) -> S { + x.tan() +} diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs new file mode 100644 index 0000000..3d91d17 --- /dev/null +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -0,0 +1,15 @@ +use chainrules::{cbrt, exp2, hypot, log2, pow, tan}; +use num_complex::Complex64; + +#[test] +fn smooth_basis_helpers_are_reexported_from_chainrules() { + assert!((cbrt(8.0_f64) - 2.0).abs() < 1.0e-12); + assert!((exp2(3.0_f64) - 8.0).abs() < 1.0e-12); + assert!((hypot(3.0_f64, 4.0_f64) - 5.0).abs() < 1.0e-12); + assert!((log2(8.0_f64) - 3.0).abs() < 1.0e-12); + assert!((pow(2.0_f64, 3.0_f64) - 8.0).abs() < 1.0e-12); + assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1.0e-12); + + let z = Complex64::new(1.0, 2.0); + let _ = pow(z, Complex64::new(2.0, 0.0)); +} From ee78587142ffcd82b15ab440574bd54695e67568 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:48:43 +0900 Subject: [PATCH 08/32] fix: split scalar ad surface and add behavior coverage --- crates/chainrules/src/scalar_ad/complex.rs | 155 ++++++++++++ crates/chainrules/src/scalar_ad/mod.rs | 167 +++++++++++++ crates/chainrules/src/scalar_ad/real.rs | 149 ++++++++++++ crates/chainrules/src/tests/behavior.rs | 254 +++++++++++++++----- crates/chainrules/src/tests/organization.rs | 21 +- 5 files changed, 683 insertions(+), 63 deletions(-) create mode 100644 crates/chainrules/src/scalar_ad/complex.rs create mode 100644 crates/chainrules/src/scalar_ad/mod.rs create mode 100644 crates/chainrules/src/scalar_ad/real.rs diff --git a/crates/chainrules/src/scalar_ad/complex.rs b/crates/chainrules/src/scalar_ad/complex.rs new file mode 100644 index 0000000..6d55e9f --- /dev/null +++ b/crates/chainrules/src/scalar_ad/complex.rs @@ -0,0 +1,155 @@ +use super::ScalarAd; +use num_complex::{Complex32, Complex64, ComplexFloat}; +use num_traits::{FloatConst, Zero}; + +macro_rules! impl_scalar_ad_complex { + ($complex_ty:ty, $real_ty:ty, $one:expr) => { + impl ScalarAd for $complex_ty { + type Real = $real_ty; + + fn conj(self) -> Self { + <$complex_ty>::conj(&self) + } + + fn recip(self) -> Self { + <$complex_ty as ComplexFloat>::recip(self) + } + + fn cbrt(self) -> Self { + <$complex_ty as ComplexFloat>::cbrt(self) + } + + fn sqrt(self) -> Self { + <$complex_ty as ComplexFloat>::sqrt(self) + } + + fn exp(self) -> Self { + <$complex_ty as ComplexFloat>::exp(self) + } + + fn exp2(self) -> Self { + <$complex_ty as ComplexFloat>::exp2(self) + } + + fn exp10(self) -> Self { + <$complex_ty as ComplexFloat>::exp( + self * <$complex_ty>::new( + <$real_ty as FloatConst>::LN_10(), + <$real_ty as Zero>::zero(), + ), + ) + } + + fn expm1(self) -> Self { + <$complex_ty as ComplexFloat>::exp(self) - $one + } + + fn ln(self) -> Self { + <$complex_ty as ComplexFloat>::ln(self) + } + + fn log1p(self) -> Self { + <$complex_ty as ComplexFloat>::ln(self + $one) + } + + fn log2(self) -> Self { + <$complex_ty as ComplexFloat>::log2(self) + } + + fn log10(self) -> Self { + <$complex_ty as ComplexFloat>::log10(self) + } + + fn sin(self) -> Self { + <$complex_ty as ComplexFloat>::sin(self) + } + + fn cos(self) -> Self { + <$complex_ty as ComplexFloat>::cos(self) + } + + fn tan(self) -> Self { + <$complex_ty as ComplexFloat>::tan(self) + } + + fn tanh(self) -> Self { + <$complex_ty as ComplexFloat>::tanh(self) + } + + fn asin(self) -> Self { + <$complex_ty as ComplexFloat>::asin(self) + } + + fn acos(self) -> Self { + <$complex_ty as ComplexFloat>::acos(self) + } + + fn atan(self) -> Self { + <$complex_ty as ComplexFloat>::atan(self) + } + + fn sinh(self) -> Self { + <$complex_ty as ComplexFloat>::sinh(self) + } + + fn cosh(self) -> Self { + <$complex_ty as ComplexFloat>::cosh(self) + } + + fn asinh(self) -> Self { + <$complex_ty as ComplexFloat>::asinh(self) + } + + fn acosh(self) -> Self { + <$complex_ty as ComplexFloat>::acosh(self) + } + + fn atanh(self) -> Self { + <$complex_ty as ComplexFloat>::atanh(self) + } + + fn abs(self) -> Self::Real { + <$complex_ty as ComplexFloat>::abs(self) + } + + fn abs2(self) -> Self::Real { + <$complex_ty>::norm_sqr(&self) + } + + fn real(self) -> Self::Real { + self.re + } + + fn imag(self) -> Self::Real { + self.im + } + + fn angle(self) -> Self::Real { + <$complex_ty as ComplexFloat>::arg(self) + } + + fn powf(self, exponent: Self::Real) -> Self { + <$complex_ty as ComplexFloat>::powf(self, exponent) + } + + fn powi(self, exponent: i32) -> Self { + <$complex_ty as ComplexFloat>::powi(self, exponent) + } + + fn pow(self, exponent: Self) -> Self { + <$complex_ty as ComplexFloat>::powc(self, exponent) + } + + fn from_real(value: Self::Real) -> Self { + <$complex_ty>::new(value, 0.0) + } + + fn from_i32(value: i32) -> Self { + <$complex_ty>::new(value as $real_ty, 0.0) + } + } + }; +} + +impl_scalar_ad_complex!(Complex32, f32, Complex32::new(1.0, 0.0)); +impl_scalar_ad_complex!(Complex64, f64, Complex64::new(1.0, 0.0)); diff --git a/crates/chainrules/src/scalar_ad/mod.rs b/crates/chainrules/src/scalar_ad/mod.rs new file mode 100644 index 0000000..a40d67b --- /dev/null +++ b/crates/chainrules/src/scalar_ad/mod.rs @@ -0,0 +1,167 @@ +use core::ops::{Add, Div, Mul, Neg, Sub}; + +use num_complex::{Complex32, Complex64}; +use num_traits::{Float, FloatConst}; + +/// Scalar trait used by elementary AD rule helpers. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ScalarAd; +/// +/// fn takes_scalar(_x: S) {} +/// +/// takes_scalar(1.0_f32); +/// takes_scalar(1.0_f64); +/// ``` +pub trait ScalarAd: + Copy + + PartialEq + + Neg + + Add + + Sub + + Mul + + Div +{ + /// Real exponent type for `powf`. + type Real: Copy + Float + FloatConst; + + /// Complex conjugate (identity for real scalars). + fn conj(self) -> Self; + + /// Reciprocal. + fn recip(self) -> Self; + + /// Cubic root. + fn cbrt(self) -> Self; + + /// Square root. + fn sqrt(self) -> Self; + + /// Exponential. + fn exp(self) -> Self; + + /// `2^self`. + fn exp2(self) -> Self; + + /// `10^self`. + fn exp10(self) -> Self; + + /// `exp(self) - 1`. + fn expm1(self) -> Self; + + /// Natural logarithm. + fn ln(self) -> Self; + + /// `ln(1 + self)`. + fn log1p(self) -> Self; + + /// Base-2 logarithm. + fn log2(self) -> Self; + + /// Base-10 logarithm. + fn log10(self) -> Self; + + /// Sine. + fn sin(self) -> Self; + + /// Cosine. + fn cos(self) -> Self; + + /// Tangent. + fn tan(self) -> Self; + + /// Hyperbolic tangent. + fn tanh(self) -> Self; + + /// Arc sine. + fn asin(self) -> Self; + + /// Arc cosine. + fn acos(self) -> Self; + + /// Arc tangent. + fn atan(self) -> Self; + + /// Hyperbolic sine. + fn sinh(self) -> Self; + + /// Hyperbolic cosine. + fn cosh(self) -> Self; + + /// Area hyperbolic sine. + fn asinh(self) -> Self; + + /// Area hyperbolic cosine. + fn acosh(self) -> Self; + + /// Area hyperbolic tangent. + fn atanh(self) -> Self; + + /// Absolute value. + fn abs(self) -> Self::Real; + + /// Absolute value squared. + fn abs2(self) -> Self::Real; + + /// Real part. + fn real(self) -> Self::Real; + + /// Imaginary part. + fn imag(self) -> Self::Real; + + /// Polar angle. + fn angle(self) -> Self::Real; + + /// Power by real exponent. + fn powf(self, exponent: Self::Real) -> Self; + + /// Power by integer exponent. + fn powi(self, exponent: i32) -> Self; + + /// Power by same-scalar exponent. + fn pow(self, exponent: Self) -> Self; + + /// Convert real scalar to this scalar type. + fn from_real(value: Self::Real) -> Self; + + /// Convert signed integer to this scalar type. + fn from_i32(value: i32) -> Self; +} + +mod complex; +mod real; + +/// PyTorch-style real-input / complex-gradient projection helper (`handle_r_to_c`). +/// +/// This is equivalent to taking the real part when a gradient for real input +/// becomes complex during intermediate algebra. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::handle_r_to_c_f64; +/// use num_complex::Complex64; +/// +/// let g = Complex64::new(1.25, -3.0); +/// assert_eq!(handle_r_to_c_f64(g), 1.25); +/// ``` +pub fn handle_r_to_c_f64(gradient: Complex64) -> f64 { + gradient.re +} + +/// `f32` variant of [`handle_r_to_c_f64`]. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::handle_r_to_c_f32; +/// use num_complex::Complex32; +/// +/// let g = Complex32::new(2.0, 4.0); +/// assert_eq!(handle_r_to_c_f32(g), 2.0); +/// ``` +pub fn handle_r_to_c_f32(gradient: Complex32) -> f32 { + gradient.re +} diff --git a/crates/chainrules/src/scalar_ad/real.rs b/crates/chainrules/src/scalar_ad/real.rs new file mode 100644 index 0000000..2069186 --- /dev/null +++ b/crates/chainrules/src/scalar_ad/real.rs @@ -0,0 +1,149 @@ +use super::ScalarAd; +use num_traits::{Float, FloatConst, Zero}; + +macro_rules! impl_scalar_ad_real { + ($ty:ty) => { + impl ScalarAd for $ty { + type Real = $ty; + + fn conj(self) -> Self { + self + } + + fn recip(self) -> Self { + <$ty as Float>::recip(self) + } + + fn cbrt(self) -> Self { + <$ty as Float>::cbrt(self) + } + + fn sqrt(self) -> Self { + <$ty as Float>::sqrt(self) + } + + fn exp(self) -> Self { + <$ty as Float>::exp(self) + } + + fn exp2(self) -> Self { + <$ty as Float>::exp2(self) + } + + fn exp10(self) -> Self { + <$ty as Float>::exp(self * <$ty as FloatConst>::LN_10()) + } + + fn expm1(self) -> Self { + <$ty as Float>::exp_m1(self) + } + + fn ln(self) -> Self { + <$ty as Float>::ln(self) + } + + fn log1p(self) -> Self { + <$ty as Float>::ln_1p(self) + } + + fn log2(self) -> Self { + <$ty as Float>::log2(self) + } + + fn log10(self) -> Self { + <$ty as Float>::log10(self) + } + + fn sin(self) -> Self { + <$ty as Float>::sin(self) + } + + fn cos(self) -> Self { + <$ty as Float>::cos(self) + } + + fn tan(self) -> Self { + <$ty as Float>::tan(self) + } + + fn tanh(self) -> Self { + <$ty as Float>::tanh(self) + } + + fn asin(self) -> Self { + <$ty as Float>::asin(self) + } + + fn acos(self) -> Self { + <$ty as Float>::acos(self) + } + + fn atan(self) -> Self { + <$ty as Float>::atan(self) + } + + fn sinh(self) -> Self { + <$ty as Float>::sinh(self) + } + + fn cosh(self) -> Self { + <$ty as Float>::cosh(self) + } + + fn asinh(self) -> Self { + <$ty as Float>::asinh(self) + } + + fn acosh(self) -> Self { + <$ty as Float>::acosh(self) + } + + fn atanh(self) -> Self { + <$ty as Float>::atanh(self) + } + + fn abs(self) -> Self::Real { + <$ty as Float>::abs(self) + } + + fn abs2(self) -> Self::Real { + self * self + } + + fn real(self) -> Self::Real { + self + } + + fn imag(self) -> Self::Real { + <$ty as Zero>::zero() + } + + fn angle(self) -> Self::Real { + <$ty as Float>::atan2(<$ty as Zero>::zero(), self) + } + + fn powf(self, exponent: Self::Real) -> Self { + <$ty as Float>::powf(self, exponent) + } + + fn powi(self, exponent: i32) -> Self { + <$ty as Float>::powi(self, exponent) + } + + fn pow(self, exponent: Self) -> Self { + <$ty as Float>::powf(self, exponent) + } + + fn from_real(value: Self::Real) -> Self { + value + } + + fn from_i32(value: i32) -> Self { + value as $ty + } + } + }; +} + +impl_scalar_ad_real!(f32); +impl_scalar_ad_real!(f64); diff --git a/crates/chainrules/src/tests/behavior.rs b/crates/chainrules/src/tests/behavior.rs index 4fe310a..ee52acf 100644 --- a/crates/chainrules/src/tests/behavior.rs +++ b/crates/chainrules/src/tests/behavior.rs @@ -1,4 +1,4 @@ -use num_complex::{Complex32, Complex64}; +use num_complex::{Complex32, Complex64, ComplexFloat}; use crate::{ acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, @@ -78,31 +78,46 @@ fn scalar_ad_real_impls_match_std_real_ops() { #[test] fn scalar_ad_complex_impls_match_std_complex_ops() { let x32 = Complex32::new(0.25, -0.5); - assert_close_c32(::conj(x32), x32.conj()); - assert_close_c32(::sqrt(x32), x32.sqrt()); - assert_close_c32(::exp(x32), x32.exp()); + assert_close_c32(::conj(x32), ComplexFloat::conj(x32)); + assert_close_c32(::sqrt(x32), ComplexFloat::sqrt(x32)); + assert_close_c32(::exp(x32), ComplexFloat::exp(x32)); assert_close_c32( ::expm1(x32), - x32.exp() - Complex32::new(1.0, 0.0), + ComplexFloat::exp(x32) - Complex32::new(1.0, 0.0), ); - assert_close_c32(::ln(x32), x32.ln()); + assert_close_c32(::ln(x32), ComplexFloat::ln(x32)); assert_close_c32( ::log1p(x32), - (x32 + Complex32::new(1.0, 0.0)).ln(), - ); - assert_close_c32(::sin(x32), x32.sin()); - assert_close_c32(::cos(x32), x32.cos()); - assert_close_c32(::tanh(x32), x32.tanh()); - assert_close_c32(::asin(x32), x32.asin()); - assert_close_c32(::acos(x32), x32.acos()); - assert_close_c32(::atan(x32), x32.atan()); - assert_close_c32(::sinh(x32), x32.sinh()); - assert_close_c32(::cosh(x32), x32.cosh()); - assert_close_c32(::asinh(x32), x32.asinh()); - assert_close_c32(::acosh(x32), x32.acosh()); - assert_close_c32(::atanh(x32), x32.atanh()); - assert_close_c32(::powf(x32, 2.0), x32.powf(2.0)); - assert_close_c32(::powi(x32, 3), x32.powi(3)); + ComplexFloat::ln(x32 + Complex32::new(1.0, 0.0)), + ); + assert_close_c32(::sin(x32), ComplexFloat::sin(x32)); + assert_close_c32(::cos(x32), ComplexFloat::cos(x32)); + assert_close_c32(::tanh(x32), ComplexFloat::tanh(x32)); + assert_close_c32(::asin(x32), ComplexFloat::asin(x32)); + assert_close_c32(::acos(x32), ComplexFloat::acos(x32)); + assert_close_c32(::atan(x32), ComplexFloat::atan(x32)); + assert_close_c32(::sinh(x32), ComplexFloat::sinh(x32)); + assert_close_c32(::cosh(x32), ComplexFloat::cosh(x32)); + assert_close_c32( + ::asinh(x32), + ComplexFloat::asinh(x32), + ); + assert_close_c32( + ::acosh(x32), + ComplexFloat::acosh(x32), + ); + assert_close_c32( + ::atanh(x32), + ComplexFloat::atanh(x32), + ); + assert_close_c32( + ::powf(x32, 2.0), + ComplexFloat::powf(x32, 2.0), + ); + assert_close_c32( + ::powi(x32, 3), + ComplexFloat::powi(x32, 3), + ); assert_eq!( ::from_real(1.5), Complex32::new(1.5, 0.0) @@ -113,31 +128,46 @@ fn scalar_ad_complex_impls_match_std_complex_ops() { ); let x64 = Complex64::new(0.5, 0.75); - assert_close_c64(::conj(x64), x64.conj()); - assert_close_c64(::sqrt(x64), x64.sqrt()); - assert_close_c64(::exp(x64), x64.exp()); + assert_close_c64(::conj(x64), ComplexFloat::conj(x64)); + assert_close_c64(::sqrt(x64), ComplexFloat::sqrt(x64)); + assert_close_c64(::exp(x64), ComplexFloat::exp(x64)); assert_close_c64( ::expm1(x64), - x64.exp() - Complex64::new(1.0, 0.0), + ComplexFloat::exp(x64) - Complex64::new(1.0, 0.0), ); - assert_close_c64(::ln(x64), x64.ln()); + assert_close_c64(::ln(x64), ComplexFloat::ln(x64)); assert_close_c64( ::log1p(x64), - (x64 + Complex64::new(1.0, 0.0)).ln(), - ); - assert_close_c64(::sin(x64), x64.sin()); - assert_close_c64(::cos(x64), x64.cos()); - assert_close_c64(::tanh(x64), x64.tanh()); - assert_close_c64(::asin(x64), x64.asin()); - assert_close_c64(::acos(x64), x64.acos()); - assert_close_c64(::atan(x64), x64.atan()); - assert_close_c64(::sinh(x64), x64.sinh()); - assert_close_c64(::cosh(x64), x64.cosh()); - assert_close_c64(::asinh(x64), x64.asinh()); - assert_close_c64(::acosh(x64), x64.acosh()); - assert_close_c64(::atanh(x64), x64.atanh()); - assert_close_c64(::powf(x64, 1.5), x64.powf(1.5)); - assert_close_c64(::powi(x64, 2), x64.powi(2)); + ComplexFloat::ln(x64 + Complex64::new(1.0, 0.0)), + ); + assert_close_c64(::sin(x64), ComplexFloat::sin(x64)); + assert_close_c64(::cos(x64), ComplexFloat::cos(x64)); + assert_close_c64(::tanh(x64), ComplexFloat::tanh(x64)); + assert_close_c64(::asin(x64), ComplexFloat::asin(x64)); + assert_close_c64(::acos(x64), ComplexFloat::acos(x64)); + assert_close_c64(::atan(x64), ComplexFloat::atan(x64)); + assert_close_c64(::sinh(x64), ComplexFloat::sinh(x64)); + assert_close_c64(::cosh(x64), ComplexFloat::cosh(x64)); + assert_close_c64( + ::asinh(x64), + ComplexFloat::asinh(x64), + ); + assert_close_c64( + ::acosh(x64), + ComplexFloat::acosh(x64), + ); + assert_close_c64( + ::atanh(x64), + ComplexFloat::atanh(x64), + ); + assert_close_c64( + ::powf(x64, 1.5), + ComplexFloat::powf(x64, 1.5), + ); + assert_close_c64( + ::powi(x64, 2), + ComplexFloat::powi(x64, 2), + ); assert_eq!( ::from_real(2.5), Complex64::new(2.5, 0.0) @@ -148,6 +178,98 @@ fn scalar_ad_complex_impls_match_std_complex_ops() { ); } +#[test] +fn scalar_ad_real_extended_surface_matches_std_real_ops() { + let x32 = 0.25_f32; + assert_close_f32(::recip(x32), x32.recip()); + assert_close_f32(::cbrt(x32), x32.cbrt()); + assert_close_f32(::exp2(x32), x32.exp2()); + assert_close_f32(::exp10(x32), 10.0_f32.powf(x32)); + assert_close_f32(::log2(x32), x32.log2()); + assert_close_f32(::log10(x32), x32.log10()); + assert_close_f32(::tan(x32), x32.tan()); + assert_close_f32(::abs(x32), x32.abs()); + assert_close_f32(::abs2(x32), x32 * x32); + assert_close_f32(::real(x32), x32); + assert_close_f32(::imag(x32), 0.0); + assert_close_f32(::angle(-x32), 0.0_f32.atan2(-x32)); + assert_close_f32(::pow(x32, 1.5), x32.powf(1.5)); + + let x64 = 0.5_f64; + assert_close_f64(::recip(x64), x64.recip()); + assert_close_f64(::cbrt(x64), x64.cbrt()); + assert_close_f64(::exp2(x64), x64.exp2()); + assert_close_f64(::exp10(x64), 10.0_f64.powf(x64)); + assert_close_f64(::log2(x64), x64.log2()); + assert_close_f64(::log10(x64), x64.log10()); + assert_close_f64(::tan(x64), x64.tan()); + assert_close_f64(::abs(x64), x64.abs()); + assert_close_f64(::abs2(x64), x64 * x64); + assert_close_f64(::real(x64), x64); + assert_close_f64(::imag(x64), 0.0); + assert_close_f64(::angle(-x64), 0.0_f64.atan2(-x64)); + assert_close_f64(::pow(x64, 1.5), x64.powf(1.5)); +} + +#[test] +fn scalar_ad_complex_extended_surface_matches_std_complex_ops() { + let x32 = Complex32::new(0.25, -0.5); + let y32 = Complex32::new(1.25, 0.75); + assert_close_c32( + ::recip(x32), + ComplexFloat::recip(x32), + ); + assert_close_c32(::cbrt(x32), ComplexFloat::cbrt(x32)); + assert_close_c32(::exp2(x32), ComplexFloat::exp2(x32)); + assert_close_c32( + ::exp10(x32), + ComplexFloat::exp(x32 * Complex32::new(std::f32::consts::LN_10, 0.0)), + ); + assert_close_c32(::log2(x32), ComplexFloat::log2(x32)); + assert_close_c32( + ::log10(x32), + ComplexFloat::log10(x32), + ); + assert_close_c32(::tan(x32), ComplexFloat::tan(x32)); + assert_close_f32(::abs(x32), x32.norm()); + assert_close_f32(::abs2(x32), x32.norm_sqr()); + assert_close_f32(::real(x32), x32.re); + assert_close_f32(::imag(x32), x32.im); + assert_close_f32(::angle(x32), x32.arg()); + assert_close_c32( + ::pow(x32, y32), + ComplexFloat::powc(x32, y32), + ); + + let x64 = Complex64::new(0.5, 0.75); + let y64 = Complex64::new(1.5, -0.25); + assert_close_c64( + ::recip(x64), + ComplexFloat::recip(x64), + ); + assert_close_c64(::cbrt(x64), ComplexFloat::cbrt(x64)); + assert_close_c64(::exp2(x64), ComplexFloat::exp2(x64)); + assert_close_c64( + ::exp10(x64), + ComplexFloat::exp(x64 * Complex64::new(std::f64::consts::LN_10, 0.0)), + ); + assert_close_c64(::log2(x64), ComplexFloat::log2(x64)); + assert_close_c64( + ::log10(x64), + ComplexFloat::log10(x64), + ); + assert_close_c64(::tan(x64), ComplexFloat::tan(x64)); + assert_close_f64(::abs(x64), x64.norm()); + assert_close_f64(::abs2(x64), x64.norm_sqr()); + assert_close_f64(::real(x64), x64.re); + assert_close_f64(::imag(x64), x64.im); + assert_close_f64(::angle(x64), x64.arg()); + assert_close_c64( + ::pow(x64, y64), + ComplexFloat::powc(x64, y64), + ); +} + #[test] fn direct_entrypoints_match_real_projection_and_atan2_formulas() { assert_eq!(handle_r_to_c_f32(Complex32::new(2.0, -5.0)), 2.0); @@ -164,10 +286,10 @@ fn direct_entrypoints_match_real_projection_and_atan2_formulas() { #[test] fn unary_entrypoints_match_forward_and_reverse_formulas() { let complex = Complex32::new(1.0, -2.0); - assert_eq!(conj(complex), complex.conj()); + assert_eq!(conj(complex), ComplexFloat::conj(complex)); let (_y, dy) = conj_frule(complex, Complex32::new(3.0, 4.0)); assert_eq!(dy, Complex32::new(3.0, -4.0)); - assert_eq!(conj_rrule(complex), complex.conj()); + assert_eq!(conj_rrule(complex), ComplexFloat::conj(complex)); assert_eq!(sqrt(9.0_f32), 3.0); let (sqrt_y, sqrt_dy) = sqrt_frule(9.0_f32, 2.0_f32); @@ -305,17 +427,17 @@ fn trig_and_hyperbolic_primal_entrypoints_match_std_ops() { assert_close_f64(acosh(acosh_real), acosh_real.acosh()); let complex = Complex64::new(0.25, -0.5); - assert_close_c64(sin(complex), complex.sin()); - assert_close_c64(cos(complex), complex.cos()); - assert_close_c64(tanh(complex), complex.tanh()); - assert_close_c64(asin(complex), complex.asin()); - assert_close_c64(acos(complex), complex.acos()); - assert_close_c64(atan(complex), complex.atan()); - assert_close_c64(sinh(complex), complex.sinh()); - assert_close_c64(cosh(complex), complex.cosh()); - assert_close_c64(asinh(complex), complex.asinh()); - assert_close_c64(acosh(complex), complex.acosh()); - assert_close_c64(atanh(complex), complex.atanh()); + assert_close_c64(sin(complex), ComplexFloat::sin(complex)); + assert_close_c64(cos(complex), ComplexFloat::cos(complex)); + assert_close_c64(tanh(complex), ComplexFloat::tanh(complex)); + assert_close_c64(asin(complex), ComplexFloat::asin(complex)); + assert_close_c64(acos(complex), ComplexFloat::acos(complex)); + assert_close_c64(atan(complex), ComplexFloat::atan(complex)); + assert_close_c64(sinh(complex), ComplexFloat::sinh(complex)); + assert_close_c64(cosh(complex), ComplexFloat::cosh(complex)); + assert_close_c64(asinh(complex), ComplexFloat::asinh(complex)); + assert_close_c64(acosh(complex), ComplexFloat::acosh(complex)); + assert_close_c64(atanh(complex), ComplexFloat::atanh(complex)); } #[test] @@ -325,26 +447,34 @@ fn extended_complex_unary_rules_conjugate_their_jacobians() { let cotangent = Complex64::new(0.5, -1.25); let (_sin_y, sin_dy) = sin_frule(x, dx); - assert_close_c64(sin_dy, dx * x.cos().conj()); - assert_close_c64(sin_rrule(x, cotangent), cotangent * x.cos().conj()); + assert_close_c64(sin_dy, dx * ComplexFloat::conj(ComplexFloat::cos(x))); + assert_close_c64( + sin_rrule(x, cotangent), + cotangent * ComplexFloat::conj(ComplexFloat::cos(x)), + ); let (_cos_y, cos_dy) = cos_frule(x, dx); - assert_close_c64(cos_dy, dx * (-x.sin()).conj()); - assert_close_c64(cos_rrule(x, cotangent), cotangent * (-x.sin()).conj()); + assert_close_c64(cos_dy, dx * ComplexFloat::conj(-ComplexFloat::sin(x))); + assert_close_c64( + cos_rrule(x, cotangent), + cotangent * ComplexFloat::conj(-ComplexFloat::sin(x)), + ); - let tanh_y = x.tanh(); + let tanh_y = ComplexFloat::tanh(x); let (_tanh_primal, tanh_dy) = tanh_frule(x, dx); assert_close_c64( tanh_dy, - dx * (Complex64::new(1.0, 0.0) - tanh_y * tanh_y).conj(), + dx * ComplexFloat::conj(Complex64::new(1.0, 0.0) - tanh_y * tanh_y), ); assert_close_c64( tanh_rrule(tanh_y, cotangent), - cotangent * (Complex64::new(1.0, 0.0) - tanh_y * tanh_y).conj(), + cotangent * ComplexFloat::conj(Complex64::new(1.0, 0.0) - tanh_y * tanh_y), ); let (_asinh_y, asinh_dy) = asinh_frule(x, dx); - let asinh_scale = (Complex64::new(1.0, 0.0) / (Complex64::new(1.0, 0.0) + x * x).sqrt()).conj(); + let asinh_scale = ComplexFloat::conj( + Complex64::new(1.0, 0.0) / ComplexFloat::sqrt(Complex64::new(1.0, 0.0) + x * x), + ); assert_close_c64(asinh_dy, dx * asinh_scale); assert_close_c64(asinh_rrule(x, cotangent), cotangent * asinh_scale); } diff --git a/crates/chainrules/src/tests/organization.rs b/crates/chainrules/src/tests/organization.rs index 0f6f72e..123787e 100644 --- a/crates/chainrules/src/tests/organization.rs +++ b/crates/chainrules/src/tests/organization.rs @@ -10,7 +10,6 @@ fn assert_line_count(path: &str, content: &str, max_lines: usize) { #[test] fn chainrules_modules_stay_under_size_guideline() { assert_line_count("../lib.rs", include_str!("../lib.rs"), 120); - assert_line_count("../scalar_ad.rs", include_str!("../scalar_ad.rs"), 320); assert_line_count("../binary.rs", include_str!("../binary.rs"), 260); assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 80); assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 80); @@ -25,6 +24,26 @@ fn chainrules_modules_stay_under_size_guideline() { include_str!("../unary/hyperbolic.rs"), 140, ); + assert_line_count( + "../unary/smooth.rs", + include_str!("../unary/smooth.rs"), + 120, + ); assert_line_count("../power.rs", include_str!("../power.rs"), 180); assert_line_count("../real_ops.rs", include_str!("../real_ops.rs"), 120); + assert_line_count( + "../scalar_ad/mod.rs", + include_str!("../scalar_ad/mod.rs"), + 220, + ); + assert_line_count( + "../scalar_ad/real.rs", + include_str!("../scalar_ad/real.rs"), + 220, + ); + assert_line_count( + "../scalar_ad/complex.rs", + include_str!("../scalar_ad/complex.rs"), + 220, + ); } From 6e15cbd4e73dc2bb4ec8f65e319df9f71423ed39 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 09:48:49 +0900 Subject: [PATCH 09/32] fix: remove legacy scalar ad module file --- crates/chainrules/src/scalar_ad.rs | 462 ----------------------------- 1 file changed, 462 deletions(-) delete mode 100644 crates/chainrules/src/scalar_ad.rs diff --git a/crates/chainrules/src/scalar_ad.rs b/crates/chainrules/src/scalar_ad.rs deleted file mode 100644 index 0ea43a3..0000000 --- a/crates/chainrules/src/scalar_ad.rs +++ /dev/null @@ -1,462 +0,0 @@ -use core::ops::{Add, Div, Mul, Neg, Sub}; - -use num_complex::{Complex32, Complex64, ComplexFloat}; -use num_traits::{Float, FloatConst, Zero}; - -/// Scalar trait used by elementary AD rule helpers. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::ScalarAd; -/// -/// fn takes_scalar(_x: S) {} -/// -/// takes_scalar(1.0_f32); -/// takes_scalar(1.0_f64); -/// ``` -pub trait ScalarAd: - Copy - + PartialEq - + Neg - + Add - + Sub - + Mul - + Div -{ - /// Real exponent type for `powf`. - type Real: Copy + Float + FloatConst; - - /// Complex conjugate (identity for real scalars). - fn conj(self) -> Self; - - /// Reciprocal. - fn recip(self) -> Self; - - /// Cubic root. - fn cbrt(self) -> Self; - - /// Square root. - fn sqrt(self) -> Self; - - /// Exponential. - fn exp(self) -> Self; - - /// `2^self`. - fn exp2(self) -> Self; - - /// `10^self`. - fn exp10(self) -> Self; - - /// `exp(self) - 1`. - fn expm1(self) -> Self; - - /// Natural logarithm. - fn ln(self) -> Self; - - /// `ln(1 + self)`. - fn log1p(self) -> Self; - - /// Base-2 logarithm. - fn log2(self) -> Self; - - /// Base-10 logarithm. - fn log10(self) -> Self; - - /// Sine. - fn sin(self) -> Self; - - /// Cosine. - fn cos(self) -> Self; - - /// Tangent. - fn tan(self) -> Self; - - /// Hyperbolic tangent. - fn tanh(self) -> Self; - - /// Arc sine. - fn asin(self) -> Self; - - /// Arc cosine. - fn acos(self) -> Self; - - /// Arc tangent. - fn atan(self) -> Self; - - /// Hyperbolic sine. - fn sinh(self) -> Self; - - /// Hyperbolic cosine. - fn cosh(self) -> Self; - - /// Area hyperbolic sine. - fn asinh(self) -> Self; - - /// Area hyperbolic cosine. - fn acosh(self) -> Self; - - /// Area hyperbolic tangent. - fn atanh(self) -> Self; - - /// Absolute value. - fn abs(self) -> Self::Real; - - /// Absolute value squared. - fn abs2(self) -> Self::Real; - - /// Real part. - fn real(self) -> Self::Real; - - /// Imaginary part. - fn imag(self) -> Self::Real; - - /// Polar angle. - fn angle(self) -> Self::Real; - - /// Power by real exponent. - fn powf(self, exponent: Self::Real) -> Self; - - /// Power by integer exponent. - fn powi(self, exponent: i32) -> Self; - - /// Power by same-scalar exponent. - fn pow(self, exponent: Self) -> Self; - - /// Convert real scalar to this scalar type. - fn from_real(value: Self::Real) -> Self; - - /// Convert signed integer to this scalar type. - fn from_i32(value: i32) -> Self; -} - -macro_rules! impl_scalar_ad_real { - ($ty:ty) => { - impl ScalarAd for $ty { - type Real = $ty; - - fn conj(self) -> Self { - self - } - - fn recip(self) -> Self { - <$ty as Float>::recip(self) - } - - fn cbrt(self) -> Self { - <$ty as Float>::cbrt(self) - } - - fn sqrt(self) -> Self { - <$ty as Float>::sqrt(self) - } - - fn exp(self) -> Self { - <$ty as Float>::exp(self) - } - - fn exp2(self) -> Self { - <$ty as Float>::exp2(self) - } - - fn exp10(self) -> Self { - (<$ty as Float>::exp(self * <$ty as FloatConst>::LN_10())) - } - - fn expm1(self) -> Self { - <$ty as Float>::exp_m1(self) - } - - fn ln(self) -> Self { - <$ty as Float>::ln(self) - } - - fn log1p(self) -> Self { - <$ty as Float>::ln_1p(self) - } - - fn log2(self) -> Self { - <$ty as Float>::log2(self) - } - - fn log10(self) -> Self { - <$ty as Float>::log10(self) - } - - fn sin(self) -> Self { - <$ty as Float>::sin(self) - } - - fn cos(self) -> Self { - <$ty as Float>::cos(self) - } - - fn tan(self) -> Self { - <$ty as Float>::tan(self) - } - - fn tanh(self) -> Self { - <$ty as Float>::tanh(self) - } - - fn asin(self) -> Self { - <$ty as Float>::asin(self) - } - - fn acos(self) -> Self { - <$ty as Float>::acos(self) - } - - fn atan(self) -> Self { - <$ty as Float>::atan(self) - } - - fn sinh(self) -> Self { - <$ty as Float>::sinh(self) - } - - fn cosh(self) -> Self { - <$ty as Float>::cosh(self) - } - - fn asinh(self) -> Self { - <$ty as Float>::asinh(self) - } - - fn acosh(self) -> Self { - <$ty as Float>::acosh(self) - } - - fn atanh(self) -> Self { - <$ty as Float>::atanh(self) - } - - fn abs(self) -> Self::Real { - <$ty as Float>::abs(self) - } - - fn abs2(self) -> Self::Real { - self * self - } - - fn real(self) -> Self::Real { - self - } - - fn imag(self) -> Self::Real { - <$ty as Zero>::zero() - } - - fn angle(self) -> Self::Real { - <$ty as Float>::atan2(<$ty as Zero>::zero(), self) - } - - fn powf(self, exponent: Self::Real) -> Self { - <$ty as Float>::powf(self, exponent) - } - - fn powi(self, exponent: i32) -> Self { - <$ty as Float>::powi(self, exponent) - } - - fn pow(self, exponent: Self) -> Self { - <$ty as Float>::powf(self, exponent) - } - - fn from_real(value: Self::Real) -> Self { - value - } - - fn from_i32(value: i32) -> Self { - value as $ty - } - } - }; -} - -macro_rules! impl_scalar_ad_complex { - ($complex_ty:ty, $real_ty:ty, $one:expr) => { - impl ScalarAd for $complex_ty { - type Real = $real_ty; - - fn conj(self) -> Self { - <$complex_ty>::conj(&self) - } - - fn recip(self) -> Self { - <$complex_ty as ComplexFloat>::recip(self) - } - - fn cbrt(self) -> Self { - <$complex_ty as ComplexFloat>::cbrt(self) - } - - fn sqrt(self) -> Self { - <$complex_ty as ComplexFloat>::sqrt(self) - } - - fn exp(self) -> Self { - <$complex_ty as ComplexFloat>::exp(self) - } - - fn exp2(self) -> Self { - <$complex_ty as ComplexFloat>::exp2(self) - } - - fn exp10(self) -> Self { - (<$complex_ty as ComplexFloat>::exp( - self * <$complex_ty>::new( - <$real_ty as FloatConst>::LN_10(), - <$real_ty as Zero>::zero(), - ), - )) - } - - fn expm1(self) -> Self { - <$complex_ty as ComplexFloat>::exp(self) - $one - } - - fn ln(self) -> Self { - <$complex_ty as ComplexFloat>::ln(self) - } - - fn log1p(self) -> Self { - <$complex_ty as ComplexFloat>::ln(self + $one) - } - - fn log2(self) -> Self { - <$complex_ty as ComplexFloat>::log2(self) - } - - fn log10(self) -> Self { - <$complex_ty as ComplexFloat>::log10(self) - } - - fn sin(self) -> Self { - <$complex_ty as ComplexFloat>::sin(self) - } - - fn cos(self) -> Self { - <$complex_ty as ComplexFloat>::cos(self) - } - - fn tan(self) -> Self { - <$complex_ty as ComplexFloat>::tan(self) - } - - fn tanh(self) -> Self { - <$complex_ty as ComplexFloat>::tanh(self) - } - - fn asin(self) -> Self { - <$complex_ty as ComplexFloat>::asin(self) - } - - fn acos(self) -> Self { - <$complex_ty as ComplexFloat>::acos(self) - } - - fn atan(self) -> Self { - <$complex_ty as ComplexFloat>::atan(self) - } - - fn sinh(self) -> Self { - <$complex_ty as ComplexFloat>::sinh(self) - } - - fn cosh(self) -> Self { - <$complex_ty as ComplexFloat>::cosh(self) - } - - fn asinh(self) -> Self { - <$complex_ty as ComplexFloat>::asinh(self) - } - - fn acosh(self) -> Self { - <$complex_ty as ComplexFloat>::acosh(self) - } - - fn atanh(self) -> Self { - <$complex_ty as ComplexFloat>::atanh(self) - } - - fn abs(self) -> Self::Real { - <$complex_ty as ComplexFloat>::abs(self) - } - - fn abs2(self) -> Self::Real { - <$complex_ty>::norm_sqr(&self) - } - - fn real(self) -> Self::Real { - self.re - } - - fn imag(self) -> Self::Real { - self.im - } - - fn angle(self) -> Self::Real { - <$complex_ty as ComplexFloat>::arg(self) - } - - fn powf(self, exponent: Self::Real) -> Self { - <$complex_ty as ComplexFloat>::powf(self, exponent) - } - - fn powi(self, exponent: i32) -> Self { - <$complex_ty as ComplexFloat>::powi(self, exponent) - } - - fn pow(self, exponent: Self) -> Self { - <$complex_ty as ComplexFloat>::powc(self, exponent) - } - - fn from_real(value: Self::Real) -> Self { - <$complex_ty>::new(value, 0.0) - } - - fn from_i32(value: i32) -> Self { - <$complex_ty>::new(value as $real_ty, 0.0) - } - } - }; -} - -impl_scalar_ad_real!(f32); -impl_scalar_ad_real!(f64); -impl_scalar_ad_complex!(Complex32, f32, Complex32::new(1.0, 0.0)); -impl_scalar_ad_complex!(Complex64, f64, Complex64::new(1.0, 0.0)); - -/// PyTorch-style real-input / complex-gradient projection helper (`handle_r_to_c`). -/// -/// This is equivalent to taking the real part when a gradient for real input -/// becomes complex during intermediate algebra. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::handle_r_to_c_f64; -/// use num_complex::Complex64; -/// -/// let g = Complex64::new(1.25, -3.0); -/// assert_eq!(handle_r_to_c_f64(g), 1.25); -/// ``` -pub fn handle_r_to_c_f64(gradient: Complex64) -> f64 { - gradient.re -} - -/// `f32` variant of [`handle_r_to_c_f64`]. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::handle_r_to_c_f32; -/// use num_complex::Complex32; -/// -/// let g = Complex32::new(2.0, 4.0); -/// assert_eq!(handle_r_to_c_f32(g), 2.0); -/// ``` -pub fn handle_r_to_c_f32(gradient: Complex32) -> f32 { - gradient.re -} From ec5b2310256ac6ad3866b3f5a04ee92d13cfe039 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:02:25 +0900 Subject: [PATCH 10/32] feat: add smooth shared scalar basis rules --- crates/chainrules/src/binary_special.rs | 48 +++++ crates/chainrules/src/lib.rs | 12 +- crates/chainrules/src/power.rs | 57 ++++++ crates/chainrules/src/unary/exp_log.rs | 8 + crates/chainrules/src/unary/exp_log_smooth.rs | 185 ++++++++++++++++++ crates/chainrules/src/unary/mod.rs | 12 +- crates/chainrules/src/unary/roots.rs | 88 +++++++++ crates/chainrules/src/unary/smooth.rs | 90 +-------- crates/chainrules/src/unary/trig.rs | 5 + crates/chainrules/src/unary/trig_smooth.rs | 92 +++++++++ crates/chainrules/tests/smooth_basis_tests.rs | 83 +++++++- 11 files changed, 592 insertions(+), 88 deletions(-) create mode 100644 crates/chainrules/src/binary_special.rs create mode 100644 crates/chainrules/src/unary/exp_log_smooth.rs create mode 100644 crates/chainrules/src/unary/roots.rs create mode 100644 crates/chainrules/src/unary/trig_smooth.rs diff --git a/crates/chainrules/src/binary_special.rs b/crates/chainrules/src/binary_special.rs new file mode 100644 index 0000000..f2e3b4b --- /dev/null +++ b/crates/chainrules/src/binary_special.rs @@ -0,0 +1,48 @@ +use num_traits::Float; + +/// Primal `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot; +/// +/// assert_eq!(hypot(3.0_f64, 4.0_f64), 5.0); +/// ``` +pub fn hypot(x: R, y: R) -> R { + x.hypot(y) +} + +/// Forward rule for `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot_frule; +/// +/// let (r, dr) = hypot_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64); +/// assert_eq!(r, 5.0); +/// assert!((dr - 0.5).abs() < 1e-12); +/// ``` +pub fn hypot_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let r = x.hypot(y); + let inv_r = R::one() / r; + (r, dx * (x * inv_r) + dy * (y * inv_r)) +} + +/// Reverse rule for `hypot`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::hypot_rrule; +/// +/// let (dx, dy) = hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64); +/// assert!((dx - 0.6).abs() < 1e-12); +/// assert!((dy - 0.8).abs() < 1e-12); +/// ``` +pub fn hypot_rrule(x: R, y: R, cotangent: R) -> (R, R) { + let r = x.hypot(y); + let inv_r = R::one() / r; + (cotangent * (x * inv_r), cotangent * (y * inv_r)) +} diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 39da7ad..09ecfdb 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -6,6 +6,7 @@ pub use chainrules_core::{ }; mod binary; +mod binary_special; mod power; mod real_ops; mod scalar_ad; @@ -26,10 +27,13 @@ pub use scalar_ad::{handle_r_to_c_f32, handle_r_to_c_f64, ScalarAd}; pub use unary::{ acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, - cbrt, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, cosh_frule, cosh_rrule, - exp, exp2, exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, hypot, log, log1p, - log1p_frule, log1p_rrule, log2, log_frule, log_rrule, pow, sin, sin_frule, sin_rrule, sinh, - sinh_frule, sinh_rrule, sqrt, sqrt_frule, sqrt_rrule, tan, tanh, tanh_frule, tanh_rrule, + cbrt, cbrt_frule, cbrt_rrule, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, + cosh_frule, cosh_rrule, exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, + exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, hypot, hypot_frule, hypot_rrule, inv, + inv_frule, inv_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, log1p_rrule, + log2, log2_frule, log2_rrule, log_frule, log_rrule, pow, pow_frule, pow_rrule, sin, sin_frule, + sin_rrule, sincos, sincos_frule, sincos_rrule, sinh, sinh_frule, sinh_rrule, sqrt, sqrt_frule, + sqrt_rrule, tan, tan_frule, tan_rrule, tanh, tanh_frule, tanh_rrule, }; #[cfg(test)] diff --git a/crates/chainrules/src/power.rs b/crates/chainrules/src/power.rs index d9aa42c..d7e3033 100644 --- a/crates/chainrules/src/power.rs +++ b/crates/chainrules/src/power.rs @@ -107,3 +107,60 @@ pub fn powi_rrule(x: S, exponent: i32, cotangent: S) -> S { } cotangent * (S::from_i32(exponent) * x.powi(exponent - 1)).conj() } + +/// Primal `pow(x, exponent)`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::pow; +/// +/// assert_eq!(pow(2.0_f64, 3.0_f64), 8.0); +/// ``` +pub fn pow(x: S, exponent: S) -> S { + x.pow(exponent) +} + +/// Forward rule for `pow(x, exponent)`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::pow_frule; +/// +/// let (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0); +/// assert_eq!(y, 8.0); +/// assert!((dy - 12.0).abs() < 1e-12); +/// ``` +pub fn pow_frule(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) { + let y = x.pow(exponent); + let dfdx = if exponent == S::from_i32(0) { + S::from_i32(0) + } else { + (exponent * x.pow(exponent - S::from_i32(1))).conj() + }; + let dfde = (y * x.ln()).conj(); + (y, dx * dfdx + dexponent * dfde) +} + +/// Reverse rule for `pow(x, exponent)`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::pow_rrule; +/// +/// let (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0); +/// assert_eq!(dx, 12.0); +/// assert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12); +/// ``` +pub fn pow_rrule(x: S, exponent: S, cotangent: S) -> (S, S) { + let y = x.pow(exponent); + let dfdx = if exponent == S::from_i32(0) { + S::from_i32(0) + } else { + (exponent * x.pow(exponent - S::from_i32(1))).conj() + }; + let dfde = (y * x.ln()).conj(); + (cotangent * dfdx, cotangent * dfde) +} diff --git a/crates/chainrules/src/unary/exp_log.rs b/crates/chainrules/src/unary/exp_log.rs index 3a7a0f0..3f39864 100644 --- a/crates/chainrules/src/unary/exp_log.rs +++ b/crates/chainrules/src/unary/exp_log.rs @@ -1,6 +1,14 @@ +#[path = "exp_log_smooth.rs"] +mod smooth_rules; + use crate::unary::one; use crate::ScalarAd; +pub use smooth_rules::{ + exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, log10, log10_frule, log10_rrule, + log2, log2_frule, log2_rrule, +}; + /// Primal `exp`. pub fn exp(x: S) -> S { x.exp() diff --git a/crates/chainrules/src/unary/exp_log_smooth.rs b/crates/chainrules/src/unary/exp_log_smooth.rs new file mode 100644 index 0000000..cff5a92 --- /dev/null +++ b/crates/chainrules/src/unary/exp_log_smooth.rs @@ -0,0 +1,185 @@ +use crate::unary::one; +use crate::ScalarAd; +use num_traits::FloatConst; + +fn ln_2() -> S { + S::from_real(S::Real::LN_2()) +} + +fn ln_10() -> S { + S::from_real(S::Real::LN_10()) +} + +/// Primal `2^x`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp2; +/// +/// assert!((exp2(3.0_f64) - 8.0).abs() < 1e-12); +/// ``` +pub fn exp2(x: S) -> S { + x.exp2() +} + +/// Forward rule for `2^x`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp2_frule; +/// +/// let (y, dy) = exp2_frule(3.0_f64, 1.0); +/// assert_eq!(y, 8.0); +/// assert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12); +/// ``` +pub fn exp2_frule(x: S, dx: S) -> (S, S) { + let y = x.exp2(); + (y, dx * (y * ln_2::()).conj()) +} + +/// Reverse rule for `2^x`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp2_rrule; +/// +/// let dy = exp2_rrule(8.0_f64, 1.0); +/// assert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12); +/// ``` +pub fn exp2_rrule(result: S, cotangent: S) -> S { + cotangent * (result * ln_2::()).conj() +} + +/// Primal `10^x`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp10; +/// +/// assert!((exp10(2.0_f64) - 100.0).abs() < 1e-12); +/// ``` +pub fn exp10(x: S) -> S { + x.exp10() +} + +/// Forward rule for `10^x`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp10_frule; +/// +/// let (y, dy) = exp10_frule(2.0_f64, 0.5); +/// assert!((y - 100.0).abs() < 1e-12); +/// assert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12); +/// ``` +pub fn exp10_frule(x: S, dx: S) -> (S, S) { + let y = x.exp10(); + (y, dx * (y * ln_10::()).conj()) +} + +/// Reverse rule for `10^x`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::exp10_rrule; +/// +/// let dy = exp10_rrule(100.0_f64, 0.5); +/// assert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12); +/// ``` +pub fn exp10_rrule(result: S, cotangent: S) -> S { + cotangent * (result * ln_10::()).conj() +} + +/// Primal `log2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log2; +/// +/// assert_eq!(log2(8.0_f64), 3.0); +/// ``` +pub fn log2(x: S) -> S { + x.log2() +} + +/// Forward rule for `log2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log2_frule; +/// +/// let (y, dy) = log2_frule(8.0_f64, 2.0); +/// assert!((y - 3.0).abs() < 1e-12); +/// assert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12); +/// ``` +pub fn log2_frule(x: S, dx: S) -> (S, S) { + let y = x.log2(); + let scale = one::() / (x * ln_2::()); + (y, dx * scale.conj()) +} + +/// Reverse rule for `log2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log2_rrule; +/// +/// let dy = log2_rrule(8.0_f64, 2.0); +/// assert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12); +/// ``` +pub fn log2_rrule(x: S, cotangent: S) -> S { + cotangent * (one::() / (x * ln_2::())).conj() +} + +/// Primal `log10`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log10; +/// +/// assert_eq!(log10(100.0_f64), 2.0); +/// ``` +pub fn log10(x: S) -> S { + x.log10() +} + +/// Forward rule for `log10`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log10_frule; +/// +/// let (y, dy) = log10_frule(100.0_f64, 2.0); +/// assert!((y - 2.0).abs() < 1e-12); +/// assert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12); +/// ``` +pub fn log10_frule(x: S, dx: S) -> (S, S) { + let y = x.log10(); + let scale = one::() / (x * ln_10::()); + (y, dx * scale.conj()) +} + +/// Reverse rule for `log10`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::log10_rrule; +/// +/// let dy = log10_rrule(100.0_f64, 2.0); +/// assert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12); +/// ``` +pub fn log10_rrule(x: S, cotangent: S) -> S { + cotangent * (one::() / (x * ln_10::())).conj() +} diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index 031f94f..1ec625a 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -1,6 +1,7 @@ mod basic; mod exp_log; mod hyperbolic; +mod roots; mod smooth; mod trig; @@ -12,15 +13,20 @@ fn one() -> S { pub use basic::{conj, conj_frule, conj_rrule, sqrt, sqrt_frule, sqrt_rrule}; pub use exp_log::{ - exp, exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, log, log1p, log1p_frule, - log1p_rrule, log_frule, log_rrule, + exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, + expm1, expm1_frule, expm1_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, + log1p_rrule, log2, log2_frule, log2_rrule, log_frule, log_rrule, }; pub use hyperbolic::{ acosh, acosh_frule, acosh_rrule, asinh, asinh_frule, asinh_rrule, atanh, atanh_frule, atanh_rrule, cosh, cosh_frule, cosh_rrule, sinh, sinh_frule, sinh_rrule, tanh, tanh_frule, tanh_rrule, }; -pub use smooth::{cbrt, exp2, hypot, log2, pow, tan}; +pub use roots::{cbrt, cbrt_frule, cbrt_rrule, inv, inv_frule, inv_rrule}; +pub use smooth::{ + hypot, hypot_frule, hypot_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, + tan, tan_frule, tan_rrule, +}; pub use trig::{ acos, acos_frule, acos_rrule, asin, asin_frule, asin_rrule, atan, atan_frule, atan_rrule, cos, cos_frule, cos_rrule, sin, sin_frule, sin_rrule, diff --git a/crates/chainrules/src/unary/roots.rs b/crates/chainrules/src/unary/roots.rs new file mode 100644 index 0000000..18d7e7e --- /dev/null +++ b/crates/chainrules/src/unary/roots.rs @@ -0,0 +1,88 @@ +use crate::ScalarAd; + +/// Primal `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt; +/// +/// assert_eq!(cbrt(8.0_f64), 2.0); +/// ``` +pub fn cbrt(x: S) -> S { + x.cbrt() +} + +/// Forward rule for `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt_frule; +/// +/// let (y, dy) = cbrt_frule(8.0_f64, 1.0); +/// assert_eq!(y, 2.0); +/// assert!((dy - (1.0 / 12.0)).abs() < 1e-12); +/// ``` +pub fn cbrt_frule(x: S, dx: S) -> (S, S) { + let y = x.cbrt(); + let scale = S::from_i32(1) / (S::from_i32(3) * y * y); + (y, dx * scale.conj()) +} + +/// Reverse rule for `cbrt`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cbrt_rrule; +/// +/// let dx = cbrt_rrule(2.0_f64, 1.0); +/// assert!((dx - (1.0 / 12.0)).abs() < 1e-12); +/// ``` +pub fn cbrt_rrule(result: S, cotangent: S) -> S { + cotangent * (S::from_i32(1) / (S::from_i32(3) * result * result)).conj() +} + +/// Primal `inv`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::inv; +/// +/// assert_eq!(inv(4.0_f64), 0.25); +/// ``` +pub fn inv(x: S) -> S { + x.recip() +} + +/// Forward rule for `inv`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::inv_frule; +/// +/// let (y, dy) = inv_frule(4.0_f64, 2.0); +/// assert_eq!(y, 0.25); +/// assert!((dy + 0.125).abs() < 1e-12); +/// ``` +pub fn inv_frule(x: S, dx: S) -> (S, S) { + let y = x.recip(); + (y, dx * (-(y * y)).conj()) +} + +/// Reverse rule for `inv`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::inv_rrule; +/// +/// let dx = inv_rrule(0.25_f64, 2.0); +/// assert!((dx + 0.125).abs() < 1e-12); +/// ``` +pub fn inv_rrule(result: S, cotangent: S) -> S { + cotangent * (-(result * result)).conj() +} diff --git a/crates/chainrules/src/unary/smooth.rs b/crates/chainrules/src/unary/smooth.rs index 9d13393..402d157 100644 --- a/crates/chainrules/src/unary/smooth.rs +++ b/crates/chainrules/src/unary/smooth.rs @@ -1,80 +1,10 @@ -use crate::ScalarAd; -use num_traits::Float; - -/// Primal `cbrt`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::cbrt; -/// -/// assert_eq!(cbrt(8.0_f64), 2.0); -/// ``` -pub fn cbrt(x: S) -> S { - x.cbrt() -} - -/// Primal `exp2`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp2; -/// -/// assert_eq!(exp2(3.0_f64), 8.0); -/// ``` -pub fn exp2(x: S) -> S { - x.exp2() -} - -/// Primal `hypot`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::hypot; -/// -/// assert_eq!(hypot(3.0_f64, 4.0_f64), 5.0); -/// ``` -pub fn hypot(x: R, y: R) -> R { - x.hypot(y) -} - -/// Primal `log2`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log2; -/// -/// assert_eq!(log2(8.0_f64), 3.0); -/// ``` -pub fn log2(x: S) -> S { - x.log2() -} - -/// Primal `pow`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::pow; -/// -/// assert_eq!(pow(2.0_f64, 3.0_f64), 8.0); -/// ``` -pub fn pow(x: S, exponent: S) -> S { - x.pow(exponent) -} - -/// Primal `tan`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::tan; -/// -/// assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12); -/// ``` -pub fn tan(x: S) -> S { - x.tan() -} +#![allow(unused_imports)] + +pub use super::exp_log::{ + exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, log10, log10_frule, log10_rrule, + log2, log2_frule, log2_rrule, +}; +pub use super::roots::{cbrt, cbrt_frule, cbrt_rrule, inv, inv_frule, inv_rrule}; +pub use super::trig::{sincos, sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule}; +pub use crate::binary_special::{hypot, hypot_frule, hypot_rrule}; +pub use crate::power::{pow, pow_frule, pow_rrule}; diff --git a/crates/chainrules/src/unary/trig.rs b/crates/chainrules/src/unary/trig.rs index e5cdf3c..abdcb9e 100644 --- a/crates/chainrules/src/unary/trig.rs +++ b/crates/chainrules/src/unary/trig.rs @@ -1,6 +1,11 @@ +#[path = "trig_smooth.rs"] +mod smooth_rules; + use crate::unary::one; use crate::ScalarAd; +pub use smooth_rules::{sincos, sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule}; + /// Primal `sin`. pub fn sin(x: S) -> S { x.sin() diff --git a/crates/chainrules/src/unary/trig_smooth.rs b/crates/chainrules/src/unary/trig_smooth.rs new file mode 100644 index 0000000..b1a71ac --- /dev/null +++ b/crates/chainrules/src/unary/trig_smooth.rs @@ -0,0 +1,92 @@ +use crate::unary::one; +use crate::ScalarAd; + +/// Primal `tan`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tan; +/// +/// assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12); +/// ``` +pub fn tan(x: S) -> S { + x.tan() +} + +/// Forward rule for `tan`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tan_frule; +/// +/// let (y, dy) = tan_frule(0.25_f64, 1.0); +/// assert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12); +/// ``` +pub fn tan_frule(x: S, dx: S) -> (S, S) { + let y = x.tan(); + (y, dx * (one::() + y * y).conj()) +} + +/// Reverse rule for `tan`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tan_rrule; +/// +/// let dy = tan_rrule(0.25_f64.tan(), 1.0); +/// assert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12); +/// ``` +pub fn tan_rrule(result: S, cotangent: S) -> S { + cotangent * (one::() + result * result).conj() +} + +/// Primal `sincos`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincos; +/// +/// let (s, c) = sincos(0.5_f64); +/// assert!((s - 0.5_f64.sin()).abs() < 1e-12); +/// assert!((c - 0.5_f64.cos()).abs() < 1e-12); +/// ``` +pub fn sincos(x: S) -> (S, S) { + (x.sin(), x.cos()) +} + +/// Forward rule for `sincos`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincos_frule; +/// +/// let ((s, c), (ds, dc)) = sincos_frule(0.25_f64, 1.0); +/// assert!((ds - 0.25_f64.cos()).abs() < 1e-12); +/// assert!((dc + 0.25_f64.sin()).abs() < 1e-12); +/// ``` +pub fn sincos_frule(x: S, dx: S) -> ((S, S), (S, S)) { + let sin_x = x.sin(); + let cos_x = x.cos(); + ((sin_x, cos_x), (dx * cos_x.conj(), dx * (-sin_x).conj())) +} + +/// Reverse rule for `sincos`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincos_rrule; +/// +/// let dx = sincos_rrule(0.25_f64, 1.0, 1.0); +/// assert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12); +/// ``` +pub fn sincos_rrule(x: S, cotangent_sin: S, cotangent_cos: S) -> S { + let sin_x = x.sin(); + let cos_x = x.cos(); + cotangent_sin * cos_x.conj() + cotangent_cos * (-sin_x).conj() +} diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index 3d91d17..8302996 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -1,15 +1,96 @@ -use chainrules::{cbrt, exp2, hypot, log2, pow, tan}; +use chainrules::{ + cbrt, cbrt_frule, cbrt_rrule, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, + hypot, hypot_frule, hypot_rrule, inv, inv_frule, inv_rrule, log10, log10_frule, log10_rrule, + log2, log2_frule, log2_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, + tan, tan_frule, tan_rrule, +}; use num_complex::Complex64; #[test] fn smooth_basis_helpers_are_reexported_from_chainrules() { assert!((cbrt(8.0_f64) - 2.0).abs() < 1.0e-12); + assert!((inv(4.0_f64) - 0.25).abs() < 1.0e-12); assert!((exp2(3.0_f64) - 8.0).abs() < 1.0e-12); + assert!((exp10(2.0_f64) - 100.0).abs() < 1.0e-12); assert!((hypot(3.0_f64, 4.0_f64) - 5.0).abs() < 1.0e-12); assert!((log2(8.0_f64) - 3.0).abs() < 1.0e-12); + assert!((log10(100.0_f64) - 2.0).abs() < 1.0e-12); assert!((pow(2.0_f64, 3.0_f64) - 8.0).abs() < 1.0e-12); assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1.0e-12); + let (sin_x, cos_x) = sincos(0.5_f64); + assert!((sin_x - 0.5_f64.sin()).abs() < 1.0e-12); + assert!((cos_x - 0.5_f64.cos()).abs() < 1.0e-12); let z = Complex64::new(1.0, 2.0); let _ = pow(z, Complex64::new(2.0, 0.0)); } + +#[test] +fn smooth_basis_frules_and_rrules_match_expected_derivatives() { + let (tan_y, tan_dy) = tan_frule(0.25_f64, 1.0_f64); + assert!((tan_y - 0.25_f64.tan()).abs() < 1.0e-12); + assert!((tan_dy - (1.0_f64 + 0.25_f64.tan().powi(2))).abs() < 1.0e-12); + assert!( + (tan_rrule(0.25_f64.tan(), 1.0_f64) - (1.0_f64 + 0.25_f64.tan().powi(2))).abs() < 1.0e-12 + ); + + let (exp2_y, exp2_dy) = exp2_frule(3.0_f64, 1.0_f64); + assert!((exp2_y - 8.0).abs() < 1.0e-12); + assert!((exp2_dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1.0e-12); + assert!((exp2_rrule(8.0_f64, 1.0_f64) - 8.0_f64 * std::f64::consts::LN_2).abs() < 1.0e-12); + + let (hypot_y, hypot_dy) = hypot_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64); + assert!((hypot_y - 5.0).abs() < 1.0e-12); + assert!((hypot_dy - 0.5).abs() < 1.0e-12); + assert!((hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64).0 - 0.6_f64).abs() < 1.0e-12); + assert!((hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64).1 - 0.8_f64).abs() < 1.0e-12); + + let (pow_y, pow_dy) = pow_frule(2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); + assert!((pow_y - 8.0).abs() < 1.0e-12); + assert!((pow_dy - 12.0).abs() < 1.0e-12); + let (pow_dx, pow_dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0_f64); + assert!((pow_dx - 12.0).abs() < 1.0e-12); + assert!((pow_dexp - (8.0_f64 * std::f64::consts::LN_2)).abs() < 1.0e-12); + + let (sincos_y, sincos_dy) = sincos_frule(0.25_f64, 1.0_f64); + assert!((sincos_y.0 - 0.25_f64.sin()).abs() < 1.0e-12); + assert!((sincos_y.1 - 0.25_f64.cos()).abs() < 1.0e-12); + assert!((sincos_dy.0 - 0.25_f64.cos()).abs() < 1.0e-12); + assert!((sincos_dy.1 + 0.25_f64.sin()).abs() < 1.0e-12); + assert!( + (sincos_rrule(0.25_f64, 1.0_f64, 1.0_f64) - (0.25_f64.cos() - 0.25_f64.sin())).abs() + < 1.0e-12 + ); + + let (cbrt_y, cbrt_dy) = cbrt_frule(8.0_f64, 1.0_f64); + assert!((cbrt_y - 2.0).abs() < 1.0e-12); + assert!((cbrt_dy - (1.0_f64 / (3.0_f64 * 4.0_f64))).abs() < 1.0e-12); + assert!((cbrt_rrule(2.0_f64, 1.0_f64) - (1.0_f64 / (3.0_f64 * 4.0_f64))).abs() < 1.0e-12); + + let (inv_y, inv_dy) = inv_frule(4.0_f64, 2.0_f64); + assert!((inv_y - 0.25).abs() < 1.0e-12); + assert!((inv_dy + 0.125).abs() < 1.0e-12); + assert!((inv_rrule(0.25_f64, 2.0_f64) + 0.125).abs() < 1.0e-12); + + let (log2_y, log2_dy) = log2_frule(8.0_f64, 2.0_f64); + assert!((log2_y - 3.0).abs() < 1.0e-12); + let expected_log2 = 2.0_f64 / (8.0_f64 * std::f64::consts::LN_2); + assert!((log2_dy - expected_log2).abs() < 1.0e-12); + assert!((log2_rrule(8.0_f64, 2.0_f64) - expected_log2).abs() < 1.0e-12); + + let (log10_y, log10_dy) = log10_frule(100.0_f64, 2.0_f64); + assert!((log10_y - 2.0).abs() < 1.0e-12); + assert!((log10_dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1.0e-12); + assert!( + (log10_rrule(100.0_f64, 2.0_f64) - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() + < 1.0e-12 + ); + + let (exp10_y, exp10_dy) = exp10_frule(2.0_f64, 0.5_f64); + assert!((exp10_y - 100.0).abs() < 1.0e-12); + assert!((exp10_dy - (100.0_f64 * std::f64::consts::LN_10 * 0.5_f64)).abs() < 1.0e-12); + assert!( + (exp10_rrule(100.0_f64, 0.5_f64) - (100.0_f64 * std::f64::consts::LN_10 * 0.5_f64)).abs() + < 1.0e-12 + ); +} From 3896e32f0a343ef1da8fc2f2bd318398d78e66ff Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:10:35 +0900 Subject: [PATCH 11/32] fix: align smooth scalar basis API --- crates/chainrules/src/unary/exp_log.rs | 81 ++++++-- crates/chainrules/src/unary/exp_log_smooth.rs | 185 ------------------ crates/chainrules/src/unary/trig.rs | 41 +++- crates/chainrules/src/unary/trig_smooth.rs | 92 --------- crates/chainrules/tests/smooth_basis_tests.rs | 2 +- 5 files changed, 98 insertions(+), 303 deletions(-) delete mode 100644 crates/chainrules/src/unary/exp_log_smooth.rs delete mode 100644 crates/chainrules/src/unary/trig_smooth.rs diff --git a/crates/chainrules/src/unary/exp_log.rs b/crates/chainrules/src/unary/exp_log.rs index 3f39864..abb7457 100644 --- a/crates/chainrules/src/unary/exp_log.rs +++ b/crates/chainrules/src/unary/exp_log.rs @@ -1,77 +1,118 @@ -#[path = "exp_log_smooth.rs"] -mod smooth_rules; - use crate::unary::one; use crate::ScalarAd; - -pub use smooth_rules::{ - exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, log10, log10_frule, log10_rrule, - log2, log2_frule, log2_rrule, -}; - +use num_traits::FloatConst; +fn ln_2() -> S { + S::from_real(S::Real::LN_2()) +} +fn ln_10() -> S { + S::from_real(S::Real::LN_10()) +} /// Primal `exp`. pub fn exp(x: S) -> S { x.exp() } - /// Forward rule for `exp`. pub fn exp_frule(x: S, dx: S) -> (S, S) { let y = x.exp(); (y, dx * y.conj()) } - /// Reverse rule for `exp`. pub fn exp_rrule(result: S, cotangent: S) -> S { cotangent * result.conj() } - /// Primal `exp(x) - 1`. pub fn expm1(x: S) -> S { x.expm1() } - /// Forward rule for `exp(x) - 1`. pub fn expm1_frule(x: S, dx: S) -> (S, S) { let y = x.expm1(); let scale = y + one::(); (y, dx * scale.conj()) } - /// Reverse rule for `exp(x) - 1`. pub fn expm1_rrule(result: S, cotangent: S) -> S { cotangent * (result + one::()).conj() } - +#[doc = "Primal `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2;\n\nassert!((exp2(3.0_f64) - 8.0).abs() < 1e-12);\n```"] +pub fn exp2(x: S) -> S { + x.exp2() +} +#[doc = "Forward rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_frule;\n\nlet (y, dy) = exp2_frule(3.0_f64, 1.0);\nassert!((y - 8.0).abs() < 1e-12);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] +pub fn exp2_frule(x: S, dx: S) -> (S, S) { + let y = x.exp2(); + (y, dx * (y * ln_2::()).conj()) +} +#[doc = "Reverse rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_rrule;\n\nlet dy = exp2_rrule(8.0_f64, 1.0);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] +pub fn exp2_rrule(result: S, cotangent: S) -> S { + cotangent * (result * ln_2::()).conj() +} +#[doc = "Primal `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10;\n\nassert!((exp10(2.0_f64) - 100.0).abs() < 1e-12);\n```"] +pub fn exp10(x: S) -> S { + x.exp10() +} +#[doc = "Forward rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_frule;\n\nlet (y, dy) = exp10_frule(2.0_f64, 0.5);\nassert!((y - 100.0).abs() < 1e-12);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"] +pub fn exp10_frule(x: S, dx: S) -> (S, S) { + let y = x.exp10(); + (y, dx * (y * ln_10::()).conj()) +} +#[doc = "Reverse rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_rrule;\n\nlet dy = exp10_rrule(100.0_f64, 0.5);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"] +pub fn exp10_rrule(result: S, cotangent: S) -> S { + cotangent * (result * ln_10::()).conj() +} /// Primal `log`. pub fn log(x: S) -> S { x.ln() } - /// Forward rule for `log`. pub fn log_frule(x: S, dx: S) -> (S, S) { let y = x.ln(); let dy = dx * (one::() / x).conj(); (y, dy) } - /// Reverse rule for `log`. pub fn log_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / x).conj() } - /// Primal `log(1 + x)`. pub fn log1p(x: S) -> S { x.log1p() } - /// Forward rule for `log(1 + x)`. pub fn log1p_frule(x: S, dx: S) -> (S, S) { let y = x.log1p(); let dy = dx * (one::() / (one::() + x)).conj(); (y, dy) } - /// Reverse rule for `log(1 + x)`. pub fn log1p_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / (one::() + x)).conj() } +#[doc = "Primal `log2`.\n\n# Examples\n```rust\nuse chainrules::log2;\n\nassert_eq!(log2(8.0_f64), 3.0);\n```"] +pub fn log2(x: S) -> S { + x.log2() +} +#[doc = "Forward rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_frule;\n\nlet (y, dy) = log2_frule(8.0_f64, 2.0);\nassert!((y - 3.0).abs() < 1e-12);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"] +pub fn log2_frule(x: S, dx: S) -> (S, S) { + let y = x.log2(); + let scale = one::() / (x * ln_2::()); + (y, dx * scale.conj()) +} +#[doc = "Reverse rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_rrule;\n\nlet dy = log2_rrule(8.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"] +pub fn log2_rrule(x: S, cotangent: S) -> S { + cotangent * (one::() / (x * ln_2::())).conj() +} +#[doc = "Primal `log10`.\n\n# Examples\n```rust\nuse chainrules::log10;\n\nassert_eq!(log10(100.0_f64), 2.0);\n```"] +pub fn log10(x: S) -> S { + x.log10() +} +#[doc = "Forward rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_frule;\n\nlet (y, dy) = log10_frule(100.0_f64, 2.0);\nassert!((y - 2.0).abs() < 1e-12);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"] +pub fn log10_frule(x: S, dx: S) -> (S, S) { + let y = x.log10(); + let scale = one::() / (x * ln_10::()); + (y, dx * scale.conj()) +} +#[doc = "Reverse rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_rrule;\n\nlet dy = log10_rrule(100.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"] +pub fn log10_rrule(x: S, cotangent: S) -> S { + cotangent * (one::() / (x * ln_10::())).conj() +} diff --git a/crates/chainrules/src/unary/exp_log_smooth.rs b/crates/chainrules/src/unary/exp_log_smooth.rs deleted file mode 100644 index cff5a92..0000000 --- a/crates/chainrules/src/unary/exp_log_smooth.rs +++ /dev/null @@ -1,185 +0,0 @@ -use crate::unary::one; -use crate::ScalarAd; -use num_traits::FloatConst; - -fn ln_2() -> S { - S::from_real(S::Real::LN_2()) -} - -fn ln_10() -> S { - S::from_real(S::Real::LN_10()) -} - -/// Primal `2^x`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp2; -/// -/// assert!((exp2(3.0_f64) - 8.0).abs() < 1e-12); -/// ``` -pub fn exp2(x: S) -> S { - x.exp2() -} - -/// Forward rule for `2^x`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp2_frule; -/// -/// let (y, dy) = exp2_frule(3.0_f64, 1.0); -/// assert_eq!(y, 8.0); -/// assert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12); -/// ``` -pub fn exp2_frule(x: S, dx: S) -> (S, S) { - let y = x.exp2(); - (y, dx * (y * ln_2::()).conj()) -} - -/// Reverse rule for `2^x`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp2_rrule; -/// -/// let dy = exp2_rrule(8.0_f64, 1.0); -/// assert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12); -/// ``` -pub fn exp2_rrule(result: S, cotangent: S) -> S { - cotangent * (result * ln_2::()).conj() -} - -/// Primal `10^x`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp10; -/// -/// assert!((exp10(2.0_f64) - 100.0).abs() < 1e-12); -/// ``` -pub fn exp10(x: S) -> S { - x.exp10() -} - -/// Forward rule for `10^x`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp10_frule; -/// -/// let (y, dy) = exp10_frule(2.0_f64, 0.5); -/// assert!((y - 100.0).abs() < 1e-12); -/// assert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12); -/// ``` -pub fn exp10_frule(x: S, dx: S) -> (S, S) { - let y = x.exp10(); - (y, dx * (y * ln_10::()).conj()) -} - -/// Reverse rule for `10^x`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::exp10_rrule; -/// -/// let dy = exp10_rrule(100.0_f64, 0.5); -/// assert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12); -/// ``` -pub fn exp10_rrule(result: S, cotangent: S) -> S { - cotangent * (result * ln_10::()).conj() -} - -/// Primal `log2`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log2; -/// -/// assert_eq!(log2(8.0_f64), 3.0); -/// ``` -pub fn log2(x: S) -> S { - x.log2() -} - -/// Forward rule for `log2`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log2_frule; -/// -/// let (y, dy) = log2_frule(8.0_f64, 2.0); -/// assert!((y - 3.0).abs() < 1e-12); -/// assert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12); -/// ``` -pub fn log2_frule(x: S, dx: S) -> (S, S) { - let y = x.log2(); - let scale = one::() / (x * ln_2::()); - (y, dx * scale.conj()) -} - -/// Reverse rule for `log2`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log2_rrule; -/// -/// let dy = log2_rrule(8.0_f64, 2.0); -/// assert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12); -/// ``` -pub fn log2_rrule(x: S, cotangent: S) -> S { - cotangent * (one::() / (x * ln_2::())).conj() -} - -/// Primal `log10`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log10; -/// -/// assert_eq!(log10(100.0_f64), 2.0); -/// ``` -pub fn log10(x: S) -> S { - x.log10() -} - -/// Forward rule for `log10`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log10_frule; -/// -/// let (y, dy) = log10_frule(100.0_f64, 2.0); -/// assert!((y - 2.0).abs() < 1e-12); -/// assert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12); -/// ``` -pub fn log10_frule(x: S, dx: S) -> (S, S) { - let y = x.log10(); - let scale = one::() / (x * ln_10::()); - (y, dx * scale.conj()) -} - -/// Reverse rule for `log10`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::log10_rrule; -/// -/// let dy = log10_rrule(100.0_f64, 2.0); -/// assert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12); -/// ``` -pub fn log10_rrule(x: S, cotangent: S) -> S { - cotangent * (one::() / (x * ln_10::())).conj() -} diff --git a/crates/chainrules/src/unary/trig.rs b/crates/chainrules/src/unary/trig.rs index abdcb9e..3f73f1c 100644 --- a/crates/chainrules/src/unary/trig.rs +++ b/crates/chainrules/src/unary/trig.rs @@ -1,11 +1,6 @@ -#[path = "trig_smooth.rs"] -mod smooth_rules; - use crate::unary::one; use crate::ScalarAd; -pub use smooth_rules::{sincos, sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule}; - /// Primal `sin`. pub fn sin(x: S) -> S { x.sin() @@ -92,3 +87,39 @@ pub fn atan_frule(x: S, dx: S) -> (S, S) { pub fn atan_rrule(x: S, cotangent: S) -> S { cotangent * (one::() / (one::() + x * x)).conj() } + +#[doc = "Primal `tan`.\n\n# Examples\n```rust\nuse chainrules::tan;\n\nassert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12);\n```"] +pub fn tan(x: S) -> S { + x.tan() +} + +#[doc = "Forward rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_frule;\n\nlet (y, dy) = tan_frule(0.25_f64, 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"] +pub fn tan_frule(x: S, dx: S) -> (S, S) { + let y = x.tan(); + (y, dx * (one::() + y * y).conj()) +} + +#[doc = "Reverse rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_rrule;\n\nlet dy = tan_rrule(0.25_f64.tan(), 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"] +pub fn tan_rrule(result: S, cotangent: S) -> S { + cotangent * (one::() + result * result).conj() +} + +#[doc = "Primal `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos;\n\nlet (s, c) = sincos(0.5_f64);\nassert!((s - 0.5_f64.sin()).abs() < 1e-12);\nassert!((c - 0.5_f64.cos()).abs() < 1e-12);\n```"] +pub fn sincos(x: S) -> (S, S) { + (x.sin(), x.cos()) +} + +#[doc = "Forward rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_frule;\n\nlet ((s, c), (ds, dc)) = sincos_frule(0.25_f64, 1.0);\nassert!((ds - 0.25_f64.cos()).abs() < 1e-12);\nassert!((dc + 0.25_f64.sin()).abs() < 1e-12);\n```"] +pub fn sincos_frule(x: S, dx: S) -> ((S, S), (S, S)) { + let sin_x = x.sin(); + let cos_x = x.cos(); + ((sin_x, cos_x), (dx * cos_x.conj(), dx * (-sin_x).conj())) +} + +#[doc = "Reverse rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_rrule;\n\nlet dx = sincos_rrule(0.25_f64, (1.0, 1.0));\nassert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12);\n```"] +pub fn sincos_rrule(x: S, cotangents: (S, S)) -> S { + let (cotangent_sin, cotangent_cos) = cotangents; + let sin_x = x.sin(); + let cos_x = x.cos(); + cotangent_sin * cos_x.conj() + cotangent_cos * (-sin_x).conj() +} diff --git a/crates/chainrules/src/unary/trig_smooth.rs b/crates/chainrules/src/unary/trig_smooth.rs deleted file mode 100644 index b1a71ac..0000000 --- a/crates/chainrules/src/unary/trig_smooth.rs +++ /dev/null @@ -1,92 +0,0 @@ -use crate::unary::one; -use crate::ScalarAd; - -/// Primal `tan`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::tan; -/// -/// assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12); -/// ``` -pub fn tan(x: S) -> S { - x.tan() -} - -/// Forward rule for `tan`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::tan_frule; -/// -/// let (y, dy) = tan_frule(0.25_f64, 1.0); -/// assert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12); -/// ``` -pub fn tan_frule(x: S, dx: S) -> (S, S) { - let y = x.tan(); - (y, dx * (one::() + y * y).conj()) -} - -/// Reverse rule for `tan`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::tan_rrule; -/// -/// let dy = tan_rrule(0.25_f64.tan(), 1.0); -/// assert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12); -/// ``` -pub fn tan_rrule(result: S, cotangent: S) -> S { - cotangent * (one::() + result * result).conj() -} - -/// Primal `sincos`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::sincos; -/// -/// let (s, c) = sincos(0.5_f64); -/// assert!((s - 0.5_f64.sin()).abs() < 1e-12); -/// assert!((c - 0.5_f64.cos()).abs() < 1e-12); -/// ``` -pub fn sincos(x: S) -> (S, S) { - (x.sin(), x.cos()) -} - -/// Forward rule for `sincos`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::sincos_frule; -/// -/// let ((s, c), (ds, dc)) = sincos_frule(0.25_f64, 1.0); -/// assert!((ds - 0.25_f64.cos()).abs() < 1e-12); -/// assert!((dc + 0.25_f64.sin()).abs() < 1e-12); -/// ``` -pub fn sincos_frule(x: S, dx: S) -> ((S, S), (S, S)) { - let sin_x = x.sin(); - let cos_x = x.cos(); - ((sin_x, cos_x), (dx * cos_x.conj(), dx * (-sin_x).conj())) -} - -/// Reverse rule for `sincos`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::sincos_rrule; -/// -/// let dx = sincos_rrule(0.25_f64, 1.0, 1.0); -/// assert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12); -/// ``` -pub fn sincos_rrule(x: S, cotangent_sin: S, cotangent_cos: S) -> S { - let sin_x = x.sin(); - let cos_x = x.cos(); - cotangent_sin * cos_x.conj() + cotangent_cos * (-sin_x).conj() -} diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index 8302996..e7edca0 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -58,7 +58,7 @@ fn smooth_basis_frules_and_rrules_match_expected_derivatives() { assert!((sincos_dy.0 - 0.25_f64.cos()).abs() < 1.0e-12); assert!((sincos_dy.1 + 0.25_f64.sin()).abs() < 1.0e-12); assert!( - (sincos_rrule(0.25_f64, 1.0_f64, 1.0_f64) - (0.25_f64.cos() - 0.25_f64.sin())).abs() + (sincos_rrule(0.25_f64, (1.0_f64, 1.0_f64)) - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1.0e-12 ); From 6fbd4706a26a9206bc6dc4c857aae524c10dbece Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:22:17 +0900 Subject: [PATCH 12/32] fix: avoid NaN in same-scalar pow rules --- crates/chainrules/src/power.rs | 85 +++++++++---------- crates/chainrules/tests/smooth_basis_tests.rs | 39 ++++++++- 2 files changed, 79 insertions(+), 45 deletions(-) diff --git a/crates/chainrules/src/power.rs b/crates/chainrules/src/power.rs index d7e3033..851bdd4 100644 --- a/crates/chainrules/src/power.rs +++ b/crates/chainrules/src/power.rs @@ -108,59 +108,56 @@ pub fn powi_rrule(x: S, exponent: i32, cotangent: S) -> S { cotangent * (S::from_i32(exponent) * x.powi(exponent - 1)).conj() } -/// Primal `pow(x, exponent)`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::pow; -/// -/// assert_eq!(pow(2.0_f64, 3.0_f64), 8.0); -/// ``` +#[doc = "Primal `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow;\n\nassert_eq!(pow(2.0_f64, 3.0_f64), 8.0);\n```"] pub fn pow(x: S, exponent: S) -> S { x.pow(exponent) } - -/// Forward rule for `pow(x, exponent)`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::pow_frule; -/// -/// let (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0); -/// assert_eq!(y, 8.0); -/// assert!((dy - 12.0).abs() < 1e-12); -/// ``` +fn zero() -> S { + S::from_i32(0) +} +fn pow_x_scale(x: S, exponent: S) -> S { + if exponent == zero::() { + zero::() + } else { + (exponent * x.pow(exponent - S::from_i32(1))).conj() + } +} +fn pow_exp_scale(x: S, exponent: S) -> S { + if x.real() == S::Real::zero() + && exponent.imag() == S::Real::zero() + && exponent.real() >= S::Real::zero() + { + zero::() + } else { + (x.pow(exponent) * x.ln()).conj() + } +} +#[doc = "Forward rule for `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow_frule;\n\nlet (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0);\nassert_eq!(y, 8.0);\nassert!((dy - 12.0).abs() < 1e-12);\n```"] pub fn pow_frule(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) { let y = x.pow(exponent); - let dfdx = if exponent == S::from_i32(0) { - S::from_i32(0) + let dfdx = if dx == zero::() { + zero::() } else { - (exponent * x.pow(exponent - S::from_i32(1))).conj() + dx * pow_x_scale(x, exponent) + }; + let dfde = if dexponent == zero::() { + zero::() + } else { + dexponent * pow_exp_scale(x, exponent) }; - let dfde = (y * x.ln()).conj(); - (y, dx * dfdx + dexponent * dfde) + (y, dfdx + dfde) } - -/// Reverse rule for `pow(x, exponent)`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::pow_rrule; -/// -/// let (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0); -/// assert_eq!(dx, 12.0); -/// assert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12); -/// ``` +#[doc = "Reverse rule for `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow_rrule;\n\nlet (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0);\nassert_eq!(dx, 12.0);\nassert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] pub fn pow_rrule(x: S, exponent: S, cotangent: S) -> (S, S) { - let y = x.pow(exponent); - let dfdx = if exponent == S::from_i32(0) { - S::from_i32(0) + let dfdx = if cotangent == zero::() { + zero::() } else { - (exponent * x.pow(exponent - S::from_i32(1))).conj() + cotangent * pow_x_scale(x, exponent) + }; + let dfde = if cotangent == zero::() { + zero::() + } else { + cotangent * pow_exp_scale(x, exponent) }; - let dfde = (y * x.ln()).conj(); - (cotangent * dfdx, cotangent * dfde) + (dfdx, dfde) } diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index e7edca0..117307c 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -4,7 +4,7 @@ use chainrules::{ log2, log2_frule, log2_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule, }; -use num_complex::Complex64; +use num_complex::{Complex64, ComplexFloat}; #[test] fn smooth_basis_helpers_are_reexported_from_chainrules() { @@ -94,3 +94,40 @@ fn smooth_basis_frules_and_rrules_match_expected_derivatives() { < 1.0e-12 ); } + +#[test] +fn pow_rules_handle_zero_and_negative_real_paths() { + let (neg_y, neg_dy) = pow_frule(-2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); + assert!((neg_y + 8.0).abs() < 1.0e-12); + assert!((neg_dy - 12.0).abs() < 1.0e-12); + + let (zero_y, zero_dy) = pow_frule(0.0_f64, 2.0_f64, 1.0_f64, 0.0_f64); + assert!((zero_y - 0.0).abs() < 1.0e-12); + assert!((zero_dy - 0.0).abs() < 1.0e-12); + + let (dx, dexp) = pow_rrule(0.0_f64, 2.0_f64, 1.0_f64); + assert!((dx - 0.0).abs() < 1.0e-12); + assert!((dexp - 0.0).abs() < 1.0e-12); +} + +#[test] +fn pow_rules_cover_complex_frule_and_rrule_paths() { + let x = Complex64::new(1.0, 1.0); + let exponent = Complex64::new(2.0, 0.0); + let dx = Complex64::new(0.5, -0.25); + let dexp = Complex64::new(0.0, 0.0); + + let (y, dy) = pow_frule(x, exponent, dx, dexp); + let expected_y = x.powc(exponent); + let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))).conj(); + assert!((y - expected_y).norm() < 1.0e-12); + assert!((dy - expected_dy).norm() < 1.0e-12); + + let cotangent = Complex64::new(0.5, -0.25); + let (dx_rr, dexp_rr) = pow_rrule(x, exponent, cotangent); + let expected_dx_rr = + cotangent * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))).conj(); + let expected_dexp_rr = cotangent * (expected_y * x.ln()).conj(); + assert!((dx_rr - expected_dx_rr).norm() < 1.0e-12); + assert!((dexp_rr - expected_dexp_rr).norm() < 1.0e-12); +} From faa3cf7a057b05d5773b120dce4b55a806b47548 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:27:10 +0900 Subject: [PATCH 13/32] fix: handle complex zero-base pow exponent scale --- crates/chainrules/src/power.rs | 4 +--- crates/chainrules/tests/smooth_basis_tests.rs | 5 +++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/crates/chainrules/src/power.rs b/crates/chainrules/src/power.rs index 851bdd4..921127f 100644 --- a/crates/chainrules/src/power.rs +++ b/crates/chainrules/src/power.rs @@ -123,9 +123,7 @@ fn pow_x_scale(x: S, exponent: S) -> S { } } fn pow_exp_scale(x: S, exponent: S) -> S { - if x.real() == S::Real::zero() - && exponent.imag() == S::Real::zero() - && exponent.real() >= S::Real::zero() + if x == zero::() && exponent.imag() == S::Real::zero() && exponent.real() >= S::Real::zero() { zero::() } else { diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index 117307c..48a0a43 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -130,4 +130,9 @@ fn pow_rules_cover_complex_frule_and_rrule_paths() { let expected_dexp_rr = cotangent * (expected_y * x.ln()).conj(); assert!((dx_rr - expected_dx_rr).norm() < 1.0e-12); assert!((dexp_rr - expected_dexp_rr).norm() < 1.0e-12); + + let imag_x = Complex64::new(0.0, 1.0); + let (_, imag_dexp_rr) = pow_rrule(imag_x, Complex64::new(2.0, 0.0), Complex64::new(1.0, 0.0)); + let expected_imag_dexp_rr = (imag_x.powc(Complex64::new(2.0, 0.0)) * imag_x.ln()).conj(); + assert!((imag_dexp_rr - expected_imag_dexp_rr).norm() < 1.0e-12); } From a63d5a599289e693722dfa065415545bc5f31b6e Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:36:54 +0900 Subject: [PATCH 14/32] feat: add complex scalar helper rules --- crates/chainrules/src/lib.rs | 18 +- crates/chainrules/src/unary/complex_parts.rs | 178 ++++++++++++++++++ crates/chainrules/src/unary/mod.rs | 5 + .../chainrules/tests/complex_helper_tests.rs | 28 +++ 4 files changed, 221 insertions(+), 8 deletions(-) create mode 100644 crates/chainrules/src/unary/complex_parts.rs create mode 100644 crates/chainrules/tests/complex_helper_tests.rs diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 09ecfdb..f08379f 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -25,15 +25,17 @@ pub use real_ops::{atan2, atan2_frule, atan2_rrule}; pub use scalar_ad::{handle_r_to_c_f32, handle_r_to_c_f64, ScalarAd}; #[doc(inline)] pub use unary::{ - acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, - asinh, asinh_frule, asinh_rrule, atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, - cbrt, cbrt_frule, cbrt_rrule, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, - cosh_frule, cosh_rrule, exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, - exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, hypot, hypot_frule, hypot_rrule, inv, + abs, abs2, abs2_frule, abs2_rrule, acos, acos_frule, acos_rrule, acosh, acosh_frule, + acosh_rrule, angle, angle_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, + atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, cbrt, cbrt_frule, cbrt_rrule, + complex, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, cosh_frule, cosh_rrule, + exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, + expm1, expm1_frule, expm1_rrule, hypot, hypot_frule, hypot_rrule, imag, imag_rrule, inv, inv_frule, inv_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, log1p_rrule, - log2, log2_frule, log2_rrule, log_frule, log_rrule, pow, pow_frule, pow_rrule, sin, sin_frule, - sin_rrule, sincos, sincos_frule, sincos_rrule, sinh, sinh_frule, sinh_rrule, sqrt, sqrt_frule, - sqrt_rrule, tan, tan_frule, tan_rrule, tanh, tanh_frule, tanh_rrule, + log2, log2_frule, log2_rrule, log_frule, log_rrule, pow, pow_frule, pow_rrule, real, + real_rrule, sin, sin_frule, sin_rrule, sincos, sincos_frule, sincos_rrule, sinh, sinh_frule, + sinh_rrule, sqrt, sqrt_frule, sqrt_rrule, tan, tan_frule, tan_rrule, tanh, tanh_frule, + tanh_rrule, }; #[cfg(test)] diff --git a/crates/chainrules/src/unary/complex_parts.rs b/crates/chainrules/src/unary/complex_parts.rs new file mode 100644 index 0000000..6af3ba2 --- /dev/null +++ b/crates/chainrules/src/unary/complex_parts.rs @@ -0,0 +1,178 @@ +use crate::ScalarAd; +use num_complex::Complex; +use num_traits::{Float, One, Zero}; + +/// Primal `abs`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs; +/// use num_complex::Complex64; +/// +/// assert_eq!(abs(Complex64::new(3.0, 4.0)), 5.0); +/// ``` +#[inline] +pub fn abs(x: S) -> S::Real { + x.abs() +} + +/// Primal `abs2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs2; +/// use num_complex::Complex64; +/// +/// assert_eq!(abs2(Complex64::new(3.0, 4.0)), 25.0); +/// ``` +#[inline] +pub fn abs2(x: S) -> S::Real { + x.abs2() +} + +/// Primal `real`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::real; +/// use num_complex::Complex64; +/// +/// assert_eq!(real(Complex64::new(3.0, 4.0)), 3.0); +/// ``` +#[inline] +pub fn real(x: S) -> S::Real { + x.real() +} + +/// Primal `imag`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::imag; +/// use num_complex::Complex64; +/// +/// assert_eq!(imag(Complex64::new(3.0, 4.0)), 4.0); +/// ``` +#[inline] +pub fn imag(x: S) -> S::Real { + x.imag() +} + +/// Primal `angle`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::angle; +/// use num_complex::Complex64; +/// +/// assert!((angle(Complex64::new(3.0, 4.0)) - 0.9272952180016122).abs() < 1e-12); +/// ``` +#[inline] +pub fn angle(x: S) -> S::Real { + x.angle() +} + +/// Construct a complex number from real and imaginary parts. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::complex; +/// use num_complex::Complex64; +/// +/// assert_eq!(complex(3.0_f64, 4.0_f64), Complex64::new(3.0, 4.0)); +/// ``` +#[inline] +pub fn complex(re: R, im: R) -> Complex { + Complex::new(re, im) +} + +/// Forward rule for `abs2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs2_frule; +/// use num_complex::Complex64; +/// +/// let z = Complex64::new(3.0, 4.0); +/// let dz = Complex64::new(1.0, -2.0); +/// let (y, dy) = abs2_frule(z, dz); +/// assert_eq!(y, 25.0); +/// assert_eq!(dy, -10.0); +/// ``` +#[inline] +pub fn abs2_frule(x: S, dx: S) -> (S::Real, S::Real) { + let y = x.abs2(); + let two = S::Real::one() + S::Real::one(); + let dy = two * (x.real() * dx.real() + x.imag() * dx.imag()); + (y, dy) +} + +/// Reverse rule for `abs2`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::abs2_rrule; +/// use num_complex::Complex64; +/// +/// let z = Complex64::new(3.0, 4.0); +/// assert_eq!(abs2_rrule(z, 1.25), Complex64::new(7.5, 10.0)); +/// ``` +#[inline] +pub fn abs2_rrule(x: Complex, cotangent: R) -> Complex { + let two = R::one() + R::one(); + Complex::new(two * cotangent * x.re, two * cotangent * x.im) +} + +/// Reverse rule for `real`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::real_rrule; +/// use num_complex::Complex64; +/// +/// assert_eq!(real_rrule::(2.0), Complex64::new(2.0, 0.0)); +/// ``` +#[inline] +pub fn real_rrule(cotangent: S::Real) -> Complex { + Complex::new(cotangent, S::Real::zero()) +} + +/// Reverse rule for `imag`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::imag_rrule; +/// use num_complex::Complex64; +/// +/// assert_eq!(imag_rrule::(2.0), Complex64::new(0.0, 2.0)); +/// ``` +#[inline] +pub fn imag_rrule(cotangent: S::Real) -> Complex { + Complex::new(S::Real::zero(), cotangent) +} + +/// Reverse rule for `angle`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::angle_rrule; +/// use num_complex::Complex64; +/// +/// assert_eq!(angle_rrule(Complex64::new(3.0, 4.0), 1.0), Complex64::new(-0.16, 0.12)); +/// ``` +#[inline] +pub fn angle_rrule(x: Complex, cotangent: R) -> Complex { + let denom = x.re * x.re + x.im * x.im; + Complex::new(-x.im * cotangent / denom, x.re * cotangent / denom) +} diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index 1ec625a..e8410be 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -1,4 +1,5 @@ mod basic; +mod complex_parts; mod exp_log; mod hyperbolic; mod roots; @@ -12,6 +13,10 @@ fn one() -> S { } pub use basic::{conj, conj_frule, conj_rrule, sqrt, sqrt_frule, sqrt_rrule}; +pub use complex_parts::{ + abs, abs2, abs2_frule, abs2_rrule, angle, angle_rrule, complex, imag, imag_rrule, real, + real_rrule, +}; pub use exp_log::{ exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, diff --git a/crates/chainrules/tests/complex_helper_tests.rs b/crates/chainrules/tests/complex_helper_tests.rs new file mode 100644 index 0000000..f3e614b --- /dev/null +++ b/crates/chainrules/tests/complex_helper_tests.rs @@ -0,0 +1,28 @@ +use chainrules::{ + abs, abs2, abs2_frule, abs2_rrule, angle, angle_rrule, complex, imag, imag_rrule, real, + real_rrule, +}; +use num_complex::Complex64; + +#[test] +fn complex_helpers_match_expected_formulas() { + let z = Complex64::new(3.0, 4.0); + let dz = Complex64::new(1.0, -2.0); + + let constructed: Complex64 = complex(3.0, 4.0); + assert_eq!(constructed, z); + assert_eq!(abs(z), 5.0); + assert_eq!(abs2(z), 25.0); + assert_eq!(real(z), 3.0); + assert_eq!(imag(z), 4.0); + assert_eq!(angle(z), z.arg()); + + let (abs2_y, abs2_dy) = abs2_frule(z, dz); + assert_eq!(abs2_y, 25.0); + assert_eq!(abs2_dy, 2.0 * (z.re * dz.re + z.im * dz.im)); + + assert_eq!(abs2_rrule(z, 1.25), Complex64::new(7.5, 10.0)); + assert_eq!(real_rrule::(2.0), Complex64::new(2.0, 0.0)); + assert_eq!(imag_rrule::(2.0), Complex64::new(0.0, 2.0)); + assert_eq!(angle_rrule(z, 1.0), Complex64::new(-0.16, 0.12)); +} From 5fbff1af4ad76a0a1383277dd9e47d2d82ceae73 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:43:27 +0900 Subject: [PATCH 15/32] docs: update chainrules README surface --- crates/chainrules/README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index b19cf91..2697d36 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -30,7 +30,8 @@ Current shipped scalar families: - exponentials and logs: `exp`, `expm1`, `log`, `log1p` - trigonometric: `sin`, `cos`, `asin`, `acos`, `atan` - hyperbolic: `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` -- complex and projection helpers: `conj`, `handle_r_to_c_f32`, `handle_r_to_c_f64` +- smooth helpers: `cbrt`, `inv`, `exp2`, `exp10`, `log2`, `log10`, `hypot`, `pow`, `sincos`, `tan` +- complex and projection helpers: `conj`, `abs`, `abs2`, `angle`, `real`, `imag`, `complex`, `handle_r_to_c_f32`, `handle_r_to_c_f64` - real-valued binary helpers: `atan2` This crate is intended as a landing zone for scalar rules ported or adapted @@ -42,8 +43,10 @@ a full port of `ChainRules.jl`. Rules in this crate are not accepted on provenance alone. They are checked with repository-local tests: -- `tests/scalarops_tests.rs` covers direct formulas, edge cases, and +- `tests/scalarops_tests.rs` covers direct formulas, edge cases, and smooth real/complex behavior +- `tests/complex_helper_tests.rs` covers the projection helpers and complex + constructor surface - `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from `../../third_party/tensor-ad-oracles` From dbeaa65d52d6faecc386f55b16b6e83ef01721d2 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 10:50:52 +0900 Subject: [PATCH 16/32] fix: narrow complex projection pullbacks --- crates/chainrules/src/unary/complex_parts.rs | 32 +++++++++++++++---- .../chainrules/tests/complex_helper_tests.rs | 12 +++++-- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/crates/chainrules/src/unary/complex_parts.rs b/crates/chainrules/src/unary/complex_parts.rs index 6af3ba2..c5fe129 100644 --- a/crates/chainrules/src/unary/complex_parts.rs +++ b/crates/chainrules/src/unary/complex_parts.rs @@ -2,6 +2,22 @@ use crate::ScalarAd; use num_complex::Complex; use num_traits::{Float, One, Zero}; +trait ComplexProjectionScalar: ScalarAd { + fn from_parts(re: Self::Real, im: Self::Real) -> Self; +} + +impl ComplexProjectionScalar for num_complex::Complex32 { + fn from_parts(re: Self::Real, im: Self::Real) -> Self { + Complex::new(re, im) + } +} + +impl ComplexProjectionScalar for num_complex::Complex64 { + fn from_parts(re: Self::Real, im: Self::Real) -> Self { + Complex::new(re, im) + } +} + /// Primal `abs`. /// /// # Examples @@ -139,11 +155,13 @@ pub fn abs2_rrule(x: Complex, cotangent: R) -> Complex { /// use chainrules::real_rrule; /// use num_complex::Complex64; /// -/// assert_eq!(real_rrule::(2.0), Complex64::new(2.0, 0.0)); +/// let grad: Complex64 = real_rrule(2.0); +/// assert_eq!(grad, Complex64::new(2.0, 0.0)); /// ``` #[inline] -pub fn real_rrule(cotangent: S::Real) -> Complex { - Complex::new(cotangent, S::Real::zero()) +#[allow(private_bounds)] +pub fn real_rrule(cotangent: S::Real) -> S { + S::from_parts(cotangent, S::Real::zero()) } /// Reverse rule for `imag`. @@ -154,11 +172,13 @@ pub fn real_rrule(cotangent: S::Real) -> Complex { /// use chainrules::imag_rrule; /// use num_complex::Complex64; /// -/// assert_eq!(imag_rrule::(2.0), Complex64::new(0.0, 2.0)); +/// let grad: Complex64 = imag_rrule(2.0); +/// assert_eq!(grad, Complex64::new(0.0, 2.0)); /// ``` #[inline] -pub fn imag_rrule(cotangent: S::Real) -> Complex { - Complex::new(S::Real::zero(), cotangent) +#[allow(private_bounds)] +pub fn imag_rrule(cotangent: S::Real) -> S { + S::from_parts(S::Real::zero(), cotangent) } /// Reverse rule for `angle`. diff --git a/crates/chainrules/tests/complex_helper_tests.rs b/crates/chainrules/tests/complex_helper_tests.rs index f3e614b..5883afd 100644 --- a/crates/chainrules/tests/complex_helper_tests.rs +++ b/crates/chainrules/tests/complex_helper_tests.rs @@ -6,11 +6,17 @@ use num_complex::Complex64; #[test] fn complex_helpers_match_expected_formulas() { + let x = 3.0_f64; let z = Complex64::new(3.0, 4.0); let dz = Complex64::new(1.0, -2.0); let constructed: Complex64 = complex(3.0, 4.0); assert_eq!(constructed, z); + assert_eq!(abs(x), 3.0); + assert_eq!(abs2(x), 9.0); + assert_eq!(real(x), 3.0); + assert_eq!(imag(x), 0.0); + assert_eq!(angle(x), 0.0_f64.atan2(x)); assert_eq!(abs(z), 5.0); assert_eq!(abs2(z), 25.0); assert_eq!(real(z), 3.0); @@ -22,7 +28,9 @@ fn complex_helpers_match_expected_formulas() { assert_eq!(abs2_dy, 2.0 * (z.re * dz.re + z.im * dz.im)); assert_eq!(abs2_rrule(z, 1.25), Complex64::new(7.5, 10.0)); - assert_eq!(real_rrule::(2.0), Complex64::new(2.0, 0.0)); - assert_eq!(imag_rrule::(2.0), Complex64::new(0.0, 2.0)); + let real_grad: Complex64 = real_rrule(2.0); + assert_eq!(real_grad, Complex64::new(2.0, 0.0)); + let imag_grad: Complex64 = imag_rrule(2.0); + assert_eq!(imag_grad, Complex64::new(0.0, 2.0)); assert_eq!(angle_rrule(z, 1.0), Complex64::new(-0.16, 0.12)); } From 46791c77bfcc8c9ba17860294a1af026119a251d Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 11:00:02 +0900 Subject: [PATCH 17/32] feat: add Julia compatibility scalar helpers --- crates/chainrules/src/lib.rs | 18 +- .../chainrules/src/unary/hyperbolic_extra.rs | 67 +++++ crates/chainrules/src/unary/mod.rs | 10 + crates/chainrules/src/unary/trig_extra.rs | 231 ++++++++++++++++++ .../tests/julia_compat_trig_tests.rs | 95 +++++++ 5 files changed, 414 insertions(+), 7 deletions(-) create mode 100644 crates/chainrules/src/unary/hyperbolic_extra.rs create mode 100644 crates/chainrules/src/unary/trig_extra.rs create mode 100644 crates/chainrules/tests/julia_compat_trig_tests.rs diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index f08379f..7a040a4 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -28,13 +28,17 @@ pub use unary::{ abs, abs2, abs2_frule, abs2_rrule, acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, angle, angle_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, cbrt, cbrt_frule, cbrt_rrule, - complex, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, cosh_frule, cosh_rrule, - exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, - expm1, expm1_frule, expm1_rrule, hypot, hypot_frule, hypot_rrule, imag, imag_rrule, inv, - inv_frule, inv_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, log1p_rrule, - log2, log2_frule, log2_rrule, log_frule, log_rrule, pow, pow_frule, pow_rrule, real, - real_rrule, sin, sin_frule, sin_rrule, sincos, sincos_frule, sincos_rrule, sinh, sinh_frule, - sinh_rrule, sqrt, sqrt_frule, sqrt_rrule, tan, tan_frule, tan_rrule, tanh, tanh_frule, + complex, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosd, cosd_frule, cosd_rrule, + cosh, cosh_frule, cosh_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, coth, + coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, csch_frule, csch_rrule, exp, exp10, + exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1, + expm1_frule, expm1_rrule, hypot, hypot_frule, hypot_rrule, imag, imag_rrule, inv, inv_frule, + inv_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, log1p_rrule, log2, + log2_frule, log2_rrule, log_frule, log_rrule, pow, pow_frule, pow_rrule, real, real_rrule, sec, + sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sin, sin_frule, sin_rrule, sincos, + sincos_frule, sincos_rrule, sincospi, sincospi_frule, sincospi_rrule, sind, sind_frule, + sind_rrule, sinh, sinh_frule, sinh_rrule, sinpi, sinpi_frule, sinpi_rrule, sqrt, sqrt_frule, + sqrt_rrule, tan, tan_frule, tan_rrule, tand, tand_frule, tand_rrule, tanh, tanh_frule, tanh_rrule, }; diff --git a/crates/chainrules/src/unary/hyperbolic_extra.rs b/crates/chainrules/src/unary/hyperbolic_extra.rs new file mode 100644 index 0000000..26b1b26 --- /dev/null +++ b/crates/chainrules/src/unary/hyperbolic_extra.rs @@ -0,0 +1,67 @@ +use crate::unary::{ + cosh, cosh_frule, cosh_rrule, inv, inv_frule, inv_rrule, sinh, sinh_frule, sinh_rrule, tanh, + tanh_frule, tanh_rrule, +}; +use crate::ScalarAd; + +/// Primal `sech`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sech; +/// +/// assert!((sech(0.5_f64) - 1.0 / 0.5_f64.cosh()).abs() < 1e-12); +/// ``` +pub fn sech(x: S) -> S { + inv(cosh(x)) +} + +/// Forward rule for `sech`. +pub fn sech_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = cosh_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `sech`. +pub fn sech_rrule(x: S, cotangent: S) -> S { + let y = sech(x); + let d_y = inv_rrule(y, cotangent); + cosh_rrule(x, d_y) +} + +/// Primal `csch`. +pub fn csch(x: S) -> S { + inv(sinh(x)) +} + +/// Forward rule for `csch`. +pub fn csch_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = sinh_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `csch`. +pub fn csch_rrule(x: S, cotangent: S) -> S { + let y = csch(x); + let d_y = inv_rrule(y, cotangent); + sinh_rrule(x, d_y) +} + +/// Primal `coth`. +pub fn coth(x: S) -> S { + inv(tanh(x)) +} + +/// Forward rule for `coth`. +pub fn coth_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = tanh_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `coth`. +pub fn coth_rrule(x: S, cotangent: S) -> S { + let y = coth(x); + let d_y = inv_rrule(y, cotangent); + tanh_rrule(tanh(x), d_y) +} diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index e8410be..66abb4d 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -2,9 +2,11 @@ mod basic; mod complex_parts; mod exp_log; mod hyperbolic; +mod hyperbolic_extra; mod roots; mod smooth; mod trig; +mod trig_extra; use crate::ScalarAd; @@ -27,6 +29,9 @@ pub use hyperbolic::{ atanh_rrule, cosh, cosh_frule, cosh_rrule, sinh, sinh_frule, sinh_rrule, tanh, tanh_frule, tanh_rrule, }; +pub use hyperbolic_extra::{ + coth, coth_frule, coth_rrule, csch, csch_frule, csch_rrule, sech, sech_frule, sech_rrule, +}; pub use roots::{cbrt, cbrt_frule, cbrt_rrule, inv, inv_frule, inv_rrule}; pub use smooth::{ hypot, hypot_frule, hypot_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, @@ -36,3 +41,8 @@ pub use trig::{ acos, acos_frule, acos_rrule, asin, asin_frule, asin_rrule, atan, atan_frule, atan_rrule, cos, cos_frule, cos_rrule, sin, sin_frule, sin_rrule, }; +pub use trig_extra::{ + cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, csc, + csc_frule, csc_rrule, sec, sec_frule, sec_rrule, sincospi, sincospi_frule, sincospi_rrule, + sind, sind_frule, sind_rrule, sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, +}; diff --git a/crates/chainrules/src/unary/trig_extra.rs b/crates/chainrules/src/unary/trig_extra.rs new file mode 100644 index 0000000..15e6e21 --- /dev/null +++ b/crates/chainrules/src/unary/trig_extra.rs @@ -0,0 +1,231 @@ +use crate::binary::{mul_frule, mul_rrule}; +use crate::unary::{ + cos, cos_frule, cos_rrule, inv, inv_frule, inv_rrule, sin, sin_frule, sin_rrule, sincos, + sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule, +}; +use crate::ScalarAd; +use num_traits::FloatConst; + +fn pi() -> S { + S::from_real(S::Real::PI()) +} + +fn deg2rad() -> S { + pi::() / S::from_i32(180) +} + +/// Primal `sec`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sec; +/// +/// assert!((sec(0.5_f64) - 1.0 / 0.5_f64.cos()).abs() < 1e-12); +/// ``` +pub fn sec(x: S) -> S { + inv(cos(x)) +} + +/// Forward rule for `sec`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sec_frule; +/// +/// let (y, dy) = sec_frule(0.5_f64, 1.0); +/// assert!((y - 1.0 / 0.5_f64.cos()).abs() < 1e-12); +/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12); +/// ``` +pub fn sec_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = cos_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `sec`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sec_rrule; +/// +/// let dy = sec_rrule(0.5_f64, 1.0); +/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12); +/// ``` +pub fn sec_rrule(x: S, cotangent: S) -> S { + let y = sec(x); + let d_y = inv_rrule(y, cotangent); + cos_rrule(x, d_y) +} + +/// Primal `csc`. +pub fn csc(x: S) -> S { + inv(sin(x)) +} + +/// Forward rule for `csc`. +pub fn csc_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = sin_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `csc`. +pub fn csc_rrule(x: S, cotangent: S) -> S { + let y = csc(x); + let d_y = inv_rrule(y, cotangent); + sin_rrule(x, d_y) +} + +/// Primal `cot`. +pub fn cot(x: S) -> S { + inv(tan(x)) +} + +/// Forward rule for `cot`. +pub fn cot_frule(x: S, dx: S) -> (S, S) { + let (y, dy) = tan_frule(x, dx); + inv_frule(y, dy) +} + +/// Reverse rule for `cot`. +pub fn cot_rrule(x: S, cotangent: S) -> S { + let y = cot(x); + let d_y = inv_rrule(y, cotangent); + tan_rrule(tan(x), d_y) +} + +/// Primal `sinpi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sinpi; +/// +/// assert!((sinpi(0.25_f64) - 0.25_f64.mul_add(std::f64::consts::PI, 0.0).sin()).abs() < 1e-12); +/// ``` +pub fn sinpi(x: S) -> S { + sincospi(x).0 +} + +/// Forward rule for `sinpi`. +pub fn sinpi_frule(x: S, dx: S) -> (S, S) { + let ((y, _), (dy, _)) = sincospi_frule(x, dx); + (y, dy) +} + +/// Reverse rule for `sinpi`. +pub fn sinpi_rrule(x: S, cotangent: S) -> S { + sincospi_rrule(x, (cotangent, S::from_i32(0))) +} + +/// Primal `cospi`. +pub fn cospi(x: S) -> S { + sincospi(x).1 +} + +/// Forward rule for `cospi`. +pub fn cospi_frule(x: S, dx: S) -> (S, S) { + let ((_, y), (_, dy)) = sincospi_frule(x, dx); + (y, dy) +} + +/// Reverse rule for `cospi`. +pub fn cospi_rrule(x: S, cotangent: S) -> S { + sincospi_rrule(x, (S::from_i32(0), cotangent)) +} + +/// Primal `sincospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincospi; +/// +/// let (s, c) = sincospi(0.25_f64); +/// assert!((s - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12); +/// assert!((c - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12); +/// ``` +pub fn sincospi(x: S) -> (S, S) { + sincos(pi::() * x) +} + +/// Forward rule for `sincospi`. +pub fn sincospi_frule(x: S, dx: S) -> ((S, S), (S, S)) { + let scale = pi::(); + let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); + sincos_frule(scaled_x, dscaled_x) +} + +/// Reverse rule for `sincospi`. +pub fn sincospi_rrule(x: S, cotangents: (S, S)) -> S { + let scale = pi::(); + let scaled_x = scale * x; + let dscaled_x = sincos_rrule(scaled_x, cotangents); + let (_, dx) = mul_rrule(scale, x, dscaled_x); + dx +} + +/// Primal `sind`. +pub fn sind(x: S) -> S { + sin(deg2rad::() * x) +} + +/// Forward rule for `sind`. +pub fn sind_frule(x: S, dx: S) -> (S, S) { + let scale = deg2rad::(); + let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); + sin_frule(scaled_x, dscaled_x) +} + +/// Reverse rule for `sind`. +pub fn sind_rrule(x: S, cotangent: S) -> S { + let scale = deg2rad::(); + let scaled_x = scale * x; + let dscaled_x = sin_rrule(scaled_x, cotangent); + let (_, dx) = mul_rrule(scale, x, dscaled_x); + dx +} + +/// Primal `cosd`. +pub fn cosd(x: S) -> S { + cos(deg2rad::() * x) +} + +/// Forward rule for `cosd`. +pub fn cosd_frule(x: S, dx: S) -> (S, S) { + let scale = deg2rad::(); + let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); + cos_frule(scaled_x, dscaled_x) +} + +/// Reverse rule for `cosd`. +pub fn cosd_rrule(x: S, cotangent: S) -> S { + let scale = deg2rad::(); + let scaled_x = scale * x; + let dscaled_x = cos_rrule(scaled_x, cotangent); + let (_, dx) = mul_rrule(scale, x, dscaled_x); + dx +} + +/// Primal `tand`. +pub fn tand(x: S) -> S { + tan(deg2rad::() * x) +} + +/// Forward rule for `tand`. +pub fn tand_frule(x: S, dx: S) -> (S, S) { + let scale = deg2rad::(); + let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); + tan_frule(scaled_x, dscaled_x) +} + +/// Reverse rule for `tand`. +pub fn tand_rrule(x: S, cotangent: S) -> S { + let scale = deg2rad::(); + let scaled_x = scale * x; + let y = tan(scaled_x); + let dscaled_x = tan_rrule(y, cotangent); + let (_, dx) = mul_rrule(scale, x, dscaled_x); + dx +} diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs new file mode 100644 index 0000000..fd6394a --- /dev/null +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -0,0 +1,95 @@ +use chainrules::{ + cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, csc, + csc_frule, csc_rrule, csch, csch_frule, csch_rrule, sec, sec_frule, sec_rrule, sech, + sech_frule, sech_rrule, sincospi, sincospi_frule, sincospi_rrule, sind, sind_frule, sind_rrule, + sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, +}; + +#[test] +fn julia_compat_primal_helpers_match_expected_values() { + let x = 0.25_f64; + assert!((sec(x) - (1.0 / x.cos())).abs() < 1e-12); + assert!((csc(x) - (1.0 / x.sin())).abs() < 1e-12); + assert!((cot(x) - (1.0 / x.tan())).abs() < 1e-12); + assert!((sinpi(x) - (std::f64::consts::PI * x).sin()).abs() < 1e-12); + assert!((cospi(x) - (std::f64::consts::PI * x).cos()).abs() < 1e-12); + let (s, c) = sincospi(x); + assert!((s - (std::f64::consts::PI * x).sin()).abs() < 1e-12); + assert!((c - (std::f64::consts::PI * x).cos()).abs() < 1e-12); + assert!((sind(30.0_f64) - 0.5_f64).abs() < 1e-12); + assert!((cosd(60.0_f64) - 0.5_f64).abs() < 1e-12); + assert!((tand(45.0_f64) - 1.0_f64).abs() < 1e-12); + assert!((sech(x) - (1.0 / x.cosh())).abs() < 1e-12); + assert!((csch(x) - (1.0 / x.sinh())).abs() < 1e-12); +} + +#[test] +fn julia_compat_derivative_helpers_match_expected_values() { + let x = 0.25_f64; + let g = 1.0_f64; + + let (_, dsec) = sec_frule(x, g); + assert!((dsec - (x.sin() / x.cos().powi(2))).abs() < 1e-12); + assert!((sec_rrule(x, g) - (x.sin() / x.cos().powi(2))).abs() < 1e-12); + + let (_, dsinpi) = sinpi_frule(x, g); + assert!((dsinpi - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12); + assert!( + (sinpi_rrule(x, g) - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12 + ); + + let (_, dtand) = tand_frule(45.0_f64, g); + assert!((dtand - std::f64::consts::PI / 180.0 * 2.0).abs() < 1e-12); + assert!((tand_rrule(45.0_f64, g) - std::f64::consts::PI / 180.0 * 2.0).abs() < 1e-12); + + let (_, dsech) = sech_frule(x, g); + let sech_x: f64 = 1.0 / x.cosh(); + assert!((dsech - (-sech_x * x.tanh())).abs() < 1e-12); + assert!((sech_rrule(x, g) - (-sech_x * x.tanh())).abs() < 1e-12); + + let (_, dcsc) = csc_frule(x, g); + assert!((dcsc - (-(x.cos() / x.sin().powi(2)))).abs() < 1e-12); + assert!((csc_rrule(x, g) - (-(x.cos() / x.sin().powi(2)))).abs() < 1e-12); + + let (_, dcot) = cot_frule(x, g); + assert!((dcot - (-(1.0 / x.sin().powi(2)))).abs() < 1e-12); + assert!((cot_rrule(x, g) - (-(1.0 / x.sin().powi(2)))).abs() < 1e-12); + + let (_, dcsch) = csch_frule(x, g); + let csch_x: f64 = 1.0 / x.sinh(); + assert!((dcsch - (-csch_x * x.cosh() / x.sinh())).abs() < 1e-12); + assert!((csch_rrule(x, g) - (-csch_x * x.cosh() / x.sinh())).abs() < 1e-12); + + let (_, dsind) = sind_frule(30.0_f64, g); + assert!((dsind - (std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos())).abs() < 1e-12); + assert!( + (sind_rrule(30.0_f64, g) - (std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos())) + .abs() + < 1e-12 + ); + + let (_, dcospi) = cospi_frule(x, g); + assert!((dcospi + std::f64::consts::PI * (std::f64::consts::PI * x).sin()).abs() < 1e-12); + assert!( + (cospi_rrule(x, g) + std::f64::consts::PI * (std::f64::consts::PI * x).sin()).abs() < 1e-12 + ); + + let (_, dcosd) = cosd_frule(60.0_f64, g); + assert!((dcosd + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12); + assert!( + (cosd_rrule(60.0_f64, g) + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()) + .abs() + < 1e-12 + ); + + let (_, dsincospi) = sincospi_frule(x, g); + assert!((dsincospi.0 - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12); + assert!((dsincospi.1 + std::f64::consts::PI * (std::f64::consts::PI * x).sin()).abs() < 1e-12); + assert!( + (sincospi_rrule(x, (g, g)) + - (std::f64::consts::PI * (std::f64::consts::PI * x).cos() + - std::f64::consts::PI * (std::f64::consts::PI * x).sin())) + .abs() + < 1e-12 + ); +} From 4bc19b94ea99e6126d50c58a6cdce9c6c650c8a6 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:08:22 +0900 Subject: [PATCH 18/32] test: add coth coverage --- crates/chainrules/tests/julia_compat_trig_tests.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs index fd6394a..c68c83d 100644 --- a/crates/chainrules/tests/julia_compat_trig_tests.rs +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -1,8 +1,8 @@ use chainrules::{ - cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, csc, - csc_frule, csc_rrule, csch, csch_frule, csch_rrule, sec, sec_frule, sec_rrule, sech, - sech_frule, sech_rrule, sincospi, sincospi_frule, sincospi_rrule, sind, sind_frule, sind_rrule, - sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, + cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, coth, + coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, csch_frule, csch_rrule, sec, + sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sincospi, sincospi_frule, sincospi_rrule, + sind, sind_frule, sind_rrule, sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, }; #[test] @@ -21,6 +21,7 @@ fn julia_compat_primal_helpers_match_expected_values() { assert!((tand(45.0_f64) - 1.0_f64).abs() < 1e-12); assert!((sech(x) - (1.0 / x.cosh())).abs() < 1e-12); assert!((csch(x) - (1.0 / x.sinh())).abs() < 1e-12); + assert!((coth(x) - (1.0 / x.tanh())).abs() < 1e-12); } #[test] @@ -60,6 +61,10 @@ fn julia_compat_derivative_helpers_match_expected_values() { assert!((dcsch - (-csch_x * x.cosh() / x.sinh())).abs() < 1e-12); assert!((csch_rrule(x, g) - (-csch_x * x.cosh() / x.sinh())).abs() < 1e-12); + let (_, dcoth) = coth_frule(x, g); + assert!((dcoth - (-(1.0 / x.sinh().powi(2)))).abs() < 1e-12); + assert!((coth_rrule(x, g) - (-(1.0 / x.sinh().powi(2)))).abs() < 1e-12); + let (_, dsind) = sind_frule(30.0_f64, g); assert!((dsind - (std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos())).abs() < 1e-12); assert!( From d40acc3e5e9dd7556364e4f120763780bc27fbbc Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:23:05 +0900 Subject: [PATCH 19/32] fix: add Julia landmark trig compatibility --- crates/chainrules/README.md | 4 + crates/chainrules/src/unary/trig_extra.rs | 162 +++++++++++++----- .../tests/julia_compat_trig_tests.rs | 85 +++++++++ 3 files changed, 211 insertions(+), 40 deletions(-) diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index 2697d36..cb4e71b 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -30,6 +30,8 @@ Current shipped scalar families: - exponentials and logs: `exp`, `expm1`, `log`, `log1p` - trigonometric: `sin`, `cos`, `asin`, `acos`, `atan` - hyperbolic: `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` +- Julia-compatible trigonometric helpers: `sec`, `csc`, `cot`, `sinpi`, `cospi`, `sincospi`, `sind`, `cosd`, `tand` +- Julia-compatible hyperbolic helpers: `sech`, `csch`, `coth` - smooth helpers: `cbrt`, `inv`, `exp2`, `exp10`, `log2`, `log10`, `hypot`, `pow`, `sincos`, `tan` - complex and projection helpers: `conj`, `abs`, `abs2`, `angle`, `real`, `imag`, `complex`, `handle_r_to_c_f32`, `handle_r_to_c_f64` - real-valued binary helpers: `atan2` @@ -45,6 +47,8 @@ repository-local tests: - `tests/scalarops_tests.rs` covers direct formulas, edge cases, and smooth real/complex behavior +- `tests/julia_compat_trig_tests.rs` covers Julia migration helpers, including + landmark real inputs and representative Complex64 behavior - `tests/complex_helper_tests.rs` covers the projection helpers and complex constructor surface - `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from diff --git a/crates/chainrules/src/unary/trig_extra.rs b/crates/chainrules/src/unary/trig_extra.rs index 15e6e21..c3eb878 100644 --- a/crates/chainrules/src/unary/trig_extra.rs +++ b/crates/chainrules/src/unary/trig_extra.rs @@ -1,17 +1,88 @@ use crate::binary::{mul_frule, mul_rrule}; use crate::unary::{ - cos, cos_frule, cos_rrule, inv, inv_frule, inv_rrule, sin, sin_frule, sin_rrule, sincos, - sincos_frule, sincos_rrule, tan, tan_frule, tan_rrule, + cos, cos_frule, cos_rrule, inv, inv_frule, inv_rrule, sin, sin_frule, sin_rrule, sincos, tan, + tan_frule, tan_rrule, }; use crate::ScalarAd; -use num_traits::FloatConst; +use num_traits::{Float, FloatConst, Zero}; fn pi() -> S { S::from_real(S::Real::PI()) } +fn real(value: f64) -> R { + match R::from(value) { + Some(value) => value, + None => unreachable!("float constant conversion should succeed"), + } +} + fn deg2rad() -> S { - pi::() / S::from_i32(180) + pi::() / S::from_real(real::(180.0)) +} + +fn real_input(x: S) -> Option { + if x.imag().is_zero() { + Some(x.real()) + } else { + None + } +} + +fn sinpi_real(x: R) -> R { + let two = real::(2.0); + let reduced = x - (x / two).floor() * two; + let zero = real::(0.0); + let one = real::(1.0); + let half = real::(0.5); + let three_half = real::(1.5); + if reduced == zero || reduced == one { + zero + } else if reduced == half { + one + } else if reduced == three_half { + -one + } else { + (R::PI() * reduced).sin() + } +} + +fn cospi_real(x: R) -> R { + let two = real::(2.0); + let reduced = x - (x / two).floor() * two; + let zero = real::(0.0); + let one = real::(1.0); + let half = real::(0.5); + let three_half = real::(1.5); + if reduced == zero { + one + } else if reduced == one { + -one + } else if reduced == half || reduced == three_half { + zero + } else { + (R::PI() * reduced).cos() + } +} + +fn tand_real(x: R) -> R { + let one_eighty = real::(180.0); + let reduced = x - (x / one_eighty).floor() * one_eighty; + let zero = real::(0.0); + let forty_five = real::(45.0); + let ninety = real::(90.0); + let one_thirty_five = real::(135.0); + if reduced == zero { + zero + } else if reduced == forty_five { + real::(1.0) + } else if reduced == ninety { + R::infinity() + } else if reduced == one_thirty_five { + real::(-1.0) + } else { + (R::PI() * reduced / one_eighty).tan() + } } /// Primal `sec`. @@ -105,34 +176,42 @@ pub fn cot_rrule(x: S, cotangent: S) -> S { /// assert!((sinpi(0.25_f64) - 0.25_f64.mul_add(std::f64::consts::PI, 0.0).sin()).abs() < 1e-12); /// ``` pub fn sinpi(x: S) -> S { - sincospi(x).0 + if let Some(x_real) = real_input(x) { + return S::from_real(sinpi_real(x_real)); + } + sincos(pi::() * x).0 } /// Forward rule for `sinpi`. pub fn sinpi_frule(x: S, dx: S) -> (S, S) { - let ((y, _), (dy, _)) = sincospi_frule(x, dx); - (y, dy) + let y = sinpi(x); + let scale = pi::() * cospi(x); + (y, dx * scale.conj()) } /// Reverse rule for `sinpi`. pub fn sinpi_rrule(x: S, cotangent: S) -> S { - sincospi_rrule(x, (cotangent, S::from_i32(0))) + cotangent * (pi::() * cospi(x)).conj() } /// Primal `cospi`. pub fn cospi(x: S) -> S { - sincospi(x).1 + if let Some(x_real) = real_input(x) { + return S::from_real(cospi_real(x_real)); + } + sincos(pi::() * x).1 } /// Forward rule for `cospi`. pub fn cospi_frule(x: S, dx: S) -> (S, S) { - let ((_, y), (_, dy)) = sincospi_frule(x, dx); - (y, dy) + let y = cospi(x); + let scale = -(pi::() * sinpi(x)); + (y, dx * scale.conj()) } /// Reverse rule for `cospi`. pub fn cospi_rrule(x: S, cotangent: S) -> S { - sincospi_rrule(x, (S::from_i32(0), cotangent)) + cotangent * (-(pi::() * sinpi(x))).conj() } /// Primal `sincospi`. @@ -147,85 +226,88 @@ pub fn cospi_rrule(x: S, cotangent: S) -> S { /// assert!((c - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12); /// ``` pub fn sincospi(x: S) -> (S, S) { - sincos(pi::() * x) + (sinpi(x), cospi(x)) } /// Forward rule for `sincospi`. pub fn sincospi_frule(x: S, dx: S) -> ((S, S), (S, S)) { - let scale = pi::(); - let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); - sincos_frule(scaled_x, dscaled_x) + let sin_x = sinpi(x); + let cos_x = cospi(x); + ( + (sin_x, cos_x), + ( + dx * (pi::() * cos_x).conj(), + dx * (-(pi::() * sin_x)).conj(), + ), + ) } /// Reverse rule for `sincospi`. pub fn sincospi_rrule(x: S, cotangents: (S, S)) -> S { - let scale = pi::(); - let scaled_x = scale * x; - let dscaled_x = sincos_rrule(scaled_x, cotangents); - let (_, dx) = mul_rrule(scale, x, dscaled_x); - dx + let (cotangent_sin, cotangent_cos) = cotangents; + sinpi_rrule(x, cotangent_sin) + cospi_rrule(x, cotangent_cos) } /// Primal `sind`. pub fn sind(x: S) -> S { - sin(deg2rad::() * x) + sinpi(x / S::from_real(real::(180.0))) } /// Forward rule for `sind`. pub fn sind_frule(x: S, dx: S) -> (S, S) { - let scale = deg2rad::(); + let scale = S::from_real(real::(1.0 / 180.0)); let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); - sin_frule(scaled_x, dscaled_x) + sinpi_frule(scaled_x, dscaled_x) } /// Reverse rule for `sind`. pub fn sind_rrule(x: S, cotangent: S) -> S { - let scale = deg2rad::(); + let scale = S::from_real(real::(1.0 / 180.0)); let scaled_x = scale * x; - let dscaled_x = sin_rrule(scaled_x, cotangent); + let dscaled_x = sinpi_rrule(scaled_x, cotangent); let (_, dx) = mul_rrule(scale, x, dscaled_x); dx } /// Primal `cosd`. pub fn cosd(x: S) -> S { - cos(deg2rad::() * x) + cospi(x / S::from_real(real::(180.0))) } /// Forward rule for `cosd`. pub fn cosd_frule(x: S, dx: S) -> (S, S) { - let scale = deg2rad::(); + let scale = S::from_real(real::(1.0 / 180.0)); let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); - cos_frule(scaled_x, dscaled_x) + cospi_frule(scaled_x, dscaled_x) } /// Reverse rule for `cosd`. pub fn cosd_rrule(x: S, cotangent: S) -> S { - let scale = deg2rad::(); + let scale = S::from_real(real::(1.0 / 180.0)); let scaled_x = scale * x; - let dscaled_x = cos_rrule(scaled_x, cotangent); + let dscaled_x = cospi_rrule(scaled_x, cotangent); let (_, dx) = mul_rrule(scale, x, dscaled_x); dx } /// Primal `tand`. pub fn tand(x: S) -> S { + if let Some(x_real) = real_input(x) { + return S::from_real(tand_real(x_real)); + } tan(deg2rad::() * x) } /// Forward rule for `tand`. pub fn tand_frule(x: S, dx: S) -> (S, S) { - let scale = deg2rad::(); - let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); - tan_frule(scaled_x, dscaled_x) + let y = tand(x); + let scale = deg2rad::() * (S::from_i32(1) + y * y); + (y, dx * scale.conj()) } /// Reverse rule for `tand`. pub fn tand_rrule(x: S, cotangent: S) -> S { - let scale = deg2rad::(); - let scaled_x = scale * x; - let y = tan(scaled_x); - let dscaled_x = tan_rrule(y, cotangent); - let (_, dx) = mul_rrule(scale, x, dscaled_x); - dx + let y = tand(x); + let scale = deg2rad::() * (S::from_i32(1) + y * y); + cotangent * scale.conj() } diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs index c68c83d..367fd9a 100644 --- a/crates/chainrules/tests/julia_compat_trig_tests.rs +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -4,6 +4,24 @@ use chainrules::{ sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sincospi, sincospi_frule, sincospi_rrule, sind, sind_frule, sind_rrule, sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, }; +use num_complex::Complex64; + +fn assert_complex_close(actual: Complex64, expected: Complex64) { + assert!((actual - expected).norm() < 1e-12); +} + +#[test] +fn julia_compat_landmark_real_inputs_match_julia_style_values() { + assert_eq!(sinpi(1.0_f64), 0.0_f64); + assert_eq!(sinpi(0.5_f64), 1.0_f64); + assert_eq!(cospi(0.5_f64), 0.0_f64); + let (s, c) = sincospi(0.5_f64); + assert_eq!(s, 1.0_f64); + assert_eq!(c, 0.0_f64); + assert_eq!(sind(180.0_f64), 0.0_f64); + assert_eq!(cosd(90.0_f64), 0.0_f64); + assert_eq!(tand(45.0_f64), 1.0_f64); +} #[test] fn julia_compat_primal_helpers_match_expected_values() { @@ -33,6 +51,14 @@ fn julia_compat_derivative_helpers_match_expected_values() { assert!((dsec - (x.sin() / x.cos().powi(2))).abs() < 1e-12); assert!((sec_rrule(x, g) - (x.sin() / x.cos().powi(2))).abs() < 1e-12); + let (_, dsinpi_landmark) = sinpi_frule(1.0_f64, g); + assert!((dsinpi_landmark + std::f64::consts::PI).abs() < 1e-12); + assert!((sinpi_rrule(1.0_f64, g) + std::f64::consts::PI).abs() < 1e-12); + + let (_, dcospi_landmark) = cospi_frule(0.5_f64, g); + assert!((dcospi_landmark + std::f64::consts::PI).abs() < 1e-12); + assert!((cospi_rrule(0.5_f64, g) + std::f64::consts::PI).abs() < 1e-12); + let (_, dsinpi) = sinpi_frule(x, g); assert!((dsinpi - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12); assert!( @@ -98,3 +124,62 @@ fn julia_compat_derivative_helpers_match_expected_values() { < 1e-12 ); } + +#[test] +fn julia_compat_landmark_inputs_match_expected_values() { + assert_eq!(sinpi(1.0_f64), 0.0); + assert_eq!(cospi(0.5_f64), 0.0); + + let (s, c) = sincospi(0.5_f64); + assert_eq!(s, 1.0); + assert_eq!(c, 0.0); + + assert_eq!(sind(180.0_f64), 0.0); + assert_eq!(cosd(90.0_f64), 0.0); + assert_eq!(tand(45.0_f64), 1.0); + assert!(tand(90.0_f64).is_infinite()); +} + +#[test] +fn julia_compat_helpers_cover_complex_primal_and_cotangent_paths() { + let z = Complex64::new(0.25, -0.1); + let dz = Complex64::new(1.0, -0.25); + let cotangent = Complex64::new(0.75, -0.5); + let pi_z = Complex64::new(std::f64::consts::PI, 0.0) * z; + + assert_complex_close(sinpi(z), pi_z.sin()); + assert_complex_close(cospi(z), pi_z.cos()); + + let (_, dsinpi) = sinpi_frule(z, dz); + assert_complex_close( + dsinpi, + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()).conj(), + ); + + assert_complex_close( + sec_rrule(z, cotangent), + cotangent * (z.sin() / z.cos().powi(2)).conj(), + ); + assert_complex_close( + sinpi_rrule(z, cotangent), + cotangent * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()).conj(), + ); +} + +#[test] +fn julia_compat_complex_inputs_cover_generic_surface() { + let z = Complex64::new(0.25, -0.5); + + let pi_z = Complex64::new(std::f64::consts::PI, 0.0) * z; + assert!((sinpi(z) - pi_z.sin()).norm() < 1e-12); + assert!((cospi(z) - pi_z.cos()).norm() < 1e-12); + let (s, c) = sincospi(z); + assert!((s - pi_z.sin()).norm() < 1e-12); + assert!((c - pi_z.cos()).norm() < 1e-12); + + let deg_z = Complex64::new(std::f64::consts::PI / 180.0, 0.0) * z; + assert!((sind(z) - deg_z.sin()).norm() < 1e-12); + assert!((cosd(z) - deg_z.cos()).norm() < 1e-12); + assert!((tand(z) - deg_z.tan()).norm() < 1e-12); + assert!((coth(z) - Complex64::new(1.0, 0.0) / z.tanh()).norm() < 1e-12); +} From 7fbbeefe04632461bd7f8a5be193c201d740e47b Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:33:29 +0900 Subject: [PATCH 20/32] fix: harden Julia compatibility helpers --- README.md | 5 +- crates/chainrules/src/tests/organization.rs | 10 + .../chainrules/src/unary/hyperbolic_extra.rs | 74 +++++++ crates/chainrules/src/unary/trig_extra.rs | 194 +++++++++++++++++- .../tests/julia_compat_trig_tests.rs | 18 +- 5 files changed, 278 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 92df6a2..0b3c767 100644 --- a/README.md +++ b/README.md @@ -62,9 +62,12 @@ See the crate READMEs for the supported scalar function inventory and examples. ## Testing -Scalar rules are checked in two complementary ways: +Scalar rules are checked in complementary ways: - formula and behavior tests in `crates/chainrules/tests/scalarops_tests.rs` +- compatibility and edge-case tests such as + `crates/chainrules/tests/julia_compat_trig_tests.rs` and + `crates/chainrules/tests/complex_helper_tests.rs` - oracle replay tests in `crates/chainrules/tests/oracle_scalar_rules.rs` against vendored published cases from `third_party/tensor-ad-oracles` diff --git a/crates/chainrules/src/tests/organization.rs b/crates/chainrules/src/tests/organization.rs index 123787e..c66ae0c 100644 --- a/crates/chainrules/src/tests/organization.rs +++ b/crates/chainrules/src/tests/organization.rs @@ -24,6 +24,16 @@ fn chainrules_modules_stay_under_size_guideline() { include_str!("../unary/hyperbolic.rs"), 140, ); + assert_line_count( + "../unary/trig_extra.rs", + include_str!("../unary/trig_extra.rs"), + 500, + ); + assert_line_count( + "../unary/hyperbolic_extra.rs", + include_str!("../unary/hyperbolic_extra.rs"), + 180, + ); assert_line_count( "../unary/smooth.rs", include_str!("../unary/smooth.rs"), diff --git a/crates/chainrules/src/unary/hyperbolic_extra.rs b/crates/chainrules/src/unary/hyperbolic_extra.rs index 26b1b26..754fa97 100644 --- a/crates/chainrules/src/unary/hyperbolic_extra.rs +++ b/crates/chainrules/src/unary/hyperbolic_extra.rs @@ -18,12 +18,32 @@ pub fn sech(x: S) -> S { } /// Forward rule for `sech`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sech_frule; +/// +/// let (_, dy) = sech_frule(0.5_f64, 1.0); +/// let sech_x = 1.0 / 0.5_f64.cosh(); +/// assert!((dy + sech_x * 0.5_f64.tanh()).abs() < 1e-12); +/// ``` pub fn sech_frule(x: S, dx: S) -> (S, S) { let (y, dy) = cosh_frule(x, dx); inv_frule(y, dy) } /// Reverse rule for `sech`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sech_rrule; +/// +/// let dy = sech_rrule(0.5_f64, 1.0); +/// let sech_x = 1.0 / 0.5_f64.cosh(); +/// assert!((dy + sech_x * 0.5_f64.tanh()).abs() < 1e-12); +/// ``` pub fn sech_rrule(x: S, cotangent: S) -> S { let y = sech(x); let d_y = inv_rrule(y, cotangent); @@ -31,17 +51,45 @@ pub fn sech_rrule(x: S, cotangent: S) -> S { } /// Primal `csch`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::csch; +/// +/// assert!((csch(0.5_f64) - 1.0 / 0.5_f64.sinh()).abs() < 1e-12); +/// ``` pub fn csch(x: S) -> S { inv(sinh(x)) } /// Forward rule for `csch`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::csch_frule; +/// +/// let (_, dy) = csch_frule(0.5_f64, 1.0); +/// let csch_x = 1.0 / 0.5_f64.sinh(); +/// assert!((dy + csch_x * 0.5_f64.cosh() / 0.5_f64.sinh()).abs() < 1e-12); +/// ``` pub fn csch_frule(x: S, dx: S) -> (S, S) { let (y, dy) = sinh_frule(x, dx); inv_frule(y, dy) } /// Reverse rule for `csch`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::csch_rrule; +/// +/// let dy = csch_rrule(0.5_f64, 1.0); +/// let csch_x = 1.0 / 0.5_f64.sinh(); +/// assert!((dy + csch_x * 0.5_f64.cosh() / 0.5_f64.sinh()).abs() < 1e-12); +/// ``` pub fn csch_rrule(x: S, cotangent: S) -> S { let y = csch(x); let d_y = inv_rrule(y, cotangent); @@ -49,17 +97,43 @@ pub fn csch_rrule(x: S, cotangent: S) -> S { } /// Primal `coth`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::coth; +/// +/// assert!((coth(0.5_f64) - 1.0 / 0.5_f64.tanh()).abs() < 1e-12); +/// ``` pub fn coth(x: S) -> S { inv(tanh(x)) } /// Forward rule for `coth`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::coth_frule; +/// +/// let (_, dy) = coth_frule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sinh().powi(2)).abs() < 1e-12); +/// ``` pub fn coth_frule(x: S, dx: S) -> (S, S) { let (y, dy) = tanh_frule(x, dx); inv_frule(y, dy) } /// Reverse rule for `coth`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::coth_rrule; +/// +/// let dy = coth_rrule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sinh().powi(2)).abs() < 1e-12); +/// ``` pub fn coth_rrule(x: S, cotangent: S) -> S { let y = coth(x); let d_y = inv_rrule(y, cotangent); diff --git a/crates/chainrules/src/unary/trig_extra.rs b/crates/chainrules/src/unary/trig_extra.rs index c3eb878..ec80b28 100644 --- a/crates/chainrules/src/unary/trig_extra.rs +++ b/crates/chainrules/src/unary/trig_extra.rs @@ -77,7 +77,7 @@ fn tand_real(x: R) -> R { } else if reduced == forty_five { real::(1.0) } else if reduced == ninety { - R::infinity() + R::infinity().copysign(sinpi_real(x / one_eighty)) } else if reduced == one_thirty_five { real::(-1.0) } else { @@ -88,10 +88,8 @@ fn tand_real(x: R) -> R { /// Primal `sec`. /// /// # Examples -/// /// ```rust /// use chainrules::sec; -/// /// assert!((sec(0.5_f64) - 1.0 / 0.5_f64.cos()).abs() < 1e-12); /// ``` pub fn sec(x: S) -> S { @@ -101,10 +99,8 @@ pub fn sec(x: S) -> S { /// Forward rule for `sec`. /// /// # Examples -/// /// ```rust /// use chainrules::sec_frule; -/// /// let (y, dy) = sec_frule(0.5_f64, 1.0); /// assert!((y - 1.0 / 0.5_f64.cos()).abs() < 1e-12); /// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12); @@ -117,10 +113,8 @@ pub fn sec_frule(x: S, dx: S) -> (S, S) { /// Reverse rule for `sec`. /// /// # Examples -/// /// ```rust /// use chainrules::sec_rrule; -/// /// let dy = sec_rrule(0.5_f64, 1.0); /// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12); /// ``` @@ -131,17 +125,37 @@ pub fn sec_rrule(x: S, cotangent: S) -> S { } /// Primal `csc`. +/// +/// # Examples +/// ```rust +/// use chainrules::csc; +/// assert!((csc(0.5_f64) - 1.0 / 0.5_f64.sin()).abs() < 1e-12); +/// ``` pub fn csc(x: S) -> S { inv(sin(x)) } /// Forward rule for `csc`. +/// +/// # Examples +/// ```rust +/// use chainrules::csc_frule; +/// let (_, dy) = csc_frule(0.5_f64, 1.0); +/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` pub fn csc_frule(x: S, dx: S) -> (S, S) { let (y, dy) = sin_frule(x, dx); inv_frule(y, dy) } /// Reverse rule for `csc`. +/// +/// # Examples +/// ```rust +/// use chainrules::csc_rrule; +/// let dy = csc_rrule(0.5_f64, 1.0); +/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` pub fn csc_rrule(x: S, cotangent: S) -> S { let y = csc(x); let d_y = inv_rrule(y, cotangent); @@ -149,17 +163,37 @@ pub fn csc_rrule(x: S, cotangent: S) -> S { } /// Primal `cot`. +/// +/// # Examples +/// ```rust +/// use chainrules::cot; +/// assert!((cot(0.5_f64) - 1.0 / 0.5_f64.tan()).abs() < 1e-12); +/// ``` pub fn cot(x: S) -> S { inv(tan(x)) } /// Forward rule for `cot`. +/// +/// # Examples +/// ```rust +/// use chainrules::cot_frule; +/// let (_, dy) = cot_frule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` pub fn cot_frule(x: S, dx: S) -> (S, S) { let (y, dy) = tan_frule(x, dx); inv_frule(y, dy) } /// Reverse rule for `cot`. +/// +/// # Examples +/// ```rust +/// use chainrules::cot_rrule; +/// let dy = cot_rrule(0.5_f64, 1.0); +/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12); +/// ``` pub fn cot_rrule(x: S, cotangent: S) -> S { let y = cot(x); let d_y = inv_rrule(y, cotangent); @@ -183,6 +217,15 @@ pub fn sinpi(x: S) -> S { } /// Forward rule for `sinpi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sinpi_frule; +/// +/// let (_, dy) = sinpi_frule(0.25_f64, 1.0); +/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// ``` pub fn sinpi_frule(x: S, dx: S) -> (S, S) { let y = sinpi(x); let scale = pi::() * cospi(x); @@ -190,11 +233,28 @@ pub fn sinpi_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `sinpi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sinpi_rrule; +/// +/// let dy = sinpi_rrule(0.25_f64, 1.0); +/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// ``` pub fn sinpi_rrule(x: S, cotangent: S) -> S { cotangent * (pi::() * cospi(x)).conj() } /// Primal `cospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cospi; +/// +/// assert!((cospi(0.25_f64) - (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// ``` pub fn cospi(x: S) -> S { if let Some(x_real) = real_input(x) { return S::from_real(cospi_real(x_real)); @@ -203,6 +263,15 @@ pub fn cospi(x: S) -> S { } /// Forward rule for `cospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cospi_frule; +/// +/// let (_, dy) = cospi_frule(0.25_f64, 1.0); +/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12); +/// ``` pub fn cospi_frule(x: S, dx: S) -> (S, S) { let y = cospi(x); let scale = -(pi::() * sinpi(x)); @@ -210,6 +279,15 @@ pub fn cospi_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `cospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cospi_rrule; +/// +/// let dy = cospi_rrule(0.25_f64, 1.0); +/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12); +/// ``` pub fn cospi_rrule(x: S, cotangent: S) -> S { cotangent * (-(pi::() * sinpi(x))).conj() } @@ -230,6 +308,16 @@ pub fn sincospi(x: S) -> (S, S) { } /// Forward rule for `sincospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincospi_frule; +/// +/// let ((_, _), (ds, dc)) = sincospi_frule(0.25_f64, 1.0); +/// assert!((ds - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12); +/// assert!((dc + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12); +/// ``` pub fn sincospi_frule(x: S, dx: S) -> ((S, S), (S, S)) { let sin_x = sinpi(x); let cos_x = cospi(x); @@ -243,17 +331,48 @@ pub fn sincospi_frule(x: S, dx: S) -> ((S, S), (S, S)) { } /// Reverse rule for `sincospi`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sincospi_rrule; +/// +/// let dx = sincospi_rrule(0.25_f64, (1.0, 1.0)); +/// assert!( +/// (dx - (std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos() +/// - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin())) +/// .abs() +/// < 1e-12 +/// ); +/// ``` pub fn sincospi_rrule(x: S, cotangents: (S, S)) -> S { let (cotangent_sin, cotangent_cos) = cotangents; sinpi_rrule(x, cotangent_sin) + cospi_rrule(x, cotangent_cos) } /// Primal `sind`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sind; +/// +/// assert!((sind(30.0_f64) - 0.5_f64).abs() < 1e-12); +/// ``` pub fn sind(x: S) -> S { sinpi(x / S::from_real(real::(180.0))) } /// Forward rule for `sind`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sind_frule; +/// +/// let (_, dy) = sind_frule(30.0_f64, 1.0); +/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12); +/// ``` pub fn sind_frule(x: S, dx: S) -> (S, S) { let scale = S::from_real(real::(1.0 / 180.0)); let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); @@ -261,6 +380,15 @@ pub fn sind_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `sind`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sind_rrule; +/// +/// let dy = sind_rrule(30.0_f64, 1.0); +/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12); +/// ``` pub fn sind_rrule(x: S, cotangent: S) -> S { let scale = S::from_real(real::(1.0 / 180.0)); let scaled_x = scale * x; @@ -270,11 +398,28 @@ pub fn sind_rrule(x: S, cotangent: S) -> S { } /// Primal `cosd`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cosd; +/// +/// assert!((cosd(60.0_f64) - 0.5_f64).abs() < 1e-12); +/// ``` pub fn cosd(x: S) -> S { cospi(x / S::from_real(real::(180.0))) } /// Forward rule for `cosd`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cosd_frule; +/// +/// let (_, dy) = cosd_frule(60.0_f64, 1.0); +/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12); +/// ``` pub fn cosd_frule(x: S, dx: S) -> (S, S) { let scale = S::from_real(real::(1.0 / 180.0)); let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx); @@ -282,6 +427,15 @@ pub fn cosd_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `cosd`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::cosd_rrule; +/// +/// let dy = cosd_rrule(60.0_f64, 1.0); +/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12); +/// ``` pub fn cosd_rrule(x: S, cotangent: S) -> S { let scale = S::from_real(real::(1.0 / 180.0)); let scaled_x = scale * x; @@ -291,6 +445,14 @@ pub fn cosd_rrule(x: S, cotangent: S) -> S { } /// Primal `tand`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tand; +/// +/// assert_eq!(tand(45.0_f64), 1.0); +/// ``` pub fn tand(x: S) -> S { if let Some(x_real) = real_input(x) { return S::from_real(tand_real(x_real)); @@ -299,6 +461,15 @@ pub fn tand(x: S) -> S { } /// Forward rule for `tand`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tand_frule; +/// +/// let (_, dy) = tand_frule(45.0_f64, 1.0); +/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12); +/// ``` pub fn tand_frule(x: S, dx: S) -> (S, S) { let y = tand(x); let scale = deg2rad::() * (S::from_i32(1) + y * y); @@ -306,6 +477,15 @@ pub fn tand_frule(x: S, dx: S) -> (S, S) { } /// Reverse rule for `tand`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::tand_rrule; +/// +/// let dy = tand_rrule(45.0_f64, 1.0); +/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12); +/// ``` pub fn tand_rrule(x: S, cotangent: S) -> S { let y = tand(x); let scale = deg2rad::() * (S::from_i32(1) + y * y); diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs index 367fd9a..4a8cf0d 100644 --- a/crates/chainrules/tests/julia_compat_trig_tests.rs +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -21,6 +21,9 @@ fn julia_compat_landmark_real_inputs_match_julia_style_values() { assert_eq!(sind(180.0_f64), 0.0_f64); assert_eq!(cosd(90.0_f64), 0.0_f64); assert_eq!(tand(45.0_f64), 1.0_f64); + assert_eq!(tand(90.0_f64), f64::INFINITY); + assert_eq!(tand(-90.0_f64), f64::NEG_INFINITY); + assert_eq!(tand(270.0_f64), f64::NEG_INFINITY); } #[test] @@ -125,21 +128,6 @@ fn julia_compat_derivative_helpers_match_expected_values() { ); } -#[test] -fn julia_compat_landmark_inputs_match_expected_values() { - assert_eq!(sinpi(1.0_f64), 0.0); - assert_eq!(cospi(0.5_f64), 0.0); - - let (s, c) = sincospi(0.5_f64); - assert_eq!(s, 1.0); - assert_eq!(c, 0.0); - - assert_eq!(sind(180.0_f64), 0.0); - assert_eq!(cosd(90.0_f64), 0.0); - assert_eq!(tand(45.0_f64), 1.0); - assert!(tand(90.0_f64).is_infinite()); -} - #[test] fn julia_compat_helpers_cover_complex_primal_and_cotangent_paths() { let z = Complex64::new(0.25, -0.1); From d5930a4925047b820980f74e22379872595ade81 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:39:11 +0900 Subject: [PATCH 21/32] feat: add nonsmooth scalar utility rules --- crates/chainrules/src/binary_special.rs | 128 +++++++ crates/chainrules/src/lib.rs | 27 +- crates/chainrules/src/unary/mod.rs | 5 + crates/chainrules/src/unary/nonsmooth.rs | 323 ++++++++++++++++++ .../tests/nonsmooth_scalar_tests.rs | 117 +++++++ 5 files changed, 588 insertions(+), 12 deletions(-) create mode 100644 crates/chainrules/src/unary/nonsmooth.rs create mode 100644 crates/chainrules/tests/nonsmooth_scalar_tests.rs diff --git a/crates/chainrules/src/binary_special.rs b/crates/chainrules/src/binary_special.rs index f2e3b4b..e06b23f 100644 --- a/crates/chainrules/src/binary_special.rs +++ b/crates/chainrules/src/binary_special.rs @@ -1,5 +1,13 @@ use num_traits::Float; +fn select_first_for_min(x: R, y: R) -> bool { + !x.is_nan() && (y.is_nan() || x < y) +} + +fn select_first_for_max(x: R, y: R) -> bool { + !x.is_nan() && (y.is_nan() || x > y) +} + /// Primal `hypot`. /// /// # Examples @@ -46,3 +54,123 @@ pub fn hypot_rrule(x: R, y: R, cotangent: R) -> (R, R) { let inv_r = R::one() / r; (cotangent * (x * inv_r), cotangent * (y * inv_r)) } + +/// Primal `min`. +/// +/// The primal follows `Float::min`; tie behavior routes the derivative to the +/// second argument. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min; +/// +/// assert_eq!(min(1.5_f64, 2.5_f64), 1.5); +/// assert_eq!(min(2.0_f64, 2.0_f64), 2.0); +/// ``` +pub fn min(x: R, y: R) -> R { + x.min(y) +} + +/// Forward rule for `min`. +/// +/// When `x == y`, the tangent comes from `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min_frule; +/// +/// let (z, dz) = min_frule(1.0_f64, 2.0_f64, 0.25, 0.5); +/// assert_eq!(z, 1.0); +/// assert_eq!(dz, 0.25); +/// ``` +pub fn min_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let z = x.min(y); + if select_first_for_min(x, y) { + (z, dx) + } else { + (z, dy) + } +} + +/// Reverse rule for `min`. +/// +/// When `x == y`, the cotangent goes to `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min_rrule; +/// +/// let (dx, dy) = min_rrule(1.0_f64, 2.0_f64, 0.5); +/// assert_eq!(dx, 0.5); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn min_rrule(x: R, y: R, cotangent: R) -> (R, R) { + if select_first_for_min(x, y) { + (cotangent, R::zero()) + } else { + (R::zero(), cotangent) + } +} + +/// Primal `max`. +/// +/// The primal follows `Float::max`; tie behavior routes the derivative to the +/// second argument. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max; +/// +/// assert_eq!(max(1.5_f64, 2.5_f64), 2.5); +/// assert_eq!(max(2.0_f64, 2.0_f64), 2.0); +/// ``` +pub fn max(x: R, y: R) -> R { + x.max(y) +} + +/// Forward rule for `max`. +/// +/// When `x == y`, the tangent comes from `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max_frule; +/// +/// let (z, dz) = max_frule(1.0_f64, 2.0_f64, 0.25, 0.5); +/// assert_eq!(z, 2.0); +/// assert_eq!(dz, 0.5); +/// ``` +pub fn max_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let z = x.max(y); + if select_first_for_max(x, y) { + (z, dx) + } else { + (z, dy) + } +} + +/// Reverse rule for `max`. +/// +/// When `x == y`, the cotangent goes to `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max_rrule; +/// +/// let (dx, dy) = max_rrule(1.0_f64, 2.0_f64, 0.5); +/// assert_eq!(dx, 0.0); +/// assert_eq!(dy, 0.5); +/// ``` +pub fn max_rrule(x: R, y: R, cotangent: R) -> (R, R) { + if select_first_for_max(x, y) { + (cotangent, R::zero()) + } else { + (R::zero(), cotangent) + } +} diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index 7a040a4..aba0777 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -18,6 +18,8 @@ pub use binary::{ sub_frule, sub_rrule, }; #[doc(inline)] +pub use binary_special::{max, max_frule, max_rrule, min, min_frule, min_rrule}; +#[doc(inline)] pub use power::{powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule}; #[doc(inline)] pub use real_ops::{atan2, atan2_frule, atan2_rrule}; @@ -28,18 +30,19 @@ pub use unary::{ abs, abs2, abs2_frule, abs2_rrule, acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, angle, angle_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, atan, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, cbrt, cbrt_frule, cbrt_rrule, - complex, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosd, cosd_frule, cosd_rrule, - cosh, cosh_frule, cosh_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, coth, - coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, csch_frule, csch_rrule, exp, exp10, - exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1, - expm1_frule, expm1_rrule, hypot, hypot_frule, hypot_rrule, imag, imag_rrule, inv, inv_frule, - inv_rrule, log, log10, log10_frule, log10_rrule, log1p, log1p_frule, log1p_rrule, log2, - log2_frule, log2_rrule, log_frule, log_rrule, pow, pow_frule, pow_rrule, real, real_rrule, sec, - sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sin, sin_frule, sin_rrule, sincos, - sincos_frule, sincos_rrule, sincospi, sincospi_frule, sincospi_rrule, sind, sind_frule, - sind_rrule, sinh, sinh_frule, sinh_rrule, sinpi, sinpi_frule, sinpi_rrule, sqrt, sqrt_frule, - sqrt_rrule, tan, tan_frule, tan_rrule, tand, tand_frule, tand_rrule, tanh, tanh_frule, - tanh_rrule, + ceil, ceil_frule, ceil_rrule, complex, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, + cosd, cosd_frule, cosd_rrule, cosh, cosh_frule, cosh_rrule, cospi, cospi_frule, cospi_rrule, + cot, cot_frule, cot_rrule, coth, coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, + csch_frule, csch_rrule, exp, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, + exp_frule, exp_rrule, expm1, expm1_frule, expm1_rrule, floor, floor_frule, floor_rrule, hypot, + hypot_frule, hypot_rrule, imag, imag_rrule, inv, inv_frule, inv_rrule, log, log10, log10_frule, + log10_rrule, log1p, log1p_frule, log1p_rrule, log2, log2_frule, log2_rrule, log_frule, + log_rrule, pow, pow_frule, pow_rrule, real, real_rrule, round, round_frule, round_rrule, sec, + sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sign, sign_frule, sign_rrule, sin, + sin_frule, sin_rrule, sincos, sincos_frule, sincos_rrule, sincospi, sincospi_frule, + sincospi_rrule, sind, sind_frule, sind_rrule, sinh, sinh_frule, sinh_rrule, sinpi, sinpi_frule, + sinpi_rrule, sqrt, sqrt_frule, sqrt_rrule, tan, tan_frule, tan_rrule, tand, tand_frule, + tand_rrule, tanh, tanh_frule, tanh_rrule, }; #[cfg(test)] diff --git a/crates/chainrules/src/unary/mod.rs b/crates/chainrules/src/unary/mod.rs index 66abb4d..dced019 100644 --- a/crates/chainrules/src/unary/mod.rs +++ b/crates/chainrules/src/unary/mod.rs @@ -3,6 +3,7 @@ mod complex_parts; mod exp_log; mod hyperbolic; mod hyperbolic_extra; +mod nonsmooth; mod roots; mod smooth; mod trig; @@ -32,6 +33,10 @@ pub use hyperbolic::{ pub use hyperbolic_extra::{ coth, coth_frule, coth_rrule, csch, csch_frule, csch_rrule, sech, sech_frule, sech_rrule, }; +pub use nonsmooth::{ + ceil, ceil_frule, ceil_rrule, floor, floor_frule, floor_rrule, round, round_frule, round_rrule, + sign, sign_frule, sign_rrule, +}; pub use roots::{cbrt, cbrt_frule, cbrt_rrule, inv, inv_frule, inv_rrule}; pub use smooth::{ hypot, hypot_frule, hypot_rrule, pow, pow_frule, pow_rrule, sincos, sincos_frule, sincos_rrule, diff --git a/crates/chainrules/src/unary/nonsmooth.rs b/crates/chainrules/src/unary/nonsmooth.rs new file mode 100644 index 0000000..b84b707 --- /dev/null +++ b/crates/chainrules/src/unary/nonsmooth.rs @@ -0,0 +1,323 @@ +#![allow(dead_code)] + +use num_traits::Float; + +fn select_first_for_min(x: R, y: R) -> bool { + !x.is_nan() && (y.is_nan() || x < y) +} + +fn select_first_for_max(x: R, y: R) -> bool { + !x.is_nan() && (y.is_nan() || x > y) +} + +/// Primal `round`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point, including integer inputs. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::round; +/// +/// assert_eq!(round(1.4_f64), 1.0); +/// assert_eq!(round(1.5_f64), 2.0); +/// ``` +pub fn round(x: R) -> R { + x.round() +} + +/// Forward rule for `round`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::round_frule; +/// +/// let (y, dy) = round_frule(1.6_f64, 0.75); +/// assert_eq!(y, 2.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn round_frule(x: R, _dx: R) -> (R, R) { + (x.round(), R::zero()) +} + +/// Reverse rule for `round`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::round_rrule; +/// +/// assert_eq!(round_rrule(1.0_f64), 0.0); +/// ``` +pub fn round_rrule(_cotangent: R) -> R { + R::zero() +} + +/// Primal `floor`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::floor; +/// +/// assert_eq!(floor(1.9_f64), 1.0); +/// ``` +pub fn floor(x: R) -> R { + x.floor() +} + +/// Forward rule for `floor`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::floor_frule; +/// +/// let (y, dy) = floor_frule(1.6_f64, 0.75); +/// assert_eq!(y, 1.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn floor_frule(x: R, _dx: R) -> (R, R) { + (x.floor(), R::zero()) +} + +/// Reverse rule for `floor`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::floor_rrule; +/// +/// assert_eq!(floor_rrule(1.0_f64), 0.0); +/// ``` +pub fn floor_rrule(_cotangent: R) -> R { + R::zero() +} + +/// Primal `ceil`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ceil; +/// +/// assert_eq!(ceil(1.1_f64), 2.0); +/// ``` +pub fn ceil(x: R) -> R { + x.ceil() +} + +/// Forward rule for `ceil`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ceil_frule; +/// +/// let (y, dy) = ceil_frule(1.1_f64, 0.75); +/// assert_eq!(y, 2.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn ceil_frule(x: R, _dx: R) -> (R, R) { + (x.ceil(), R::zero()) +} + +/// Reverse rule for `ceil`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::ceil_rrule; +/// +/// assert_eq!(ceil_rrule(1.0_f64), 0.0); +/// ``` +pub fn ceil_rrule(_cotangent: R) -> R { + R::zero() +} + +/// Primal `sign`. +/// +/// The primal follows `Float::signum`, and the corresponding forward and +/// reverse rules use a zero-gradient policy at every point. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sign; +/// +/// assert_eq!(sign(-3.0_f64), -1.0); +/// assert_eq!(sign(0.0_f64), 1.0); +/// ``` +pub fn sign(x: R) -> R { + x.signum() +} + +/// Forward rule for `sign`. +/// +/// The tangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sign_frule; +/// +/// let (y, dy) = sign_frule(-2.0_f64, 0.75); +/// assert_eq!(y, -1.0); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn sign_frule(x: R, _dx: R) -> (R, R) { + (x.signum(), R::zero()) +} + +/// Reverse rule for `sign`. +/// +/// The cotangent is always zero. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::sign_rrule; +/// +/// assert_eq!(sign_rrule(1.0_f64), 0.0); +/// ``` +pub fn sign_rrule(_cotangent: R) -> R { + R::zero() +} + +/// Primal `min`. +/// +/// Tie behavior routes the derivative to the second argument. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min; +/// +/// assert_eq!(min(1.5_f64, 2.5_f64), 1.5); +/// assert_eq!(min(2.0_f64, 2.0_f64), 2.0); +/// ``` +pub fn min(x: R, y: R) -> R { + x.min(y) +} + +/// Forward rule for `min`. +/// +/// When `x == y`, the tangent comes from `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min_frule; +/// +/// let (z, dz) = min_frule(1.0_f64, 2.0_f64, 0.25, 0.5); +/// assert_eq!(z, 1.0); +/// assert_eq!(dz, 0.25); +/// ``` +pub fn min_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let z = x.min(y); + if select_first_for_min(x, y) { + (z, dx) + } else { + (z, dy) + } +} + +/// Reverse rule for `min`. +/// +/// When `x == y`, the cotangent goes to `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::min_rrule; +/// +/// let (dx, dy) = min_rrule(1.0_f64, 2.0_f64, 0.5); +/// assert_eq!(dx, 0.5); +/// assert_eq!(dy, 0.0); +/// ``` +pub fn min_rrule(x: R, y: R, cotangent: R) -> (R, R) { + if select_first_for_min(x, y) { + (cotangent, R::zero()) + } else { + (R::zero(), cotangent) + } +} + +/// Primal `max`. +/// +/// Tie behavior routes the derivative to the second argument. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max; +/// +/// assert_eq!(max(1.5_f64, 2.5_f64), 2.5); +/// assert_eq!(max(2.0_f64, 2.0_f64), 2.0); +/// ``` +pub fn max(x: R, y: R) -> R { + x.max(y) +} + +/// Forward rule for `max`. +/// +/// When `x == y`, the tangent comes from `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max_frule; +/// +/// let (z, dz) = max_frule(1.0_f64, 2.0_f64, 0.25, 0.5); +/// assert_eq!(z, 2.0); +/// assert_eq!(dz, 0.5); +/// ``` +pub fn max_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { + let z = x.max(y); + if select_first_for_max(x, y) { + (z, dx) + } else { + (z, dy) + } +} + +/// Reverse rule for `max`. +/// +/// When `x == y`, the cotangent goes to `y`. +/// +/// # Examples +/// +/// ```rust +/// use chainrules::max_rrule; +/// +/// let (dx, dy) = max_rrule(1.0_f64, 2.0_f64, 0.5); +/// assert_eq!(dx, 0.0); +/// assert_eq!(dy, 0.5); +/// ``` +pub fn max_rrule(x: R, y: R, cotangent: R) -> (R, R) { + if select_first_for_max(x, y) { + (cotangent, R::zero()) + } else { + (R::zero(), cotangent) + } +} diff --git a/crates/chainrules/tests/nonsmooth_scalar_tests.rs b/crates/chainrules/tests/nonsmooth_scalar_tests.rs new file mode 100644 index 0000000..5e242b2 --- /dev/null +++ b/crates/chainrules/tests/nonsmooth_scalar_tests.rs @@ -0,0 +1,117 @@ +use chainrules::{ + ceil, ceil_frule, ceil_rrule, floor, floor_frule, floor_rrule, max, max_frule, max_rrule, min, + min_frule, min_rrule, round, round_frule, round_rrule, sign, sign_frule, sign_rrule, +}; +use num_traits::Float; + +fn assert_close(actual: T, expected: T) +where + T: core::fmt::Debug + PartialEq, +{ + assert_eq!(actual, expected); +} + +fn assert_zero(actual: T) +where + T: core::fmt::Debug + PartialEq + Float, +{ + assert_eq!(actual, T::zero()); +} + +fn cast(value: f32) -> T +where + T: Float, +{ + T::from(value).expect("cast to float") +} + +fn check_nonsmooth_scalar_rules() +where + T: Copy + core::fmt::Debug + PartialEq + Float, +{ + let x = cast::(1.6_f32); + let y = cast::(-2.4_f32); + + assert_close(round(x), cast::(2.0_f32)); + assert_close(floor(x), cast::(1.0_f32)); + assert_close(ceil(y), cast::(-2.0_f32)); + assert_close(sign(y), cast::(-1.0_f32)); + + let (round_y, round_dy) = round_frule(x, cast::(7.0_f32)); + assert_close(round_y, cast::(2.0_f32)); + assert_zero(round_dy); + assert_zero(round_rrule(cast::(7.0_f32))); + + let (floor_y, floor_dy) = floor_frule(x, cast::(5.0_f32)); + assert_close(floor_y, cast::(1.0_f32)); + assert_zero(floor_dy); + assert_zero(floor_rrule(cast::(5.0_f32))); + + let (ceil_y, ceil_dy) = ceil_frule(y, cast::(11.0_f32)); + assert_close(ceil_y, cast::(-2.0_f32)); + assert_zero(ceil_dy); + assert_zero(ceil_rrule(cast::(11.0_f32))); + + let (sign_y, sign_dy) = sign_frule(y, cast::(3.0_f32)); + assert_close(sign_y, cast::(-1.0_f32)); + assert_zero(sign_dy); + assert_zero(sign_rrule(cast::(3.0_f32))); + + let (min_y, min_dy) = min_frule( + cast::(1.0_f32), + cast::(2.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close( + min(cast::(1.0_f32), cast::(2.0_f32)), + cast::(1.0_f32), + ); + assert_close(min_y, cast::(1.0_f32)); + assert_close(min_dy, cast::(4.0_f32)); + let (min_tie_y, min_tie_dy) = min_frule( + cast::(3.0_f32), + cast::(3.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close(min_tie_y, cast::(3.0_f32)); + assert_close(min_tie_dy, cast::(8.0_f32)); + let (min_dx, min_dy) = min_rrule(cast::(3.0_f32), cast::(3.0_f32), cast::(6.0_f32)); + assert_zero(min_dx); + assert_close(min_dy, cast::(6.0_f32)); + + let (max_y, max_dy) = max_frule( + cast::(1.0_f32), + cast::(2.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close( + max(cast::(1.0_f32), cast::(2.0_f32)), + cast::(2.0_f32), + ); + assert_close(max_y, cast::(2.0_f32)); + assert_close(max_dy, cast::(8.0_f32)); + let (max_tie_y, max_tie_dy) = max_frule( + cast::(3.0_f32), + cast::(3.0_f32), + cast::(4.0_f32), + cast::(8.0_f32), + ); + assert_close(max_tie_y, cast::(3.0_f32)); + assert_close(max_tie_dy, cast::(8.0_f32)); + let (max_dx, max_dy) = max_rrule(cast::(3.0_f32), cast::(3.0_f32), cast::(6.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(6.0_f32)); +} + +#[test] +fn nonsmooth_scalar_rules_match_expected_policy_for_f64() { + check_nonsmooth_scalar_rules::(); +} + +#[test] +fn nonsmooth_scalar_rules_match_expected_policy_for_f32() { + check_nonsmooth_scalar_rules::(); +} From 65360fd293dd2b647d7b66976f5b967d8f3e7bcb Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:42:02 +0900 Subject: [PATCH 22/32] fix: restore nonsmooth sign and rrule signatures --- crates/chainrules/src/unary/nonsmooth.rs | 162 +++--------------- .../tests/nonsmooth_scalar_tests.rs | 23 ++- 2 files changed, 40 insertions(+), 145 deletions(-) diff --git a/crates/chainrules/src/unary/nonsmooth.rs b/crates/chainrules/src/unary/nonsmooth.rs index b84b707..8f8b573 100644 --- a/crates/chainrules/src/unary/nonsmooth.rs +++ b/crates/chainrules/src/unary/nonsmooth.rs @@ -1,15 +1,5 @@ -#![allow(dead_code)] - use num_traits::Float; -fn select_first_for_min(x: R, y: R) -> bool { - !x.is_nan() && (y.is_nan() || x < y) -} - -fn select_first_for_max(x: R, y: R) -> bool { - !x.is_nan() && (y.is_nan() || x > y) -} - /// Primal `round`. /// /// The corresponding forward and reverse rules use a zero-gradient policy at @@ -53,9 +43,9 @@ pub fn round_frule(x: R, _dx: R) -> (R, R) { /// ```rust /// use chainrules::round_rrule; /// -/// assert_eq!(round_rrule(1.0_f64), 0.0); +/// assert_eq!(round_rrule(1.0_f64, 0.5), 0.0); /// ``` -pub fn round_rrule(_cotangent: R) -> R { +pub fn round_rrule(_x: R, _cotangent: R) -> R { R::zero() } @@ -101,9 +91,9 @@ pub fn floor_frule(x: R, _dx: R) -> (R, R) { /// ```rust /// use chainrules::floor_rrule; /// -/// assert_eq!(floor_rrule(1.0_f64), 0.0); +/// assert_eq!(floor_rrule(1.0_f64, 0.5), 0.0); /// ``` -pub fn floor_rrule(_cotangent: R) -> R { +pub fn floor_rrule(_x: R, _cotangent: R) -> R { R::zero() } @@ -149,16 +139,19 @@ pub fn ceil_frule(x: R, _dx: R) -> (R, R) { /// ```rust /// use chainrules::ceil_rrule; /// -/// assert_eq!(ceil_rrule(1.0_f64), 0.0); +/// assert_eq!(ceil_rrule(1.0_f64, 0.5), 0.0); /// ``` -pub fn ceil_rrule(_cotangent: R) -> R { +pub fn ceil_rrule(_x: R, _cotangent: R) -> R { R::zero() } /// Primal `sign`. /// -/// The primal follows `Float::signum`, and the corresponding forward and -/// reverse rules use a zero-gradient policy at every point. +/// The primal follows Julia-style `sign`: it returns signed zero for zero +/// inputs and otherwise `x / abs(x)`. +/// +/// The corresponding forward and reverse rules use a zero-gradient policy at +/// every point. /// /// # Examples /// @@ -166,10 +159,15 @@ pub fn ceil_rrule(_cotangent: R) -> R { /// use chainrules::sign; /// /// assert_eq!(sign(-3.0_f64), -1.0); -/// assert_eq!(sign(0.0_f64), 1.0); +/// assert_eq!(sign(0.0_f64), 0.0); +/// assert_eq!(sign(-0.0_f64).is_sign_negative(), true); /// ``` pub fn sign(x: R) -> R { - x.signum() + if x == R::zero() { + x + } else { + x / x.abs() + } } /// Forward rule for `sign`. @@ -186,7 +184,7 @@ pub fn sign(x: R) -> R { /// assert_eq!(dy, 0.0); /// ``` pub fn sign_frule(x: R, _dx: R) -> (R, R) { - (x.signum(), R::zero()) + (sign(x), R::zero()) } /// Reverse rule for `sign`. @@ -198,126 +196,8 @@ pub fn sign_frule(x: R, _dx: R) -> (R, R) { /// ```rust /// use chainrules::sign_rrule; /// -/// assert_eq!(sign_rrule(1.0_f64), 0.0); +/// assert_eq!(sign_rrule(1.0_f64, 0.5), 0.0); /// ``` -pub fn sign_rrule(_cotangent: R) -> R { +pub fn sign_rrule(_x: R, _cotangent: R) -> R { R::zero() } - -/// Primal `min`. -/// -/// Tie behavior routes the derivative to the second argument. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::min; -/// -/// assert_eq!(min(1.5_f64, 2.5_f64), 1.5); -/// assert_eq!(min(2.0_f64, 2.0_f64), 2.0); -/// ``` -pub fn min(x: R, y: R) -> R { - x.min(y) -} - -/// Forward rule for `min`. -/// -/// When `x == y`, the tangent comes from `y`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::min_frule; -/// -/// let (z, dz) = min_frule(1.0_f64, 2.0_f64, 0.25, 0.5); -/// assert_eq!(z, 1.0); -/// assert_eq!(dz, 0.25); -/// ``` -pub fn min_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { - let z = x.min(y); - if select_first_for_min(x, y) { - (z, dx) - } else { - (z, dy) - } -} - -/// Reverse rule for `min`. -/// -/// When `x == y`, the cotangent goes to `y`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::min_rrule; -/// -/// let (dx, dy) = min_rrule(1.0_f64, 2.0_f64, 0.5); -/// assert_eq!(dx, 0.5); -/// assert_eq!(dy, 0.0); -/// ``` -pub fn min_rrule(x: R, y: R, cotangent: R) -> (R, R) { - if select_first_for_min(x, y) { - (cotangent, R::zero()) - } else { - (R::zero(), cotangent) - } -} - -/// Primal `max`. -/// -/// Tie behavior routes the derivative to the second argument. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::max; -/// -/// assert_eq!(max(1.5_f64, 2.5_f64), 2.5); -/// assert_eq!(max(2.0_f64, 2.0_f64), 2.0); -/// ``` -pub fn max(x: R, y: R) -> R { - x.max(y) -} - -/// Forward rule for `max`. -/// -/// When `x == y`, the tangent comes from `y`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::max_frule; -/// -/// let (z, dz) = max_frule(1.0_f64, 2.0_f64, 0.25, 0.5); -/// assert_eq!(z, 2.0); -/// assert_eq!(dz, 0.5); -/// ``` -pub fn max_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { - let z = x.max(y); - if select_first_for_max(x, y) { - (z, dx) - } else { - (z, dy) - } -} - -/// Reverse rule for `max`. -/// -/// When `x == y`, the cotangent goes to `y`. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::max_rrule; -/// -/// let (dx, dy) = max_rrule(1.0_f64, 2.0_f64, 0.5); -/// assert_eq!(dx, 0.0); -/// assert_eq!(dy, 0.5); -/// ``` -pub fn max_rrule(x: R, y: R, cotangent: R) -> (R, R) { - if select_first_for_max(x, y) { - (cotangent, R::zero()) - } else { - (R::zero(), cotangent) - } -} diff --git a/crates/chainrules/tests/nonsmooth_scalar_tests.rs b/crates/chainrules/tests/nonsmooth_scalar_tests.rs index 5e242b2..8dfae91 100644 --- a/crates/chainrules/tests/nonsmooth_scalar_tests.rs +++ b/crates/chainrules/tests/nonsmooth_scalar_tests.rs @@ -18,6 +18,17 @@ where assert_eq!(actual, T::zero()); } +fn assert_negative_zero(actual: T) +where + T: core::fmt::Debug + PartialEq + Float, +{ + assert_eq!(actual, T::zero()); + assert!( + actual.is_sign_negative(), + "expected negative zero, got {actual:?}" + ); +} + fn cast(value: f32) -> T where T: Float, @@ -31,31 +42,35 @@ where { let x = cast::(1.6_f32); let y = cast::(-2.4_f32); + let zero = cast::(0.0_f32); + let neg_zero = cast::(-0.0_f32); assert_close(round(x), cast::(2.0_f32)); assert_close(floor(x), cast::(1.0_f32)); assert_close(ceil(y), cast::(-2.0_f32)); assert_close(sign(y), cast::(-1.0_f32)); + assert_zero(sign(zero)); + assert_negative_zero(sign(neg_zero)); let (round_y, round_dy) = round_frule(x, cast::(7.0_f32)); assert_close(round_y, cast::(2.0_f32)); assert_zero(round_dy); - assert_zero(round_rrule(cast::(7.0_f32))); + assert_zero(round_rrule(x, cast::(7.0_f32))); let (floor_y, floor_dy) = floor_frule(x, cast::(5.0_f32)); assert_close(floor_y, cast::(1.0_f32)); assert_zero(floor_dy); - assert_zero(floor_rrule(cast::(5.0_f32))); + assert_zero(floor_rrule(x, cast::(5.0_f32))); let (ceil_y, ceil_dy) = ceil_frule(y, cast::(11.0_f32)); assert_close(ceil_y, cast::(-2.0_f32)); assert_zero(ceil_dy); - assert_zero(ceil_rrule(cast::(11.0_f32))); + assert_zero(ceil_rrule(y, cast::(11.0_f32))); let (sign_y, sign_dy) = sign_frule(y, cast::(3.0_f32)); assert_close(sign_y, cast::(-1.0_f32)); assert_zero(sign_dy); - assert_zero(sign_rrule(cast::(3.0_f32))); + assert_zero(sign_rrule(y, cast::(3.0_f32))); let (min_y, min_dy) = min_frule( cast::(1.0_f32), From 392f24eff79e5a021bed7676628902b3bd50e29d Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:45:04 +0900 Subject: [PATCH 23/32] test: add nonsmooth reverse-rule cases --- crates/chainrules/tests/nonsmooth_scalar_tests.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/chainrules/tests/nonsmooth_scalar_tests.rs b/crates/chainrules/tests/nonsmooth_scalar_tests.rs index 8dfae91..e018c97 100644 --- a/crates/chainrules/tests/nonsmooth_scalar_tests.rs +++ b/crates/chainrules/tests/nonsmooth_scalar_tests.rs @@ -95,6 +95,9 @@ where let (min_dx, min_dy) = min_rrule(cast::(3.0_f32), cast::(3.0_f32), cast::(6.0_f32)); assert_zero(min_dx); assert_close(min_dy, cast::(6.0_f32)); + let (min_dx, min_dy) = min_rrule(cast::(2.0_f32), cast::(3.0_f32), cast::(5.0_f32)); + assert_close(min_dx, cast::(5.0_f32)); + assert_zero(min_dy); let (max_y, max_dy) = max_frule( cast::(1.0_f32), @@ -119,6 +122,9 @@ where let (max_dx, max_dy) = max_rrule(cast::(3.0_f32), cast::(3.0_f32), cast::(6.0_f32)); assert_zero(max_dx); assert_close(max_dy, cast::(6.0_f32)); + let (max_dx, max_dy) = max_rrule(cast::(2.0_f32), cast::(3.0_f32), cast::(5.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(5.0_f32)); } #[test] From 7c2cc61aacc837e9f4bf967f29f16809933bec15 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 12:53:46 +0900 Subject: [PATCH 24/32] fix: cover nonsmooth infinities and edge policies --- crates/chainrules/README.md | 3 ++ crates/chainrules/src/binary_special.rs | 16 +++++--- crates/chainrules/src/tests/organization.rs | 10 +++++ crates/chainrules/src/unary/nonsmooth.rs | 5 ++- .../tests/nonsmooth_scalar_tests.rs | 38 +++++++++++++++++++ 5 files changed, 64 insertions(+), 8 deletions(-) diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index cb4e71b..dde2ab8 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -32,6 +32,7 @@ Current shipped scalar families: - hyperbolic: `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, `atanh` - Julia-compatible trigonometric helpers: `sec`, `csc`, `cot`, `sinpi`, `cospi`, `sincospi`, `sind`, `cosd`, `tand` - Julia-compatible hyperbolic helpers: `sech`, `csch`, `coth` +- non-smooth real helpers: `round`, `floor`, `ceil`, `sign`, `min`, `max` - smooth helpers: `cbrt`, `inv`, `exp2`, `exp10`, `log2`, `log10`, `hypot`, `pow`, `sincos`, `tan` - complex and projection helpers: `conj`, `abs`, `abs2`, `angle`, `real`, `imag`, `complex`, `handle_r_to_c_f32`, `handle_r_to_c_f64` - real-valued binary helpers: `atan2` @@ -49,6 +50,8 @@ repository-local tests: real/complex behavior - `tests/julia_compat_trig_tests.rs` covers Julia migration helpers, including landmark real inputs and representative Complex64 behavior +- `tests/nonsmooth_scalar_tests.rs` covers the documented zero-gradient and + tie-routing policies for non-smooth helpers - `tests/complex_helper_tests.rs` covers the projection helpers and complex constructor surface - `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from diff --git a/crates/chainrules/src/binary_special.rs b/crates/chainrules/src/binary_special.rs index e06b23f..98fa04d 100644 --- a/crates/chainrules/src/binary_special.rs +++ b/crates/chainrules/src/binary_special.rs @@ -57,8 +57,9 @@ pub fn hypot_rrule(x: R, y: R, cotangent: R) -> (R, R) { /// Primal `min`. /// -/// The primal follows `Float::min`; tie behavior routes the derivative to the -/// second argument. +/// The primal follows `Float::min`. For differentiation, ties route the +/// tangent/cotangent to the second argument. If exactly one input is `NaN`, +/// the non-`NaN` input receives the gradient. /// /// # Examples /// @@ -96,7 +97,8 @@ pub fn min_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { /// Reverse rule for `min`. /// -/// When `x == y`, the cotangent goes to `y`. +/// When `x == y`, the cotangent goes to `y`. If exactly one input is `NaN`, +/// the non-`NaN` input receives the cotangent. /// /// # Examples /// @@ -117,8 +119,9 @@ pub fn min_rrule(x: R, y: R, cotangent: R) -> (R, R) { /// Primal `max`. /// -/// The primal follows `Float::max`; tie behavior routes the derivative to the -/// second argument. +/// The primal follows `Float::max`. For differentiation, ties route the +/// tangent/cotangent to the second argument. If exactly one input is `NaN`, +/// the non-`NaN` input receives the gradient. /// /// # Examples /// @@ -156,7 +159,8 @@ pub fn max_frule(x: R, y: R, dx: R, dy: R) -> (R, R) { /// Reverse rule for `max`. /// -/// When `x == y`, the cotangent goes to `y`. +/// When `x == y`, the cotangent goes to `y`. If exactly one input is `NaN`, +/// the non-`NaN` input receives the cotangent. /// /// # Examples /// diff --git a/crates/chainrules/src/tests/organization.rs b/crates/chainrules/src/tests/organization.rs index c66ae0c..ff122bb 100644 --- a/crates/chainrules/src/tests/organization.rs +++ b/crates/chainrules/src/tests/organization.rs @@ -11,6 +11,11 @@ fn assert_line_count(path: &str, content: &str, max_lines: usize) { fn chainrules_modules_stay_under_size_guideline() { assert_line_count("../lib.rs", include_str!("../lib.rs"), 120); assert_line_count("../binary.rs", include_str!("../binary.rs"), 260); + assert_line_count( + "../binary_special.rs", + include_str!("../binary_special.rs"), + 220, + ); assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 80); assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 80); assert_line_count( @@ -34,6 +39,11 @@ fn chainrules_modules_stay_under_size_guideline() { include_str!("../unary/hyperbolic_extra.rs"), 180, ); + assert_line_count( + "../unary/nonsmooth.rs", + include_str!("../unary/nonsmooth.rs"), + 220, + ); assert_line_count( "../unary/smooth.rs", include_str!("../unary/smooth.rs"), diff --git a/crates/chainrules/src/unary/nonsmooth.rs b/crates/chainrules/src/unary/nonsmooth.rs index 8f8b573..7c2c958 100644 --- a/crates/chainrules/src/unary/nonsmooth.rs +++ b/crates/chainrules/src/unary/nonsmooth.rs @@ -148,7 +148,8 @@ pub fn ceil_rrule(_x: R, _cotangent: R) -> R { /// Primal `sign`. /// /// The primal follows Julia-style `sign`: it returns signed zero for zero -/// inputs and otherwise `x / abs(x)`. +/// inputs, `+1`/`-1` for positive/negative infinities, and `x.signum()` +/// otherwise. /// /// The corresponding forward and reverse rules use a zero-gradient policy at /// every point. @@ -166,7 +167,7 @@ pub fn sign(x: R) -> R { if x == R::zero() { x } else { - x / x.abs() + x.signum() } } diff --git a/crates/chainrules/tests/nonsmooth_scalar_tests.rs b/crates/chainrules/tests/nonsmooth_scalar_tests.rs index e018c97..6636b40 100644 --- a/crates/chainrules/tests/nonsmooth_scalar_tests.rs +++ b/crates/chainrules/tests/nonsmooth_scalar_tests.rs @@ -44,6 +44,8 @@ where let y = cast::(-2.4_f32); let zero = cast::(0.0_f32); let neg_zero = cast::(-0.0_f32); + let inf = T::infinity(); + let neg_inf = T::neg_infinity(); assert_close(round(x), cast::(2.0_f32)); assert_close(floor(x), cast::(1.0_f32)); @@ -51,6 +53,8 @@ where assert_close(sign(y), cast::(-1.0_f32)); assert_zero(sign(zero)); assert_negative_zero(sign(neg_zero)); + assert_close(sign(inf), T::one()); + assert_close(sign(neg_inf), -T::one()); let (round_y, round_dy) = round_frule(x, cast::(7.0_f32)); assert_close(round_y, cast::(2.0_f32)); @@ -98,6 +102,23 @@ where let (min_dx, min_dy) = min_rrule(cast::(2.0_f32), cast::(3.0_f32), cast::(5.0_f32)); assert_close(min_dx, cast::(5.0_f32)); assert_zero(min_dy); + let (min_tie_y, min_tie_dy) = min_frule(neg_zero, zero, cast::(4.0_f32), cast::(8.0_f32)); + assert_zero(min_tie_y); + assert_close(min_tie_dy, cast::(8.0_f32)); + let (min_dx, min_dy) = min_rrule(neg_zero, zero, cast::(5.0_f32)); + assert_zero(min_dx); + assert_close(min_dy, cast::(5.0_f32)); + let (min_nan_y, min_nan_dy) = min_frule( + cast::(2.0_f32), + T::nan(), + cast::(6.0_f32), + cast::(9.0_f32), + ); + assert_close(min_nan_y, cast::(2.0_f32)); + assert_close(min_nan_dy, cast::(6.0_f32)); + let (min_dx, min_dy) = min_rrule(cast::(2.0_f32), T::nan(), cast::(5.0_f32)); + assert_close(min_dx, cast::(5.0_f32)); + assert_zero(min_dy); let (max_y, max_dy) = max_frule( cast::(1.0_f32), @@ -125,6 +146,23 @@ where let (max_dx, max_dy) = max_rrule(cast::(2.0_f32), cast::(3.0_f32), cast::(5.0_f32)); assert_zero(max_dx); assert_close(max_dy, cast::(5.0_f32)); + let (max_tie_y, max_tie_dy) = max_frule(neg_zero, zero, cast::(4.0_f32), cast::(8.0_f32)); + assert_zero(max_tie_y); + assert_close(max_tie_dy, cast::(8.0_f32)); + let (max_dx, max_dy) = max_rrule(neg_zero, zero, cast::(5.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(5.0_f32)); + let (max_nan_y, max_nan_dy) = max_frule( + T::nan(), + cast::(3.0_f32), + cast::(6.0_f32), + cast::(9.0_f32), + ); + assert_close(max_nan_y, cast::(3.0_f32)); + assert_close(max_nan_dy, cast::(9.0_f32)); + let (max_dx, max_dy) = max_rrule(T::nan(), cast::(3.0_f32), cast::(5.0_f32)); + assert_zero(max_dx); + assert_close(max_dy, cast::(5.0_f32)); } #[test] From 1fa4c44347ad54e7b81ea20650eeebb0cf1b0cb9 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 13:03:32 +0900 Subject: [PATCH 25/32] test: expand scalar oracle coverage --- crates/chainrules/src/tests/organization.rs | 34 +- crates/chainrules/tests/common.rs | 180 ++++++++++ .../chainrules/tests/complex_helper_tests.rs | 66 +++- .../tests/julia_compat_trig_tests.rs | 309 ++++++++++++++---- .../chainrules/tests/oracle_scalar_rules.rs | 195 +++++------ crates/chainrules/tests/smooth_basis_tests.rs | 172 +++++++--- 6 files changed, 678 insertions(+), 278 deletions(-) create mode 100644 crates/chainrules/tests/common.rs diff --git a/crates/chainrules/src/tests/organization.rs b/crates/chainrules/src/tests/organization.rs index ff122bb..0c5976f 100644 --- a/crates/chainrules/src/tests/organization.rs +++ b/crates/chainrules/src/tests/organization.rs @@ -9,30 +9,30 @@ fn assert_line_count(path: &str, content: &str, max_lines: usize) { // Do not delete or weaken this test: it protects the split scalar AD rule modules that keep this crate extensible. #[test] fn chainrules_modules_stay_under_size_guideline() { - assert_line_count("../lib.rs", include_str!("../lib.rs"), 120); - assert_line_count("../binary.rs", include_str!("../binary.rs"), 260); + assert_line_count("../lib.rs", include_str!("../lib.rs"), 60); + assert_line_count("../binary.rs", include_str!("../binary.rs"), 220); assert_line_count( "../binary_special.rs", include_str!("../binary_special.rs"), - 220, + 200, ); - assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 80); - assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 80); + assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 60); + assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 40); assert_line_count( "../unary/exp_log.rs", include_str!("../unary/exp_log.rs"), - 120, + 130, ); - assert_line_count("../unary/trig.rs", include_str!("../unary/trig.rs"), 140); + assert_line_count("../unary/trig.rs", include_str!("../unary/trig.rs"), 135); assert_line_count( "../unary/hyperbolic.rs", include_str!("../unary/hyperbolic.rs"), - 140, + 120, ); assert_line_count( "../unary/trig_extra.rs", include_str!("../unary/trig_extra.rs"), - 500, + 520, ); assert_line_count( "../unary/hyperbolic_extra.rs", @@ -44,26 +44,22 @@ fn chainrules_modules_stay_under_size_guideline() { include_str!("../unary/nonsmooth.rs"), 220, ); - assert_line_count( - "../unary/smooth.rs", - include_str!("../unary/smooth.rs"), - 120, - ); - assert_line_count("../power.rs", include_str!("../power.rs"), 180); - assert_line_count("../real_ops.rs", include_str!("../real_ops.rs"), 120); + assert_line_count("../unary/smooth.rs", include_str!("../unary/smooth.rs"), 30); + assert_line_count("../power.rs", include_str!("../power.rs"), 170); + assert_line_count("../real_ops.rs", include_str!("../real_ops.rs"), 70); assert_line_count( "../scalar_ad/mod.rs", include_str!("../scalar_ad/mod.rs"), - 220, + 180, ); assert_line_count( "../scalar_ad/real.rs", include_str!("../scalar_ad/real.rs"), - 220, + 160, ); assert_line_count( "../scalar_ad/complex.rs", include_str!("../scalar_ad/complex.rs"), - 220, + 170, ); } diff --git a/crates/chainrules/tests/common.rs b/crates/chainrules/tests/common.rs new file mode 100644 index 0000000..115ec86 --- /dev/null +++ b/crates/chainrules/tests/common.rs @@ -0,0 +1,180 @@ +#![allow(dead_code)] + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; + +use num_complex::Complex64; +use serde_json::Value; + +pub struct UnaryOracleCase { + pub op: &'static str, + pub frule: fn(T, T) -> (T, T), + pub rrule: fn(T, T, T) -> T, +} + +pub trait OracleScalar: Copy + core::fmt::Debug { + fn dtype() -> &'static str; + fn from_json(value: &Value, path: &str) -> Self; + fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str); +} + +impl OracleScalar for f64 { + fn dtype() -> &'static str { + "float64" + } + + fn from_json(value: &Value, path: &str) -> Self { + scalar_f64(value, path) + } + + fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str) { + assert_close_f64(actual, expected, atol, rtol, label); + } +} + +impl OracleScalar for Complex64 { + fn dtype() -> &'static str { + "complex128" + } + + fn from_json(value: &Value, path: &str) -> Self { + scalar_complex64(value, path) + } + + fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str) { + assert_close_complex64(actual, expected, atol, rtol, label); + } +} + +fn oracle_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .join("third_party") + .join("tensor-ad-oracles") +} + +pub fn first_successful_case(op: &str, dtype: &str) -> Value { + let path = oracle_root().join("cases").join(op).join("identity.jsonl"); + let file = File::open(&path).unwrap_or_else(|err| panic!("open {}: {err}", path.display())); + let reader = BufReader::new(file); + + for line in reader.lines() { + let line = line.unwrap_or_else(|err| panic!("read {}: {err}", path.display())); + let value: Value = serde_json::from_str(&line) + .unwrap_or_else(|err| panic!("parse {}: {err}", path.display())); + let case_dtype = value["dtype"].as_str(); + let behavior = value["expected_behavior"].as_str(); + if case_dtype == Some(dtype) && behavior == Some("success") { + return value; + } + } + + panic!("no successful {dtype} case found in {}", path.display()); +} + +pub fn scalar_f64(value: &Value, path: &str) -> f64 { + value + .as_f64() + .unwrap_or_else(|| panic!("expected float64 at {path}, got {value}")) +} + +pub fn scalar_complex64(value: &Value, path: &str) -> Complex64 { + let data = value + .as_array() + .unwrap_or_else(|| panic!("expected complex128 pair at {path}, got {value}")); + assert!( + data.len() == 2, + "expected complex128 pair at {path}, got {value}" + ); + Complex64::new(scalar_f64(&data[0], path), scalar_f64(&data[1], path)) +} + +pub fn assert_close_f64(actual: f64, expected: f64, atol: f64, rtol: f64, label: &str) { + let tol = atol + rtol * expected.abs().max(actual.abs()); + assert!( + (actual - expected).abs() <= tol, + "{label}: actual={actual}, expected={expected}, atol={atol}, rtol={rtol}", + ); +} + +pub fn assert_close_complex64( + actual: Complex64, + expected: Complex64, + atol: f64, + rtol: f64, + label: &str, +) { + let tol = atol + rtol * expected.norm().max(actual.norm()); + assert!( + (actual - expected).norm() <= tol, + "{label}: actual={actual:?}, expected={expected:?}, atol={atol}, rtol={rtol}", + ); +} + +pub fn run_unary_oracle_cases(cases: &[UnaryOracleCase]) { + for case in cases { + let oracle = first_successful_case(case.op, T::dtype()); + let input = T::from_json(&oracle["inputs"]["a"]["data"][0], "inputs.a.data[0]"); + let probe = &oracle["probes"][0]; + let tangent = T::from_json( + &probe["direction"]["a"]["data"][0], + "probes[0].direction.a.data[0]", + ); + let cotangent = T::from_json( + &probe["cotangent"]["value"]["data"][0], + "probes[0].cotangent.value.data[0]", + ); + let expected_jvp = T::from_json( + &probe["pytorch_ref"]["jvp"]["value"]["data"][0], + "probes[0].pytorch_ref.jvp.value.data[0]", + ); + let expected_vjp = T::from_json( + &probe["pytorch_ref"]["vjp"]["a"]["data"][0], + "probes[0].pytorch_ref.vjp.a.data[0]", + ); + let atol = scalar_f64( + &oracle["comparison"]["first_order"]["atol"], + "comparison.first_order.atol", + ); + let rtol = scalar_f64( + &oracle["comparison"]["first_order"]["rtol"], + "comparison.first_order.rtol", + ); + + let (result, actual_jvp) = (case.frule)(input, tangent); + let actual_vjp = (case.rrule)(input, result, cotangent); + + T::assert_close(actual_jvp, expected_jvp, atol, rtol, case.op); + T::assert_close(actual_vjp, expected_vjp, atol, rtol, case.op); + } +} + +pub fn run_unary_oracle_reverse_cases_complex64(cases: &[UnaryOracleCase]) { + for case in cases { + let oracle = first_successful_case(case.op, Complex64::dtype()); + let input = scalar_complex64(&oracle["inputs"]["a"]["data"][0], "inputs.a.data[0]"); + let probe = &oracle["probes"][0]; + let cotangent = scalar_complex64( + &probe["cotangent"]["value"]["data"][0], + "probes[0].cotangent.value.data[0]", + ); + let expected_vjp = scalar_complex64( + &probe["pytorch_ref"]["vjp"]["a"]["data"][0], + "probes[0].pytorch_ref.vjp.a.data[0]", + ); + let atol = scalar_f64( + &oracle["comparison"]["first_order"]["atol"], + "comparison.first_order.atol", + ); + let rtol = scalar_f64( + &oracle["comparison"]["first_order"]["rtol"], + "comparison.first_order.rtol", + ); + + let (result, _) = (case.frule)(input, cotangent); + let actual_vjp = (case.rrule)(input, result, cotangent); + Complex64::assert_close(actual_vjp, expected_vjp, atol, rtol, case.op); + } +} diff --git a/crates/chainrules/tests/complex_helper_tests.rs b/crates/chainrules/tests/complex_helper_tests.rs index 5883afd..8b8dfc9 100644 --- a/crates/chainrules/tests/complex_helper_tests.rs +++ b/crates/chainrules/tests/complex_helper_tests.rs @@ -1,9 +1,13 @@ +mod common; + use chainrules::{ abs, abs2, abs2_frule, abs2_rrule, angle, angle_rrule, complex, imag, imag_rrule, real, real_rrule, }; use num_complex::Complex64; +use common::{assert_close_complex64, assert_close_f64}; + #[test] fn complex_helpers_match_expected_formulas() { let x = 3.0_f64; @@ -12,25 +16,55 @@ fn complex_helpers_match_expected_formulas() { let constructed: Complex64 = complex(3.0, 4.0); assert_eq!(constructed, z); - assert_eq!(abs(x), 3.0); - assert_eq!(abs2(x), 9.0); - assert_eq!(real(x), 3.0); - assert_eq!(imag(x), 0.0); - assert_eq!(angle(x), 0.0_f64.atan2(x)); - assert_eq!(abs(z), 5.0); - assert_eq!(abs2(z), 25.0); - assert_eq!(real(z), 3.0); - assert_eq!(imag(z), 4.0); - assert_eq!(angle(z), z.arg()); + assert_close_f64(abs(x), 3.0, 1.0e-12, 0.0, "abs(x)"); + assert_close_f64(abs2(x), 9.0, 1.0e-12, 0.0, "abs2(x)"); + assert_close_f64(real(x), 3.0, 1.0e-12, 0.0, "real(x)"); + assert_close_f64(imag(x), 0.0, 1.0e-12, 0.0, "imag(x)"); + assert_close_f64(angle(x), 0.0_f64.atan2(x), 1.0e-12, 0.0, "angle(x)"); + assert_close_f64(abs(z), 5.0, 1.0e-12, 0.0, "abs(z)"); + assert_close_f64(abs2(z), 25.0, 1.0e-12, 0.0, "abs2(z)"); + assert_close_f64(real(z), 3.0, 1.0e-12, 0.0, "real(z)"); + assert_close_f64(imag(z), 4.0, 1.0e-12, 0.0, "imag(z)"); + assert_close_f64(angle(z), z.arg(), 1.0e-12, 0.0, "angle(z)"); let (abs2_y, abs2_dy) = abs2_frule(z, dz); - assert_eq!(abs2_y, 25.0); - assert_eq!(abs2_dy, 2.0 * (z.re * dz.re + z.im * dz.im)); + assert_close_f64(abs2_y, 25.0, 1.0e-12, 0.0, "abs2.y"); + assert_close_f64( + abs2_dy, + 2.0 * (z.re * dz.re + z.im * dz.im), + 1.0e-12, + 0.0, + "abs2.dy", + ); - assert_eq!(abs2_rrule(z, 1.25), Complex64::new(7.5, 10.0)); + assert_close_complex64( + abs2_rrule(z, 1.25), + Complex64::new(7.5, 10.0), + 1.0e-12, + 0.0, + "abs2.rrule", + ); let real_grad: Complex64 = real_rrule(2.0); - assert_eq!(real_grad, Complex64::new(2.0, 0.0)); + assert_close_complex64( + real_grad, + Complex64::new(2.0, 0.0), + 1.0e-12, + 0.0, + "real.rrule", + ); let imag_grad: Complex64 = imag_rrule(2.0); - assert_eq!(imag_grad, Complex64::new(0.0, 2.0)); - assert_eq!(angle_rrule(z, 1.0), Complex64::new(-0.16, 0.12)); + assert_close_complex64( + imag_grad, + Complex64::new(0.0, 2.0), + 1.0e-12, + 0.0, + "imag.rrule", + ); + assert_close_complex64( + angle_rrule(z, 1.0), + Complex64::new(-0.16, 0.12), + 1.0e-12, + 0.0, + "angle.rrule", + ); } diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs index 4a8cf0d..8f32195 100644 --- a/crates/chainrules/tests/julia_compat_trig_tests.rs +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -1,15 +1,14 @@ +mod common; + use chainrules::{ cosd, cosd_frule, cosd_rrule, cospi, cospi_frule, cospi_rrule, cot, cot_frule, cot_rrule, coth, coth_frule, coth_rrule, csc, csc_frule, csc_rrule, csch, csch_frule, csch_rrule, sec, sec_frule, sec_rrule, sech, sech_frule, sech_rrule, sincospi, sincospi_frule, sincospi_rrule, sind, sind_frule, sind_rrule, sinpi, sinpi_frule, sinpi_rrule, tand, tand_frule, tand_rrule, }; +use common::{assert_close_complex64, assert_close_f64}; use num_complex::Complex64; -fn assert_complex_close(actual: Complex64, expected: Complex64) { - assert!((actual - expected).norm() < 1e-12); -} - #[test] fn julia_compat_landmark_real_inputs_match_julia_style_values() { assert_eq!(sinpi(1.0_f64), 0.0_f64); @@ -29,20 +28,44 @@ fn julia_compat_landmark_real_inputs_match_julia_style_values() { #[test] fn julia_compat_primal_helpers_match_expected_values() { let x = 0.25_f64; - assert!((sec(x) - (1.0 / x.cos())).abs() < 1e-12); - assert!((csc(x) - (1.0 / x.sin())).abs() < 1e-12); - assert!((cot(x) - (1.0 / x.tan())).abs() < 1e-12); - assert!((sinpi(x) - (std::f64::consts::PI * x).sin()).abs() < 1e-12); - assert!((cospi(x) - (std::f64::consts::PI * x).cos()).abs() < 1e-12); + assert_close_f64(sec(x), 1.0 / x.cos(), 1e-12, 0.0, "sec"); + assert_close_f64(csc(x), 1.0 / x.sin(), 1e-12, 0.0, "csc"); + assert_close_f64(cot(x), 1.0 / x.tan(), 1e-12, 0.0, "cot"); + assert_close_f64( + sinpi(x), + (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sinpi", + ); + assert_close_f64( + cospi(x), + (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "cospi", + ); let (s, c) = sincospi(x); - assert!((s - (std::f64::consts::PI * x).sin()).abs() < 1e-12); - assert!((c - (std::f64::consts::PI * x).cos()).abs() < 1e-12); - assert!((sind(30.0_f64) - 0.5_f64).abs() < 1e-12); - assert!((cosd(60.0_f64) - 0.5_f64).abs() < 1e-12); - assert!((tand(45.0_f64) - 1.0_f64).abs() < 1e-12); - assert!((sech(x) - (1.0 / x.cosh())).abs() < 1e-12); - assert!((csch(x) - (1.0 / x.sinh())).abs() < 1e-12); - assert!((coth(x) - (1.0 / x.tanh())).abs() < 1e-12); + assert_close_f64( + s, + (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sincospi.sin", + ); + assert_close_f64( + c, + (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sincospi.cos", + ); + assert_close_f64(sind(30.0_f64), 0.5_f64, 1e-12, 0.0, "sind"); + assert_close_f64(cosd(60.0_f64), 0.5_f64, 1e-12, 0.0, "cosd"); + assert_close_f64(tand(45.0_f64), 1.0_f64, 1e-12, 0.0, "tand"); + assert_close_f64(sech(x), 1.0 / x.cosh(), 1e-12, 0.0, "sech"); + assert_close_f64(csch(x), 1.0 / x.sinh(), 1e-12, 0.0, "csch"); + assert_close_f64(coth(x), 1.0 / x.tanh(), 1e-12, 0.0, "coth"); } #[test] @@ -51,80 +74,207 @@ fn julia_compat_derivative_helpers_match_expected_values() { let g = 1.0_f64; let (_, dsec) = sec_frule(x, g); - assert!((dsec - (x.sin() / x.cos().powi(2))).abs() < 1e-12); - assert!((sec_rrule(x, g) - (x.sin() / x.cos().powi(2))).abs() < 1e-12); + assert_close_f64(dsec, x.sin() / x.cos().powi(2), 1e-12, 0.0, "sec_frule"); + assert_close_f64( + sec_rrule(x, g), + x.sin() / x.cos().powi(2), + 1e-12, + 0.0, + "sec_rrule", + ); let (_, dsinpi_landmark) = sinpi_frule(1.0_f64, g); - assert!((dsinpi_landmark + std::f64::consts::PI).abs() < 1e-12); - assert!((sinpi_rrule(1.0_f64, g) + std::f64::consts::PI).abs() < 1e-12); + assert_close_f64( + dsinpi_landmark, + -std::f64::consts::PI, + 1e-12, + 0.0, + "sinpi_frule landmark", + ); + assert_close_f64( + sinpi_rrule(1.0_f64, g), + -std::f64::consts::PI, + 1e-12, + 0.0, + "sinpi_rrule landmark", + ); let (_, dcospi_landmark) = cospi_frule(0.5_f64, g); - assert!((dcospi_landmark + std::f64::consts::PI).abs() < 1e-12); - assert!((cospi_rrule(0.5_f64, g) + std::f64::consts::PI).abs() < 1e-12); + assert_close_f64( + dcospi_landmark, + -std::f64::consts::PI, + 1e-12, + 0.0, + "cospi_frule landmark", + ); + assert_close_f64( + cospi_rrule(0.5_f64, g), + -std::f64::consts::PI, + 1e-12, + 0.0, + "cospi_rrule landmark", + ); let (_, dsinpi) = sinpi_frule(x, g); - assert!((dsinpi - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12); - assert!( - (sinpi_rrule(x, g) - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12 + assert_close_f64( + dsinpi, + std::f64::consts::PI * (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sinpi_frule", + ); + assert_close_f64( + sinpi_rrule(x, g), + std::f64::consts::PI * (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sinpi_rrule", ); let (_, dtand) = tand_frule(45.0_f64, g); - assert!((dtand - std::f64::consts::PI / 180.0 * 2.0).abs() < 1e-12); - assert!((tand_rrule(45.0_f64, g) - std::f64::consts::PI / 180.0 * 2.0).abs() < 1e-12); + assert_close_f64( + dtand, + std::f64::consts::PI / 180.0 * 2.0, + 1e-12, + 0.0, + "tand_frule", + ); + assert_close_f64( + tand_rrule(45.0_f64, g), + std::f64::consts::PI / 180.0 * 2.0, + 1e-12, + 0.0, + "tand_rrule", + ); let (_, dsech) = sech_frule(x, g); let sech_x: f64 = 1.0 / x.cosh(); - assert!((dsech - (-sech_x * x.tanh())).abs() < 1e-12); - assert!((sech_rrule(x, g) - (-sech_x * x.tanh())).abs() < 1e-12); + assert_close_f64(dsech, -sech_x * x.tanh(), 1e-12, 0.0, "sech_frule"); + assert_close_f64( + sech_rrule(x, g), + -sech_x * x.tanh(), + 1e-12, + 0.0, + "sech_rrule", + ); let (_, dcsc) = csc_frule(x, g); - assert!((dcsc - (-(x.cos() / x.sin().powi(2)))).abs() < 1e-12); - assert!((csc_rrule(x, g) - (-(x.cos() / x.sin().powi(2)))).abs() < 1e-12); + assert_close_f64(dcsc, -(x.cos() / x.sin().powi(2)), 1e-12, 0.0, "csc_frule"); + assert_close_f64( + csc_rrule(x, g), + -(x.cos() / x.sin().powi(2)), + 1e-12, + 0.0, + "csc_rrule", + ); let (_, dcot) = cot_frule(x, g); - assert!((dcot - (-(1.0 / x.sin().powi(2)))).abs() < 1e-12); - assert!((cot_rrule(x, g) - (-(1.0 / x.sin().powi(2)))).abs() < 1e-12); + assert_close_f64(dcot, -(1.0 / x.sin().powi(2)), 1e-12, 0.0, "cot_frule"); + assert_close_f64( + cot_rrule(x, g), + -(1.0 / x.sin().powi(2)), + 1e-12, + 0.0, + "cot_rrule", + ); let (_, dcsch) = csch_frule(x, g); let csch_x: f64 = 1.0 / x.sinh(); - assert!((dcsch - (-csch_x * x.cosh() / x.sinh())).abs() < 1e-12); - assert!((csch_rrule(x, g) - (-csch_x * x.cosh() / x.sinh())).abs() < 1e-12); + assert_close_f64( + dcsch, + -csch_x * x.cosh() / x.sinh(), + 1e-12, + 0.0, + "csch_frule", + ); + assert_close_f64( + csch_rrule(x, g), + -csch_x * x.cosh() / x.sinh(), + 1e-12, + 0.0, + "csch_rrule", + ); let (_, dcoth) = coth_frule(x, g); - assert!((dcoth - (-(1.0 / x.sinh().powi(2)))).abs() < 1e-12); - assert!((coth_rrule(x, g) - (-(1.0 / x.sinh().powi(2)))).abs() < 1e-12); + assert_close_f64(dcoth, -(1.0 / x.sinh().powi(2)), 1e-12, 0.0, "coth_frule"); + assert_close_f64( + coth_rrule(x, g), + -(1.0 / x.sinh().powi(2)), + 1e-12, + 0.0, + "coth_rrule", + ); let (_, dsind) = sind_frule(30.0_f64, g); - assert!((dsind - (std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos())).abs() < 1e-12); - assert!( - (sind_rrule(30.0_f64, g) - (std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos())) - .abs() - < 1e-12 + assert_close_f64( + dsind, + std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos(), + 1e-12, + 0.0, + "sind_frule", + ); + assert_close_f64( + sind_rrule(30.0_f64, g), + std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos(), + 1e-12, + 0.0, + "sind_rrule", ); let (_, dcospi) = cospi_frule(x, g); - assert!((dcospi + std::f64::consts::PI * (std::f64::consts::PI * x).sin()).abs() < 1e-12); - assert!( - (cospi_rrule(x, g) + std::f64::consts::PI * (std::f64::consts::PI * x).sin()).abs() < 1e-12 + assert_close_f64( + dcospi, + -std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "cospi_frule", + ); + assert_close_f64( + cospi_rrule(x, g), + -std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "cospi_rrule", ); let (_, dcosd) = cosd_frule(60.0_f64, g); - assert!((dcosd + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12); - assert!( - (cosd_rrule(60.0_f64, g) + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()) - .abs() - < 1e-12 + assert_close_f64( + dcosd, + -std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin(), + 1e-12, + 0.0, + "cosd_frule", + ); + assert_close_f64( + cosd_rrule(60.0_f64, g), + -std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin(), + 1e-12, + 0.0, + "cosd_rrule", ); let (_, dsincospi) = sincospi_frule(x, g); - assert!((dsincospi.0 - std::f64::consts::PI * (std::f64::consts::PI * x).cos()).abs() < 1e-12); - assert!((dsincospi.1 + std::f64::consts::PI * (std::f64::consts::PI * x).sin()).abs() < 1e-12); - assert!( - (sincospi_rrule(x, (g, g)) - - (std::f64::consts::PI * (std::f64::consts::PI * x).cos() - - std::f64::consts::PI * (std::f64::consts::PI * x).sin())) - .abs() - < 1e-12 + assert_close_f64( + dsincospi.0, + std::f64::consts::PI * (std::f64::consts::PI * x).cos(), + 1e-12, + 0.0, + "sincospi_frule.sin", + ); + assert_close_f64( + dsincospi.1, + -std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sincospi_frule.cos", + ); + assert_close_f64( + sincospi_rrule(x, (g, g)), + std::f64::consts::PI * (std::f64::consts::PI * x).cos() + - std::f64::consts::PI * (std::f64::consts::PI * x).sin(), + 1e-12, + 0.0, + "sincospi_rrule", ); } @@ -135,22 +285,31 @@ fn julia_compat_helpers_cover_complex_primal_and_cotangent_paths() { let cotangent = Complex64::new(0.75, -0.5); let pi_z = Complex64::new(std::f64::consts::PI, 0.0) * z; - assert_complex_close(sinpi(z), pi_z.sin()); - assert_complex_close(cospi(z), pi_z.cos()); + assert_close_complex64(sinpi(z), pi_z.sin(), 1e-12, 0.0, "sinpi(z)"); + assert_close_complex64(cospi(z), pi_z.cos(), 1e-12, 0.0, "cospi(z)"); let (_, dsinpi) = sinpi_frule(z, dz); - assert_complex_close( + assert_close_complex64( dsinpi, dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()).conj(), + 1e-12, + 0.0, + "sinpi_frule(z)", ); - assert_complex_close( + assert_close_complex64( sec_rrule(z, cotangent), cotangent * (z.sin() / z.cos().powi(2)).conj(), + 1e-12, + 0.0, + "sec_rrule(z)", ); - assert_complex_close( + assert_close_complex64( sinpi_rrule(z, cotangent), cotangent * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()).conj(), + 1e-12, + 0.0, + "sinpi_rrule(z)", ); } @@ -159,15 +318,21 @@ fn julia_compat_complex_inputs_cover_generic_surface() { let z = Complex64::new(0.25, -0.5); let pi_z = Complex64::new(std::f64::consts::PI, 0.0) * z; - assert!((sinpi(z) - pi_z.sin()).norm() < 1e-12); - assert!((cospi(z) - pi_z.cos()).norm() < 1e-12); + assert_close_complex64(sinpi(z), pi_z.sin(), 1e-12, 0.0, "sinpi(z)"); + assert_close_complex64(cospi(z), pi_z.cos(), 1e-12, 0.0, "cospi(z)"); let (s, c) = sincospi(z); - assert!((s - pi_z.sin()).norm() < 1e-12); - assert!((c - pi_z.cos()).norm() < 1e-12); + assert_close_complex64(s, pi_z.sin(), 1e-12, 0.0, "sincospi.sin(z)"); + assert_close_complex64(c, pi_z.cos(), 1e-12, 0.0, "sincospi.cos(z)"); let deg_z = Complex64::new(std::f64::consts::PI / 180.0, 0.0) * z; - assert!((sind(z) - deg_z.sin()).norm() < 1e-12); - assert!((cosd(z) - deg_z.cos()).norm() < 1e-12); - assert!((tand(z) - deg_z.tan()).norm() < 1e-12); - assert!((coth(z) - Complex64::new(1.0, 0.0) / z.tanh()).norm() < 1e-12); + assert_close_complex64(sind(z), deg_z.sin(), 1e-12, 0.0, "sind(z)"); + assert_close_complex64(cosd(z), deg_z.cos(), 1e-12, 0.0, "cosd(z)"); + assert_close_complex64(tand(z), deg_z.tan(), 1e-12, 0.0, "tand(z)"); + assert_close_complex64( + coth(z), + Complex64::new(1.0, 0.0) / z.tanh(), + 1e-12, + 0.0, + "coth(z)", + ); } diff --git a/crates/chainrules/tests/oracle_scalar_rules.rs b/crates/chainrules/tests/oracle_scalar_rules.rs index 2a9aefd..5dc161f 100644 --- a/crates/chainrules/tests/oracle_scalar_rules.rs +++ b/crates/chainrules/tests/oracle_scalar_rules.rs @@ -1,181 +1,138 @@ -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::path::PathBuf; +mod common; use chainrules::{ acos_frule, acos_rrule, acosh_frule, acosh_rrule, asin_frule, asin_rrule, asinh_frule, asinh_rrule, atan_frule, atan_rrule, atanh_frule, atanh_rrule, cos_frule, cos_rrule, - cosh_frule, cosh_rrule, exp_frule, exp_rrule, expm1_frule, expm1_rrule, log1p_frule, - log1p_rrule, log_frule, log_rrule, sin_frule, sin_rrule, sinh_frule, sinh_rrule, sqrt_frule, - sqrt_rrule, tanh_frule, tanh_rrule, + cosh_frule, cosh_rrule, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1_frule, expm1_rrule, + log1p_frule, log1p_rrule, log2_frule, log2_rrule, log_frule, log_rrule, sin_frule, sin_rrule, + sinh_frule, sinh_rrule, sqrt_frule, sqrt_rrule, tan_frule, tan_rrule, tanh_frule, tanh_rrule, }; -use serde_json::Value; +use num_complex::Complex64; -struct UnaryRuleCase { - op: &'static str, - frule: fn(f64, f64) -> (f64, f64), - rrule: fn(f64, f64, f64) -> f64, -} - -fn oracle_root() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("..") - .join("..") - .join("third_party") - .join("tensor-ad-oracles") -} - -fn first_successful_float64_case(op: &str) -> Value { - let path = oracle_root().join("cases").join(op).join("identity.jsonl"); - let file = File::open(&path).unwrap_or_else(|err| panic!("open {}: {err}", path.display())); - let reader = BufReader::new(file); - - for line in reader.lines() { - let line = line.unwrap_or_else(|err| panic!("read {}: {err}", path.display())); - let value: Value = serde_json::from_str(&line) - .unwrap_or_else(|err| panic!("parse {}: {err}", path.display())); - let dtype = value["dtype"].as_str(); - let behavior = value["expected_behavior"].as_str(); - if dtype == Some("float64") && behavior == Some("success") { - return value; - } - } - - panic!("no successful float64 case found in {}", path.display()); -} - -fn scalar(value: &Value, path: &str) -> f64 { - value - .as_f64() - .unwrap_or_else(|| panic!("expected float64 at {path}, got {value}")) -} - -fn assert_close(actual: f64, expected: f64, atol: f64, rtol: f64, label: &str) { - let tol = atol + rtol * expected.abs().max(actual.abs()); - assert!( - (actual - expected).abs() <= tol, - "{label}: actual={actual}, expected={expected}, atol={atol}, rtol={rtol}", - ); -} +use common::{run_unary_oracle_cases, run_unary_oracle_reverse_cases_complex64, UnaryOracleCase}; #[test] fn published_float64_oracles_match_unary_rule_entrypoints() { - let cases = [ - UnaryRuleCase { + let cases: [UnaryOracleCase; 19] = [ + UnaryOracleCase { op: "sqrt", frule: sqrt_frule, - rrule: |x, _result, cotangent| sqrt_rrule(x.sqrt(), cotangent), + rrule: |x: f64, _result, cotangent| sqrt_rrule(x.sqrt(), cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "exp", frule: exp_frule, - rrule: |_x, result, cotangent| exp_rrule(result, cotangent), + rrule: |_x: f64, result, cotangent| exp_rrule(result, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "expm1", frule: expm1_frule, - rrule: |_x, result, cotangent| expm1_rrule(result, cotangent), + rrule: |_x: f64, result, cotangent| expm1_rrule(result, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "log", frule: log_frule, - rrule: |x, _result, cotangent| log_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| log_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "log1p", frule: log1p_frule, - rrule: |x, _result, cotangent| log1p_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| log1p_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "sin", frule: sin_frule, - rrule: |x, _result, cotangent| sin_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| sin_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "cos", frule: cos_frule, - rrule: |x, _result, cotangent| cos_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| cos_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "tanh", frule: tanh_frule, - rrule: |_x, result, cotangent| tanh_rrule(result, cotangent), + rrule: |_x: f64, result, cotangent| tanh_rrule(result, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "asin", frule: asin_frule, - rrule: |x, _result, cotangent| asin_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| asin_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "acos", frule: acos_frule, - rrule: |x, _result, cotangent| acos_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| acos_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "atan", frule: atan_frule, - rrule: |x, _result, cotangent| atan_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| atan_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "sinh", frule: sinh_frule, - rrule: |x, _result, cotangent| sinh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| sinh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "cosh", frule: cosh_frule, - rrule: |x, _result, cotangent| cosh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| cosh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "asinh", frule: asinh_frule, - rrule: |x, _result, cotangent| asinh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| asinh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "acosh", frule: acosh_frule, - rrule: |x, _result, cotangent| acosh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| acosh_rrule(x, cotangent), }, - UnaryRuleCase { + UnaryOracleCase { op: "atanh", frule: atanh_frule, - rrule: |x, _result, cotangent| atanh_rrule(x, cotangent), + rrule: |x: f64, _result, cotangent| atanh_rrule(x, cotangent), + }, + UnaryOracleCase { + op: "tan", + frule: tan_frule, + rrule: |_x: f64, result, cotangent| tan_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "exp2", + frule: exp2_frule, + rrule: |_x: f64, result, cotangent| exp2_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "log2", + frule: log2_frule, + rrule: |x: f64, _result, cotangent| log2_rrule(x, cotangent), }, ]; - for case in cases { - let oracle = first_successful_float64_case(case.op); - let input = scalar(&oracle["inputs"]["a"]["data"][0], "inputs.a.data[0]"); - let probe = &oracle["probes"][0]; - let tangent = scalar( - &probe["direction"]["a"]["data"][0], - "probes[0].direction.a.data[0]", - ); - let cotangent = scalar( - &probe["cotangent"]["value"]["data"][0], - "probes[0].cotangent.value.data[0]", - ); - let expected_jvp = scalar( - &probe["pytorch_ref"]["jvp"]["value"]["data"][0], - "probes[0].pytorch_ref.jvp.value.data[0]", - ); - let expected_vjp = scalar( - &probe["pytorch_ref"]["vjp"]["a"]["data"][0], - "probes[0].pytorch_ref.vjp.a.data[0]", - ); - let atol = scalar( - &oracle["comparison"]["first_order"]["atol"], - "comparison.first_order.atol", - ); - let rtol = scalar( - &oracle["comparison"]["first_order"]["rtol"], - "comparison.first_order.rtol", - ); + run_unary_oracle_cases(&cases); +} - let (result, actual_jvp) = (case.frule)(input, tangent); - let actual_vjp = (case.rrule)(input, result, cotangent); +#[test] +fn published_complex128_oracles_match_unary_rule_entrypoints() { + let cases: [UnaryOracleCase; 3] = [ + UnaryOracleCase { + op: "tan", + frule: tan_frule, + rrule: |_x: Complex64, result, cotangent| tan_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "exp2", + frule: exp2_frule, + rrule: |_x: Complex64, result, cotangent| exp2_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "log2", + frule: log2_frule, + rrule: |x: Complex64, _result, cotangent| log2_rrule(x, cotangent), + }, + ]; - assert_close(actual_jvp, expected_jvp, atol, rtol, case.op); - assert_close(actual_vjp, expected_vjp, atol, rtol, case.op); - } + run_unary_oracle_reverse_cases_complex64(&cases); } diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index 48a0a43..cdf657d 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -1,3 +1,5 @@ +mod common; + use chainrules::{ cbrt, cbrt_frule, cbrt_rrule, exp10, exp10_frule, exp10_rrule, exp2, exp2_frule, exp2_rrule, hypot, hypot_frule, hypot_rrule, inv, inv_frule, inv_rrule, log10, log10_frule, log10_rrule, @@ -6,20 +8,22 @@ use chainrules::{ }; use num_complex::{Complex64, ComplexFloat}; +use common::assert_close_f64; + #[test] fn smooth_basis_helpers_are_reexported_from_chainrules() { - assert!((cbrt(8.0_f64) - 2.0).abs() < 1.0e-12); - assert!((inv(4.0_f64) - 0.25).abs() < 1.0e-12); - assert!((exp2(3.0_f64) - 8.0).abs() < 1.0e-12); - assert!((exp10(2.0_f64) - 100.0).abs() < 1.0e-12); - assert!((hypot(3.0_f64, 4.0_f64) - 5.0).abs() < 1.0e-12); - assert!((log2(8.0_f64) - 3.0).abs() < 1.0e-12); - assert!((log10(100.0_f64) - 2.0).abs() < 1.0e-12); - assert!((pow(2.0_f64, 3.0_f64) - 8.0).abs() < 1.0e-12); - assert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1.0e-12); + assert_close_f64(cbrt(8.0_f64), 2.0, 1.0e-12, 0.0, "cbrt"); + assert_close_f64(inv(4.0_f64), 0.25, 1.0e-12, 0.0, "inv"); + assert_close_f64(exp2(3.0_f64), 8.0, 1.0e-12, 0.0, "exp2"); + assert_close_f64(exp10(2.0_f64), 100.0, 1.0e-12, 0.0, "exp10"); + assert_close_f64(hypot(3.0_f64, 4.0_f64), 5.0, 1.0e-12, 0.0, "hypot"); + assert_close_f64(log2(8.0_f64), 3.0, 1.0e-12, 0.0, "log2"); + assert_close_f64(log10(100.0_f64), 2.0, 1.0e-12, 0.0, "log10"); + assert_close_f64(pow(2.0_f64, 3.0_f64), 8.0, 1.0e-12, 0.0, "pow"); + assert_close_f64(tan(0.5_f64), 0.5_f64.tan(), 1.0e-12, 0.0, "tan"); let (sin_x, cos_x) = sincos(0.5_f64); - assert!((sin_x - 0.5_f64.sin()).abs() < 1.0e-12); - assert!((cos_x - 0.5_f64.cos()).abs() < 1.0e-12); + assert_close_f64(sin_x, 0.5_f64.sin(), 1.0e-12, 0.0, "sincos.sin"); + assert_close_f64(cos_x, 0.5_f64.cos(), 1.0e-12, 0.0, "sincos.cos"); let z = Complex64::new(1.0, 2.0); let _ = pow(z, Complex64::new(2.0, 0.0)); @@ -28,70 +32,134 @@ fn smooth_basis_helpers_are_reexported_from_chainrules() { #[test] fn smooth_basis_frules_and_rrules_match_expected_derivatives() { let (tan_y, tan_dy) = tan_frule(0.25_f64, 1.0_f64); - assert!((tan_y - 0.25_f64.tan()).abs() < 1.0e-12); - assert!((tan_dy - (1.0_f64 + 0.25_f64.tan().powi(2))).abs() < 1.0e-12); - assert!( - (tan_rrule(0.25_f64.tan(), 1.0_f64) - (1.0_f64 + 0.25_f64.tan().powi(2))).abs() < 1.0e-12 + assert_close_f64(tan_y, 0.25_f64.tan(), 1.0e-12, 0.0, "tan.y"); + assert_close_f64( + tan_dy, + 1.0_f64 + 0.25_f64.tan().powi(2), + 1.0e-12, + 0.0, + "tan.dy", + ); + assert_close_f64( + tan_rrule(0.25_f64.tan(), 1.0_f64), + 1.0_f64 + 0.25_f64.tan().powi(2), + 1.0e-12, + 0.0, + "tan.rrule", ); let (exp2_y, exp2_dy) = exp2_frule(3.0_f64, 1.0_f64); - assert!((exp2_y - 8.0).abs() < 1.0e-12); - assert!((exp2_dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1.0e-12); - assert!((exp2_rrule(8.0_f64, 1.0_f64) - 8.0_f64 * std::f64::consts::LN_2).abs() < 1.0e-12); + assert_close_f64(exp2_y, 8.0, 1.0e-12, 0.0, "exp2.y"); + assert_close_f64( + exp2_dy, + 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "exp2.dy", + ); + assert_close_f64( + exp2_rrule(8.0_f64, 1.0_f64), + 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "exp2.rrule", + ); let (hypot_y, hypot_dy) = hypot_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64); - assert!((hypot_y - 5.0).abs() < 1.0e-12); - assert!((hypot_dy - 0.5).abs() < 1.0e-12); - assert!((hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64).0 - 0.6_f64).abs() < 1.0e-12); - assert!((hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64).1 - 0.8_f64).abs() < 1.0e-12); + assert_close_f64(hypot_y, 5.0, 1.0e-12, 0.0, "hypot.y"); + assert_close_f64(hypot_dy, 0.5, 1.0e-12, 0.0, "hypot.dy"); + let (hypot_dx, hypot_dy) = hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64); + assert_close_f64(hypot_dx, 0.6_f64, 1.0e-12, 0.0, "hypot.rrule.dx"); + assert_close_f64(hypot_dy, 0.8_f64, 1.0e-12, 0.0, "hypot.rrule.dy"); let (pow_y, pow_dy) = pow_frule(2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); - assert!((pow_y - 8.0).abs() < 1.0e-12); - assert!((pow_dy - 12.0).abs() < 1.0e-12); + assert_close_f64(pow_y, 8.0, 1.0e-12, 0.0, "pow.y"); + assert_close_f64(pow_dy, 12.0, 1.0e-12, 0.0, "pow.dy"); let (pow_dx, pow_dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0_f64); - assert!((pow_dx - 12.0).abs() < 1.0e-12); - assert!((pow_dexp - (8.0_f64 * std::f64::consts::LN_2)).abs() < 1.0e-12); + assert_close_f64(pow_dx, 12.0, 1.0e-12, 0.0, "pow.rrule.dx"); + assert_close_f64( + pow_dexp, + 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "pow.rrule.dexp", + ); let (sincos_y, sincos_dy) = sincos_frule(0.25_f64, 1.0_f64); - assert!((sincos_y.0 - 0.25_f64.sin()).abs() < 1.0e-12); - assert!((sincos_y.1 - 0.25_f64.cos()).abs() < 1.0e-12); - assert!((sincos_dy.0 - 0.25_f64.cos()).abs() < 1.0e-12); - assert!((sincos_dy.1 + 0.25_f64.sin()).abs() < 1.0e-12); - assert!( - (sincos_rrule(0.25_f64, (1.0_f64, 1.0_f64)) - (0.25_f64.cos() - 0.25_f64.sin())).abs() - < 1.0e-12 + assert_close_f64(sincos_y.0, 0.25_f64.sin(), 1.0e-12, 0.0, "sincos.y.sin"); + assert_close_f64(sincos_y.1, 0.25_f64.cos(), 1.0e-12, 0.0, "sincos.y.cos"); + assert_close_f64(sincos_dy.0, 0.25_f64.cos(), 1.0e-12, 0.0, "sincos.dy.sin"); + assert_close_f64(sincos_dy.1, -0.25_f64.sin(), 1.0e-12, 0.0, "sincos.dy.cos"); + assert_close_f64( + sincos_rrule(0.25_f64, (1.0_f64, 1.0_f64)), + 0.25_f64.cos() - 0.25_f64.sin(), + 1.0e-12, + 0.0, + "sincos.rrule", ); let (cbrt_y, cbrt_dy) = cbrt_frule(8.0_f64, 1.0_f64); - assert!((cbrt_y - 2.0).abs() < 1.0e-12); - assert!((cbrt_dy - (1.0_f64 / (3.0_f64 * 4.0_f64))).abs() < 1.0e-12); - assert!((cbrt_rrule(2.0_f64, 1.0_f64) - (1.0_f64 / (3.0_f64 * 4.0_f64))).abs() < 1.0e-12); + assert_close_f64(cbrt_y, 2.0, 1.0e-12, 0.0, "cbrt.y"); + assert_close_f64( + cbrt_dy, + 1.0_f64 / (3.0_f64 * 4.0_f64), + 1.0e-12, + 0.0, + "cbrt.dy", + ); + assert_close_f64( + cbrt_rrule(2.0_f64, 1.0_f64), + 1.0_f64 / (3.0_f64 * 4.0_f64), + 1.0e-12, + 0.0, + "cbrt.rrule", + ); let (inv_y, inv_dy) = inv_frule(4.0_f64, 2.0_f64); - assert!((inv_y - 0.25).abs() < 1.0e-12); - assert!((inv_dy + 0.125).abs() < 1.0e-12); - assert!((inv_rrule(0.25_f64, 2.0_f64) + 0.125).abs() < 1.0e-12); + assert_close_f64(inv_y, 0.25, 1.0e-12, 0.0, "inv.y"); + assert_close_f64(inv_dy, -0.125, 1.0e-12, 0.0, "inv.dy"); + assert_close_f64( + inv_rrule(0.25_f64, 2.0_f64), + -0.125, + 1.0e-12, + 0.0, + "inv.rrule", + ); let (log2_y, log2_dy) = log2_frule(8.0_f64, 2.0_f64); - assert!((log2_y - 3.0).abs() < 1.0e-12); + assert_close_f64(log2_y, 3.0, 1.0e-12, 0.0, "log2.y"); let expected_log2 = 2.0_f64 / (8.0_f64 * std::f64::consts::LN_2); - assert!((log2_dy - expected_log2).abs() < 1.0e-12); - assert!((log2_rrule(8.0_f64, 2.0_f64) - expected_log2).abs() < 1.0e-12); + assert_close_f64(log2_dy, expected_log2, 1.0e-12, 0.0, "log2.dy"); + assert_close_f64( + log2_rrule(8.0_f64, 2.0_f64), + expected_log2, + 1.0e-12, + 0.0, + "log2.rrule", + ); let (log10_y, log10_dy) = log10_frule(100.0_f64, 2.0_f64); - assert!((log10_y - 2.0).abs() < 1.0e-12); - assert!((log10_dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1.0e-12); - assert!( - (log10_rrule(100.0_f64, 2.0_f64) - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() - < 1.0e-12 + assert_close_f64(log10_y, 2.0, 1.0e-12, 0.0, "log10.y"); + let expected_log10 = 2.0_f64 / (100.0_f64 * std::f64::consts::LN_10); + assert_close_f64(log10_dy, expected_log10, 1.0e-12, 0.0, "log10.dy"); + assert_close_f64( + log10_rrule(100.0_f64, 2.0_f64), + expected_log10, + 1.0e-12, + 0.0, + "log10.rrule", ); let (exp10_y, exp10_dy) = exp10_frule(2.0_f64, 0.5_f64); - assert!((exp10_y - 100.0).abs() < 1.0e-12); - assert!((exp10_dy - (100.0_f64 * std::f64::consts::LN_10 * 0.5_f64)).abs() < 1.0e-12); - assert!( - (exp10_rrule(100.0_f64, 0.5_f64) - (100.0_f64 * std::f64::consts::LN_10 * 0.5_f64)).abs() - < 1.0e-12 + assert_close_f64(exp10_y, 100.0, 1.0e-12, 0.0, "exp10.y"); + let expected_exp10 = 100.0_f64 * std::f64::consts::LN_10 * 0.5_f64; + assert_close_f64(exp10_dy, expected_exp10, 1.0e-12, 0.0, "exp10.dy"); + assert_close_f64( + exp10_rrule(100.0_f64, 0.5_f64), + expected_exp10, + 1.0e-12, + 0.0, + "exp10.rrule", ); } From f053a04d5cf7ce4f155f5db539f2fc78b7cf8ec0 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 13:17:27 +0900 Subject: [PATCH 26/32] test: fix scalar oracle replay --- crates/chainrules/src/tests/organization.rs | 6 + crates/chainrules/tests/common.rs | 238 +++++++++++++----- .../chainrules/tests/oracle_scalar_rules.rs | 28 ++- crates/chainrules/tests/smooth_basis_tests.rs | 41 ++- 4 files changed, 237 insertions(+), 76 deletions(-) diff --git a/crates/chainrules/src/tests/organization.rs b/crates/chainrules/src/tests/organization.rs index 0c5976f..30eec50 100644 --- a/crates/chainrules/src/tests/organization.rs +++ b/crates/chainrules/src/tests/organization.rs @@ -18,6 +18,12 @@ fn chainrules_modules_stay_under_size_guideline() { ); assert_line_count("../unary/mod.rs", include_str!("../unary/mod.rs"), 60); assert_line_count("../unary/basic.rs", include_str!("../unary/basic.rs"), 40); + assert_line_count("../unary/roots.rs", include_str!("../unary/roots.rs"), 100); + assert_line_count( + "../unary/complex_parts.rs", + include_str!("../unary/complex_parts.rs"), + 210, + ); assert_line_count( "../unary/exp_log.rs", include_str!("../unary/exp_log.rs"), diff --git a/crates/chainrules/tests/common.rs b/crates/chainrules/tests/common.rs index 115ec86..b2be005 100644 --- a/crates/chainrules/tests/common.rs +++ b/crates/chainrules/tests/common.rs @@ -13,8 +13,15 @@ pub struct UnaryOracleCase { pub rrule: fn(T, T, T) -> T, } +pub struct UnaryReverseOracleCase { + pub op: &'static str, + pub primal: fn(T) -> T, + pub rrule: fn(T, T, T) -> T, +} + pub trait OracleScalar: Copy + core::fmt::Debug { fn dtype() -> &'static str; + fn is_scalar_value(value: &Value) -> bool; fn from_json(value: &Value, path: &str) -> Self; fn assert_close(actual: Self, expected: Self, atol: f64, rtol: f64, label: &str); } @@ -24,6 +31,10 @@ impl OracleScalar for f64 { "float64" } + fn is_scalar_value(value: &Value) -> bool { + value.is_number() + } + fn from_json(value: &Value, path: &str) -> Self { scalar_f64(value, path) } @@ -38,6 +49,12 @@ impl OracleScalar for Complex64 { "complex128" } + fn is_scalar_value(value: &Value) -> bool { + value + .as_array() + .is_some_and(|items| items.len() == 2 && items.iter().all(Value::is_number)) + } + fn from_json(value: &Value, path: &str) -> Self { scalar_complex64(value, path) } @@ -55,10 +72,11 @@ fn oracle_root() -> PathBuf { .join("tensor-ad-oracles") } -pub fn first_successful_case(op: &str, dtype: &str) -> Value { +pub fn successful_cases(op: &str, dtype: &str) -> Vec { let path = oracle_root().join("cases").join(op).join("identity.jsonl"); let file = File::open(&path).unwrap_or_else(|err| panic!("open {}: {err}", path.display())); let reader = BufReader::new(file); + let mut cases = Vec::new(); for line in reader.lines() { let line = line.unwrap_or_else(|err| panic!("read {}: {err}", path.display())); @@ -67,11 +85,16 @@ pub fn first_successful_case(op: &str, dtype: &str) -> Value { let case_dtype = value["dtype"].as_str(); let behavior = value["expected_behavior"].as_str(); if case_dtype == Some(dtype) && behavior == Some("success") { - return value; + cases.push(value); } } - panic!("no successful {dtype} case found in {}", path.display()); + assert!( + !cases.is_empty(), + "no successful {dtype} cases found in {}", + path.display() + ); + cases } pub fn scalar_f64(value: &Value, path: &str) -> f64 { @@ -113,68 +136,163 @@ pub fn assert_close_complex64( ); } +fn collect_scalar_values(value: &Value, path: &str, out: &mut Vec) { + if T::is_scalar_value(value) { + out.push(T::from_json(value, path)); + return; + } + + let items = value + .as_array() + .unwrap_or_else(|| panic!("expected array at {path}, got {value}")); + for (index, item) in items.iter().enumerate() { + collect_scalar_values::(item, &format!("{path}[{index}]"), out); + } +} + +fn scalar_values(value: &Value, path: &str) -> Vec { + let mut out = Vec::new(); + collect_scalar_values(value, path, &mut out); + out +} + pub fn run_unary_oracle_cases(cases: &[UnaryOracleCase]) { for case in cases { - let oracle = first_successful_case(case.op, T::dtype()); - let input = T::from_json(&oracle["inputs"]["a"]["data"][0], "inputs.a.data[0]"); - let probe = &oracle["probes"][0]; - let tangent = T::from_json( - &probe["direction"]["a"]["data"][0], - "probes[0].direction.a.data[0]", - ); - let cotangent = T::from_json( - &probe["cotangent"]["value"]["data"][0], - "probes[0].cotangent.value.data[0]", - ); - let expected_jvp = T::from_json( - &probe["pytorch_ref"]["jvp"]["value"]["data"][0], - "probes[0].pytorch_ref.jvp.value.data[0]", - ); - let expected_vjp = T::from_json( - &probe["pytorch_ref"]["vjp"]["a"]["data"][0], - "probes[0].pytorch_ref.vjp.a.data[0]", - ); - let atol = scalar_f64( - &oracle["comparison"]["first_order"]["atol"], - "comparison.first_order.atol", - ); - let rtol = scalar_f64( - &oracle["comparison"]["first_order"]["rtol"], - "comparison.first_order.rtol", - ); - - let (result, actual_jvp) = (case.frule)(input, tangent); - let actual_vjp = (case.rrule)(input, result, cotangent); - - T::assert_close(actual_jvp, expected_jvp, atol, rtol, case.op); - T::assert_close(actual_vjp, expected_vjp, atol, rtol, case.op); + for (case_index, oracle) in successful_cases(case.op, T::dtype()) + .into_iter() + .enumerate() + { + let inputs = scalar_values::(&oracle["inputs"]["a"]["data"], "inputs.a.data"); + let probes = oracle["probes"] + .as_array() + .unwrap_or_else(|| panic!("expected probes array for {}", case.op)); + let atol = scalar_f64( + &oracle["comparison"]["first_order"]["atol"], + "comparison.first_order.atol", + ); + let rtol = scalar_f64( + &oracle["comparison"]["first_order"]["rtol"], + "comparison.first_order.rtol", + ); + + for (probe_index, probe) in probes.iter().enumerate() { + let tangents = scalar_values::( + &probe["direction"]["a"]["data"], + &format!("probes[{probe_index}].direction.a.data"), + ); + let cotangents = scalar_values::( + &probe["cotangent"]["value"]["data"], + &format!("probes[{probe_index}].cotangent.value.data"), + ); + let expected_jvps = scalar_values::( + &probe["pytorch_ref"]["jvp"]["value"]["data"], + &format!("probes[{probe_index}].pytorch_ref.jvp.value.data"), + ); + let expected_vjps = scalar_values::( + &probe["pytorch_ref"]["vjp"]["a"]["data"], + &format!("probes[{probe_index}].pytorch_ref.vjp.a.data"), + ); + + assert_eq!( + inputs.len(), + tangents.len(), + "{} case {case_index} probe {probe_index}: input and tangent lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + cotangents.len(), + "{} case {case_index} probe {probe_index}: input and cotangent lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + expected_jvps.len(), + "{} case {case_index} probe {probe_index}: input and expected jvp lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + expected_vjps.len(), + "{} case {case_index} probe {probe_index}: input and expected vjp lengths differ", + case.op + ); + + for index in 0..inputs.len() { + let (result, actual_jvp) = (case.frule)(inputs[index], tangents[index]); + let actual_vjp = (case.rrule)(inputs[index], result, cotangents[index]); + let label = format!( + "{} case {case_index} probe {probe_index} element {index}", + case.op + ); + T::assert_close(actual_jvp, expected_jvps[index], atol, rtol, &label); + T::assert_close(actual_vjp, expected_vjps[index], atol, rtol, &label); + } + } + } } } -pub fn run_unary_oracle_reverse_cases_complex64(cases: &[UnaryOracleCase]) { +// Reverse-only complex oracle replay keeps the current ScalarAd convention explicit. +pub fn run_unary_oracle_reverse_cases_complex64(cases: &[UnaryReverseOracleCase]) { for case in cases { - let oracle = first_successful_case(case.op, Complex64::dtype()); - let input = scalar_complex64(&oracle["inputs"]["a"]["data"][0], "inputs.a.data[0]"); - let probe = &oracle["probes"][0]; - let cotangent = scalar_complex64( - &probe["cotangent"]["value"]["data"][0], - "probes[0].cotangent.value.data[0]", - ); - let expected_vjp = scalar_complex64( - &probe["pytorch_ref"]["vjp"]["a"]["data"][0], - "probes[0].pytorch_ref.vjp.a.data[0]", - ); - let atol = scalar_f64( - &oracle["comparison"]["first_order"]["atol"], - "comparison.first_order.atol", - ); - let rtol = scalar_f64( - &oracle["comparison"]["first_order"]["rtol"], - "comparison.first_order.rtol", - ); - - let (result, _) = (case.frule)(input, cotangent); - let actual_vjp = (case.rrule)(input, result, cotangent); - Complex64::assert_close(actual_vjp, expected_vjp, atol, rtol, case.op); + for (case_index, oracle) in successful_cases(case.op, ::dtype()) + .into_iter() + .enumerate() + { + let inputs = + scalar_values::(&oracle["inputs"]["a"]["data"], "inputs.a.data"); + let probes = oracle["probes"] + .as_array() + .unwrap_or_else(|| panic!("expected probes array for {}", case.op)); + let atol = scalar_f64( + &oracle["comparison"]["first_order"]["atol"], + "comparison.first_order.atol", + ); + let rtol = scalar_f64( + &oracle["comparison"]["first_order"]["rtol"], + "comparison.first_order.rtol", + ); + + for (probe_index, probe) in probes.iter().enumerate() { + let cotangents = scalar_values::( + &probe["cotangent"]["value"]["data"], + &format!("probes[{probe_index}].cotangent.value.data"), + ); + let expected_vjps = scalar_values::( + &probe["pytorch_ref"]["vjp"]["a"]["data"], + &format!("probes[{probe_index}].pytorch_ref.vjp.a.data"), + ); + + assert_eq!( + inputs.len(), + cotangents.len(), + "{} case {case_index} probe {probe_index}: input and cotangent lengths differ", + case.op + ); + assert_eq!( + inputs.len(), + expected_vjps.len(), + "{} case {case_index} probe {probe_index}: input and expected vjp lengths differ", + case.op + ); + + for index in 0..inputs.len() { + let result = (case.primal)(inputs[index]); + let actual_vjp = (case.rrule)(inputs[index], result, cotangents[index]); + let label = format!( + "{} case {case_index} probe {probe_index} element {index}", + case.op + ); + ::assert_close( + actual_vjp, + expected_vjps[index], + atol, + rtol, + &label, + ); + } + } + } } } diff --git a/crates/chainrules/tests/oracle_scalar_rules.rs b/crates/chainrules/tests/oracle_scalar_rules.rs index 5dc161f..d43154a 100644 --- a/crates/chainrules/tests/oracle_scalar_rules.rs +++ b/crates/chainrules/tests/oracle_scalar_rules.rs @@ -3,13 +3,17 @@ mod common; use chainrules::{ acos_frule, acos_rrule, acosh_frule, acosh_rrule, asin_frule, asin_rrule, asinh_frule, asinh_rrule, atan_frule, atan_rrule, atanh_frule, atanh_rrule, cos_frule, cos_rrule, - cosh_frule, cosh_rrule, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1_frule, expm1_rrule, - log1p_frule, log1p_rrule, log2_frule, log2_rrule, log_frule, log_rrule, sin_frule, sin_rrule, - sinh_frule, sinh_rrule, sqrt_frule, sqrt_rrule, tan_frule, tan_rrule, tanh_frule, tanh_rrule, + cosh_frule, cosh_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1_frule, + expm1_rrule, log1p_frule, log1p_rrule, log2, log2_frule, log2_rrule, log_frule, log_rrule, + sin_frule, sin_rrule, sinh_frule, sinh_rrule, sqrt_frule, sqrt_rrule, tan, tan_frule, + tan_rrule, tanh_frule, tanh_rrule, }; use num_complex::Complex64; -use common::{run_unary_oracle_cases, run_unary_oracle_reverse_cases_complex64, UnaryOracleCase}; +use common::{ + run_unary_oracle_cases, run_unary_oracle_reverse_cases_complex64, UnaryOracleCase, + UnaryReverseOracleCase, +}; #[test] fn published_float64_oracles_match_unary_rule_entrypoints() { @@ -115,21 +119,21 @@ fn published_float64_oracles_match_unary_rule_entrypoints() { } #[test] -fn published_complex128_oracles_match_unary_rule_entrypoints() { - let cases: [UnaryOracleCase; 3] = [ - UnaryOracleCase { +fn published_complex128_oracles_match_unary_rule_entrypoints_reverse_only() { + let cases = [ + UnaryReverseOracleCase { op: "tan", - frule: tan_frule, + primal: tan, rrule: |_x: Complex64, result, cotangent| tan_rrule(result, cotangent), }, - UnaryOracleCase { + UnaryReverseOracleCase { op: "exp2", - frule: exp2_frule, + primal: exp2, rrule: |_x: Complex64, result, cotangent| exp2_rrule(result, cotangent), }, - UnaryOracleCase { + UnaryReverseOracleCase { op: "log2", - frule: log2_frule, + primal: log2, rrule: |x: Complex64, _result, cotangent| log2_rrule(x, cotangent), }, ]; diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index cdf657d..81ca158 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -8,7 +8,7 @@ use chainrules::{ }; use num_complex::{Complex64, ComplexFloat}; -use common::assert_close_f64; +use common::{assert_close_complex64, assert_close_f64}; #[test] fn smooth_basis_helpers_are_reexported_from_chainrules() { @@ -85,6 +85,16 @@ fn smooth_basis_frules_and_rrules_match_expected_derivatives() { "pow.rrule.dexp", ); + let (pow_y, pow_dy) = pow_frule(2.0_f64, 3.0_f64, 1.0_f64, 0.5_f64); + assert_close_f64(pow_y, 8.0, 1.0e-12, 0.0, "pow.y.dexp"); + assert_close_f64( + pow_dy, + 12.0 + 0.5_f64 * 8.0_f64 * std::f64::consts::LN_2, + 1.0e-12, + 0.0, + "pow.dy.dexp", + ); + let (sincos_y, sincos_dy) = sincos_frule(0.25_f64, 1.0_f64); assert_close_f64(sincos_y.0, 0.25_f64.sin(), 1.0e-12, 0.0, "sincos.y.sin"); assert_close_f64(sincos_y.1, 0.25_f64.cos(), 1.0e-12, 0.0, "sincos.y.cos"); @@ -163,6 +173,28 @@ fn smooth_basis_frules_and_rrules_match_expected_derivatives() { ); } +#[test] +fn smooth_basis_complex_frules_match_expected_derivatives() { + let z = Complex64::new(0.25, -0.5); + let dz = Complex64::new(0.5, -0.25); + + let (tan_y, tan_dy) = tan_frule(z, dz); + let tan_scale = (Complex64::new(1.0, 0.0) + tan_y * tan_y).conj(); + assert_close_complex64(tan_y, z.tan(), 1.0e-12, 0.0, "tan.z"); + assert_close_complex64(tan_dy, dz * tan_scale, 1.0e-12, 0.0, "tan.dz"); + + let (exp2_y, exp2_dy) = exp2_frule(z, dz); + let exp2_scale = (exp2_y * Complex64::new(std::f64::consts::LN_2, 0.0)).conj(); + assert_close_complex64(exp2_y, z.exp2(), 1.0e-12, 0.0, "exp2.z"); + assert_close_complex64(exp2_dy, dz * exp2_scale, 1.0e-12, 0.0, "exp2.dz"); + + let (log2_y, log2_dy) = log2_frule(z, dz); + let log2_scale = + (Complex64::new(1.0, 0.0) / (z * Complex64::new(std::f64::consts::LN_2, 0.0))).conj(); + assert_close_complex64(log2_y, z.log2(), 1.0e-12, 0.0, "log2.z"); + assert_close_complex64(log2_dy, dz * log2_scale, 1.0e-12, 0.0, "log2.dz"); +} + #[test] fn pow_rules_handle_zero_and_negative_real_paths() { let (neg_y, neg_dy) = pow_frule(-2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); @@ -181,13 +213,14 @@ fn pow_rules_handle_zero_and_negative_real_paths() { #[test] fn pow_rules_cover_complex_frule_and_rrule_paths() { let x = Complex64::new(1.0, 1.0); - let exponent = Complex64::new(2.0, 0.0); + let exponent = Complex64::new(2.0, 0.5); let dx = Complex64::new(0.5, -0.25); - let dexp = Complex64::new(0.0, 0.0); + let dexp = Complex64::new(0.1, -0.2); let (y, dy) = pow_frule(x, exponent, dx, dexp); let expected_y = x.powc(exponent); - let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))).conj(); + let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))).conj() + + dexp * (expected_y * x.ln()).conj(); assert!((y - expected_y).norm() < 1.0e-12); assert!((dy - expected_dy).norm() < 1.0e-12); From eca1e5a69d669af359524f70619dd298be5d2102 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 13:24:22 +0900 Subject: [PATCH 27/32] docs: clarify oracle validation split --- README.md | 6 +++++- crates/chainrules/README.md | 5 ++++- crates/chainrules/tests/common.rs | 5 ++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0b3c767..b33d655 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,11 @@ Scalar rules are checked in complementary ways: `crates/chainrules/tests/julia_compat_trig_tests.rs` and `crates/chainrules/tests/complex_helper_tests.rs` - oracle replay tests in `crates/chainrules/tests/oracle_scalar_rules.rs` - against vendored published cases from `third_party/tensor-ad-oracles` + against vendored published cases from `third_party/tensor-ad-oracles`, + including direct float64 replay and selected complex128 reverse-mode replay; + complex forward-mode checks stay in repository-local formula tests because the + crate's `frule` convention intentionally differs from the published JVP + convention ```bash cargo test --workspace --release diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index dde2ab8..03613e9 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -55,7 +55,10 @@ repository-local tests: - `tests/complex_helper_tests.rs` covers the projection helpers and complex constructor surface - `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from - `../../third_party/tensor-ad-oracles` + `../../third_party/tensor-ad-oracles`, with direct float64 replay and + selected Complex64 reverse-mode replay +- complex forward-mode checks that follow this crate's `ScalarAd` convention + stay in repository-local formula tests such as `tests/smooth_basis_tests.rs` ## Examples diff --git a/crates/chainrules/tests/common.rs b/crates/chainrules/tests/common.rs index b2be005..d9a91df 100644 --- a/crates/chainrules/tests/common.rs +++ b/crates/chainrules/tests/common.rs @@ -233,7 +233,10 @@ pub fn run_unary_oracle_cases(cases: &[UnaryOracleCase]) { } } -// Reverse-only complex oracle replay keeps the current ScalarAd convention explicit. +// Published complex oracle JVPs use the plain holomorphic derivative, while this +// crate's ScalarAd forward convention applies the conjugated derivative. Keep +// the oracle replay reverse-only and cover the complex forward convention in +// repository-local formula tests. pub fn run_unary_oracle_reverse_cases_complex64(cases: &[UnaryReverseOracleCase]) { for case in cases { for (case_index, oracle) in successful_cases(case.op, ::dtype()) From caf19f4c0420181166a614376e5999a287f28d69 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 13:38:21 +0900 Subject: [PATCH 28/32] fix: surface pow zero-base exponent singularities --- crates/chainrules/src/power.rs | 18 ++++++++++------ crates/chainrules/tests/smooth_basis_tests.rs | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/crates/chainrules/src/power.rs b/crates/chainrules/src/power.rs index 921127f..b93d799 100644 --- a/crates/chainrules/src/power.rs +++ b/crates/chainrules/src/power.rs @@ -1,4 +1,4 @@ -use num_traits::{One, Zero}; +use num_traits::{Float, One, Zero}; use crate::ScalarAd; @@ -115,6 +115,9 @@ pub fn pow(x: S, exponent: S) -> S { fn zero() -> S { S::from_i32(0) } +fn nan() -> S { + S::from_real(S::Real::nan()) +} fn pow_x_scale(x: S, exponent: S) -> S { if exponent == zero::() { zero::() @@ -123,14 +126,17 @@ fn pow_x_scale(x: S, exponent: S) -> S { } } fn pow_exp_scale(x: S, exponent: S) -> S { - if x == zero::() && exponent.imag() == S::Real::zero() && exponent.real() >= S::Real::zero() - { - zero::() + if x == zero::() && exponent.imag() == S::Real::zero() { + if exponent.real() > S::Real::zero() { + zero::() + } else { + nan::() + } } else { (x.pow(exponent) * x.ln()).conj() } } -#[doc = "Forward rule for `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow_frule;\n\nlet (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0);\nassert_eq!(y, 8.0);\nassert!((dy - 12.0).abs() < 1e-12);\n```"] +#[doc = "Forward rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-tangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_frule;\n\nlet (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0);\nassert_eq!(y, 8.0);\nassert!((dy - 12.0).abs() < 1e-12);\n```"] pub fn pow_frule(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) { let y = x.pow(exponent); let dfdx = if dx == zero::() { @@ -145,7 +151,7 @@ pub fn pow_frule(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) }; (y, dfdx + dfde) } -#[doc = "Reverse rule for `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow_rrule;\n\nlet (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0);\nassert_eq!(dx, 12.0);\nassert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] +#[doc = "Reverse rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-cotangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_rrule;\n\nlet (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0);\nassert_eq!(dx, 12.0);\nassert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] pub fn pow_rrule(x: S, exponent: S, cotangent: S) -> (S, S) { let dfdx = if cotangent == zero::() { zero::() diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index 81ca158..00d6131 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -210,6 +210,27 @@ fn pow_rules_handle_zero_and_negative_real_paths() { assert!((dexp - 0.0).abs() < 1.0e-12); } +#[test] +fn pow_rules_mark_zero_base_exponent_singularities_for_real_inputs() { + let (_, zero_zero_dy) = pow_frule(0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64); + assert!(zero_zero_dy.is_nan()); + + let (_, zero_neg_dy) = pow_frule(0.0_f64, -1.0_f64, 0.0_f64, 1.0_f64); + assert!(zero_neg_dy.is_nan()); + + let (_, zero_zero_dexp) = pow_rrule(0.0_f64, 0.0_f64, 1.0_f64); + assert!(zero_zero_dexp.is_nan()); + + let (_, zero_neg_dexp) = pow_rrule(0.0_f64, -1.0_f64, 1.0_f64); + assert!(zero_neg_dexp.is_nan()); + + let (_, zero_zero_dy32) = pow_frule(0.0_f32, 0.0_f32, 0.0_f32, 1.0_f32); + assert!(zero_zero_dy32.is_nan()); + + let (_, zero_neg_dexp32) = pow_rrule(0.0_f32, -1.0_f32, 1.0_f32); + assert!(zero_neg_dexp32.is_nan()); +} + #[test] fn pow_rules_cover_complex_frule_and_rrule_paths() { let x = Complex64::new(1.0, 1.0); From 3455e4e7d60bfe7feee444a7cd44981870b1e37c Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 18:32:54 +0900 Subject: [PATCH 29/32] refactor: align unary complex frules with standard jvp --- crates/chainrules/src/tests/behavior.rs | 21 +++--- crates/chainrules/src/unary/basic.rs | 2 +- crates/chainrules/src/unary/exp_log.rs | 16 ++-- crates/chainrules/src/unary/hyperbolic.rs | 12 +-- crates/chainrules/src/unary/roots.rs | 4 +- crates/chainrules/src/unary/trig.rs | 14 ++-- crates/chainrules/src/unary/trig_extra.rs | 11 +-- crates/chainrules/tests/common.rs | 73 ------------------- .../tests/julia_compat_trig_tests.rs | 54 +++++++++++++- .../chainrules/tests/oracle_scalar_rules.rs | 30 ++++---- crates/chainrules/tests/smooth_basis_tests.rs | 25 +++++-- 11 files changed, 123 insertions(+), 139 deletions(-) diff --git a/crates/chainrules/src/tests/behavior.rs b/crates/chainrules/src/tests/behavior.rs index ee52acf..bba6313 100644 --- a/crates/chainrules/src/tests/behavior.rs +++ b/crates/chainrules/src/tests/behavior.rs @@ -441,20 +441,20 @@ fn trig_and_hyperbolic_primal_entrypoints_match_std_ops() { } #[test] -fn extended_complex_unary_rules_conjugate_their_jacobians() { +fn extended_complex_unary_rules_use_standard_jvps_with_conjugate_rrules() { let x = Complex64::new(0.25, -0.5); let dx = Complex64::new(-0.75, 0.5); let cotangent = Complex64::new(0.5, -1.25); let (_sin_y, sin_dy) = sin_frule(x, dx); - assert_close_c64(sin_dy, dx * ComplexFloat::conj(ComplexFloat::cos(x))); + assert_close_c64(sin_dy, dx * ComplexFloat::cos(x)); assert_close_c64( sin_rrule(x, cotangent), cotangent * ComplexFloat::conj(ComplexFloat::cos(x)), ); let (_cos_y, cos_dy) = cos_frule(x, dx); - assert_close_c64(cos_dy, dx * ComplexFloat::conj(-ComplexFloat::sin(x))); + assert_close_c64(cos_dy, dx * -ComplexFloat::sin(x)); assert_close_c64( cos_rrule(x, cotangent), cotangent * ComplexFloat::conj(-ComplexFloat::sin(x)), @@ -462,19 +462,18 @@ fn extended_complex_unary_rules_conjugate_their_jacobians() { let tanh_y = ComplexFloat::tanh(x); let (_tanh_primal, tanh_dy) = tanh_frule(x, dx); - assert_close_c64( - tanh_dy, - dx * ComplexFloat::conj(Complex64::new(1.0, 0.0) - tanh_y * tanh_y), - ); + assert_close_c64(tanh_dy, dx * (Complex64::new(1.0, 0.0) - tanh_y * tanh_y)); assert_close_c64( tanh_rrule(tanh_y, cotangent), cotangent * ComplexFloat::conj(Complex64::new(1.0, 0.0) - tanh_y * tanh_y), ); let (_asinh_y, asinh_dy) = asinh_frule(x, dx); - let asinh_scale = ComplexFloat::conj( - Complex64::new(1.0, 0.0) / ComplexFloat::sqrt(Complex64::new(1.0, 0.0) + x * x), - ); + let asinh_scale = + Complex64::new(1.0, 0.0) / ComplexFloat::sqrt(Complex64::new(1.0, 0.0) + x * x); assert_close_c64(asinh_dy, dx * asinh_scale); - assert_close_c64(asinh_rrule(x, cotangent), cotangent * asinh_scale); + assert_close_c64( + asinh_rrule(x, cotangent), + cotangent * ComplexFloat::conj(asinh_scale), + ); } diff --git a/crates/chainrules/src/unary/basic.rs b/crates/chainrules/src/unary/basic.rs index 886c01b..7b76987 100644 --- a/crates/chainrules/src/unary/basic.rs +++ b/crates/chainrules/src/unary/basic.rs @@ -23,7 +23,7 @@ pub fn sqrt(x: S) -> S { /// Forward rule for `sqrt`. pub fn sqrt_frule(x: S, dx: S) -> (S, S) { let y = x.sqrt(); - let dy = dx / (S::from_i32(2) * y.conj()); + let dy = dx / (S::from_i32(2) * y); (y, dy) } diff --git a/crates/chainrules/src/unary/exp_log.rs b/crates/chainrules/src/unary/exp_log.rs index abb7457..7ebabbe 100644 --- a/crates/chainrules/src/unary/exp_log.rs +++ b/crates/chainrules/src/unary/exp_log.rs @@ -14,7 +14,7 @@ pub fn exp(x: S) -> S { /// Forward rule for `exp`. pub fn exp_frule(x: S, dx: S) -> (S, S) { let y = x.exp(); - (y, dx * y.conj()) + (y, dx * y) } /// Reverse rule for `exp`. pub fn exp_rrule(result: S, cotangent: S) -> S { @@ -28,7 +28,7 @@ pub fn expm1(x: S) -> S { pub fn expm1_frule(x: S, dx: S) -> (S, S) { let y = x.expm1(); let scale = y + one::(); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `exp(x) - 1`. pub fn expm1_rrule(result: S, cotangent: S) -> S { @@ -41,7 +41,7 @@ pub fn exp2(x: S) -> S { #[doc = "Forward rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_frule;\n\nlet (y, dy) = exp2_frule(3.0_f64, 1.0);\nassert!((y - 8.0).abs() < 1e-12);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] pub fn exp2_frule(x: S, dx: S) -> (S, S) { let y = x.exp2(); - (y, dx * (y * ln_2::()).conj()) + (y, dx * (y * ln_2::())) } #[doc = "Reverse rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_rrule;\n\nlet dy = exp2_rrule(8.0_f64, 1.0);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"] pub fn exp2_rrule(result: S, cotangent: S) -> S { @@ -54,7 +54,7 @@ pub fn exp10(x: S) -> S { #[doc = "Forward rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_frule;\n\nlet (y, dy) = exp10_frule(2.0_f64, 0.5);\nassert!((y - 100.0).abs() < 1e-12);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"] pub fn exp10_frule(x: S, dx: S) -> (S, S) { let y = x.exp10(); - (y, dx * (y * ln_10::()).conj()) + (y, dx * (y * ln_10::())) } #[doc = "Reverse rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_rrule;\n\nlet dy = exp10_rrule(100.0_f64, 0.5);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"] pub fn exp10_rrule(result: S, cotangent: S) -> S { @@ -67,7 +67,7 @@ pub fn log(x: S) -> S { /// Forward rule for `log`. pub fn log_frule(x: S, dx: S) -> (S, S) { let y = x.ln(); - let dy = dx * (one::() / x).conj(); + let dy = dx * (one::() / x); (y, dy) } /// Reverse rule for `log`. @@ -81,7 +81,7 @@ pub fn log1p(x: S) -> S { /// Forward rule for `log(1 + x)`. pub fn log1p_frule(x: S, dx: S) -> (S, S) { let y = x.log1p(); - let dy = dx * (one::() / (one::() + x)).conj(); + let dy = dx * (one::() / (one::() + x)); (y, dy) } /// Reverse rule for `log(1 + x)`. @@ -96,7 +96,7 @@ pub fn log2(x: S) -> S { pub fn log2_frule(x: S, dx: S) -> (S, S) { let y = x.log2(); let scale = one::() / (x * ln_2::()); - (y, dx * scale.conj()) + (y, dx * scale) } #[doc = "Reverse rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_rrule;\n\nlet dy = log2_rrule(8.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"] pub fn log2_rrule(x: S, cotangent: S) -> S { @@ -110,7 +110,7 @@ pub fn log10(x: S) -> S { pub fn log10_frule(x: S, dx: S) -> (S, S) { let y = x.log10(); let scale = one::() / (x * ln_10::()); - (y, dx * scale.conj()) + (y, dx * scale) } #[doc = "Reverse rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_rrule;\n\nlet dy = log10_rrule(100.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"] pub fn log10_rrule(x: S, cotangent: S) -> S { diff --git a/crates/chainrules/src/unary/hyperbolic.rs b/crates/chainrules/src/unary/hyperbolic.rs index 5d09e9c..07cfca5 100644 --- a/crates/chainrules/src/unary/hyperbolic.rs +++ b/crates/chainrules/src/unary/hyperbolic.rs @@ -10,7 +10,7 @@ pub fn tanh(x: S) -> S { pub fn tanh_frule(x: S, dx: S) -> (S, S) { let y = x.tanh(); let scale = one::() - y * y; - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `tanh`. @@ -26,7 +26,7 @@ pub fn sinh(x: S) -> S { /// Forward rule for `sinh`. pub fn sinh_frule(x: S, dx: S) -> (S, S) { let y = x.sinh(); - (y, dx * x.cosh().conj()) + (y, dx * x.cosh()) } /// Reverse rule for `sinh`. @@ -42,7 +42,7 @@ pub fn cosh(x: S) -> S { /// Forward rule for `cosh`. pub fn cosh_frule(x: S, dx: S) -> (S, S) { let y = x.cosh(); - (y, dx * x.sinh().conj()) + (y, dx * x.sinh()) } /// Reverse rule for `cosh`. @@ -63,7 +63,7 @@ pub fn asinh(x: S) -> S { pub fn asinh_frule(x: S, dx: S) -> (S, S) { let y = x.asinh(); let scale = inverse_sqrt_one_plus_square(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `asinh`. @@ -84,7 +84,7 @@ pub fn acosh(x: S) -> S { pub fn acosh_frule(x: S, dx: S) -> (S, S) { let y = x.acosh(); let scale = inverse_acosh_scale(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `acosh`. @@ -101,7 +101,7 @@ pub fn atanh(x: S) -> S { pub fn atanh_frule(x: S, dx: S) -> (S, S) { let y = x.atanh(); let scale = one::() / (one::() - x * x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `atanh`. diff --git a/crates/chainrules/src/unary/roots.rs b/crates/chainrules/src/unary/roots.rs index 18d7e7e..b1eed73 100644 --- a/crates/chainrules/src/unary/roots.rs +++ b/crates/chainrules/src/unary/roots.rs @@ -27,7 +27,7 @@ pub fn cbrt(x: S) -> S { pub fn cbrt_frule(x: S, dx: S) -> (S, S) { let y = x.cbrt(); let scale = S::from_i32(1) / (S::from_i32(3) * y * y); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `cbrt`. @@ -70,7 +70,7 @@ pub fn inv(x: S) -> S { /// ``` pub fn inv_frule(x: S, dx: S) -> (S, S) { let y = x.recip(); - (y, dx * (-(y * y)).conj()) + (y, dx * (-(y * y))) } /// Reverse rule for `inv`. diff --git a/crates/chainrules/src/unary/trig.rs b/crates/chainrules/src/unary/trig.rs index 3f73f1c..e740b1a 100644 --- a/crates/chainrules/src/unary/trig.rs +++ b/crates/chainrules/src/unary/trig.rs @@ -9,7 +9,7 @@ pub fn sin(x: S) -> S { /// Forward rule for `sin`. pub fn sin_frule(x: S, dx: S) -> (S, S) { let y = x.sin(); - (y, dx * x.cos().conj()) + (y, dx * x.cos()) } /// Reverse rule for `sin`. @@ -25,7 +25,7 @@ pub fn cos(x: S) -> S { /// Forward rule for `cos`. pub fn cos_frule(x: S, dx: S) -> (S, S) { let y = x.cos(); - (y, dx * (-x.sin()).conj()) + (y, dx * -x.sin()) } /// Reverse rule for `cos`. @@ -46,7 +46,7 @@ pub fn asin(x: S) -> S { pub fn asin_frule(x: S, dx: S) -> (S, S) { let y = x.asin(); let scale = inverse_sqrt_one_minus_square(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `asin`. @@ -63,7 +63,7 @@ pub fn acos(x: S) -> S { pub fn acos_frule(x: S, dx: S) -> (S, S) { let y = x.acos(); let scale = -inverse_sqrt_one_minus_square(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `acos`. @@ -80,7 +80,7 @@ pub fn atan(x: S) -> S { pub fn atan_frule(x: S, dx: S) -> (S, S) { let y = x.atan(); let scale = one::() / (one::() + x * x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `atan`. @@ -96,7 +96,7 @@ pub fn tan(x: S) -> S { #[doc = "Forward rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_frule;\n\nlet (y, dy) = tan_frule(0.25_f64, 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"] pub fn tan_frule(x: S, dx: S) -> (S, S) { let y = x.tan(); - (y, dx * (one::() + y * y).conj()) + (y, dx * (one::() + y * y)) } #[doc = "Reverse rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_rrule;\n\nlet dy = tan_rrule(0.25_f64.tan(), 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"] @@ -113,7 +113,7 @@ pub fn sincos(x: S) -> (S, S) { pub fn sincos_frule(x: S, dx: S) -> ((S, S), (S, S)) { let sin_x = x.sin(); let cos_x = x.cos(); - ((sin_x, cos_x), (dx * cos_x.conj(), dx * (-sin_x).conj())) + ((sin_x, cos_x), (dx * cos_x, dx * -sin_x)) } #[doc = "Reverse rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_rrule;\n\nlet dx = sincos_rrule(0.25_f64, (1.0, 1.0));\nassert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12);\n```"] diff --git a/crates/chainrules/src/unary/trig_extra.rs b/crates/chainrules/src/unary/trig_extra.rs index ec80b28..b6a52d8 100644 --- a/crates/chainrules/src/unary/trig_extra.rs +++ b/crates/chainrules/src/unary/trig_extra.rs @@ -229,7 +229,7 @@ pub fn sinpi(x: S) -> S { pub fn sinpi_frule(x: S, dx: S) -> (S, S) { let y = sinpi(x); let scale = pi::() * cospi(x); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `sinpi`. @@ -275,7 +275,7 @@ pub fn cospi(x: S) -> S { pub fn cospi_frule(x: S, dx: S) -> (S, S) { let y = cospi(x); let scale = -(pi::() * sinpi(x)); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `cospi`. @@ -323,10 +323,7 @@ pub fn sincospi_frule(x: S, dx: S) -> ((S, S), (S, S)) { let cos_x = cospi(x); ( (sin_x, cos_x), - ( - dx * (pi::() * cos_x).conj(), - dx * (-(pi::() * sin_x)).conj(), - ), + (dx * (pi::() * cos_x), dx * (-(pi::() * sin_x))), ) } @@ -473,7 +470,7 @@ pub fn tand(x: S) -> S { pub fn tand_frule(x: S, dx: S) -> (S, S) { let y = tand(x); let scale = deg2rad::() * (S::from_i32(1) + y * y); - (y, dx * scale.conj()) + (y, dx * scale) } /// Reverse rule for `tand`. diff --git a/crates/chainrules/tests/common.rs b/crates/chainrules/tests/common.rs index d9a91df..00b0c81 100644 --- a/crates/chainrules/tests/common.rs +++ b/crates/chainrules/tests/common.rs @@ -13,12 +13,6 @@ pub struct UnaryOracleCase { pub rrule: fn(T, T, T) -> T, } -pub struct UnaryReverseOracleCase { - pub op: &'static str, - pub primal: fn(T) -> T, - pub rrule: fn(T, T, T) -> T, -} - pub trait OracleScalar: Copy + core::fmt::Debug { fn dtype() -> &'static str; fn is_scalar_value(value: &Value) -> bool; @@ -232,70 +226,3 @@ pub fn run_unary_oracle_cases(cases: &[UnaryOracleCase]) { } } } - -// Published complex oracle JVPs use the plain holomorphic derivative, while this -// crate's ScalarAd forward convention applies the conjugated derivative. Keep -// the oracle replay reverse-only and cover the complex forward convention in -// repository-local formula tests. -pub fn run_unary_oracle_reverse_cases_complex64(cases: &[UnaryReverseOracleCase]) { - for case in cases { - for (case_index, oracle) in successful_cases(case.op, ::dtype()) - .into_iter() - .enumerate() - { - let inputs = - scalar_values::(&oracle["inputs"]["a"]["data"], "inputs.a.data"); - let probes = oracle["probes"] - .as_array() - .unwrap_or_else(|| panic!("expected probes array for {}", case.op)); - let atol = scalar_f64( - &oracle["comparison"]["first_order"]["atol"], - "comparison.first_order.atol", - ); - let rtol = scalar_f64( - &oracle["comparison"]["first_order"]["rtol"], - "comparison.first_order.rtol", - ); - - for (probe_index, probe) in probes.iter().enumerate() { - let cotangents = scalar_values::( - &probe["cotangent"]["value"]["data"], - &format!("probes[{probe_index}].cotangent.value.data"), - ); - let expected_vjps = scalar_values::( - &probe["pytorch_ref"]["vjp"]["a"]["data"], - &format!("probes[{probe_index}].pytorch_ref.vjp.a.data"), - ); - - assert_eq!( - inputs.len(), - cotangents.len(), - "{} case {case_index} probe {probe_index}: input and cotangent lengths differ", - case.op - ); - assert_eq!( - inputs.len(), - expected_vjps.len(), - "{} case {case_index} probe {probe_index}: input and expected vjp lengths differ", - case.op - ); - - for index in 0..inputs.len() { - let result = (case.primal)(inputs[index]); - let actual_vjp = (case.rrule)(inputs[index], result, cotangents[index]); - let label = format!( - "{} case {case_index} probe {probe_index} element {index}", - case.op - ); - ::assert_close( - actual_vjp, - expected_vjps[index], - atol, - rtol, - &label, - ); - } - } - } - } -} diff --git a/crates/chainrules/tests/julia_compat_trig_tests.rs b/crates/chainrules/tests/julia_compat_trig_tests.rs index 8f32195..51680c0 100644 --- a/crates/chainrules/tests/julia_compat_trig_tests.rs +++ b/crates/chainrules/tests/julia_compat_trig_tests.rs @@ -279,7 +279,7 @@ fn julia_compat_derivative_helpers_match_expected_values() { } #[test] -fn julia_compat_helpers_cover_complex_primal_and_cotangent_paths() { +fn julia_compat_helpers_cover_complex_primal_forward_and_cotangent_paths() { let z = Complex64::new(0.25, -0.1); let dz = Complex64::new(1.0, -0.25); let cotangent = Complex64::new(0.75, -0.5); @@ -291,11 +291,44 @@ fn julia_compat_helpers_cover_complex_primal_and_cotangent_paths() { let (_, dsinpi) = sinpi_frule(z, dz); assert_close_complex64( dsinpi, - dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()).conj(), + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), 1e-12, 0.0, "sinpi_frule(z)", ); + let (_, dcospi) = cospi_frule(z, dz); + assert_close_complex64( + dcospi, + dz * (-(Complex64::new(std::f64::consts::PI, 0.0) * pi_z.sin())), + 1e-12, + 0.0, + "cospi_frule(z)", + ); + let ((_, _), (dsinpi_pair, dcospi_pair)) = sincospi_frule(z, dz); + assert_close_complex64( + dsinpi_pair, + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), + 1e-12, + 0.0, + "sincospi_frule.sin(z)", + ); + assert_close_complex64( + dcospi_pair, + dz * (-(Complex64::new(std::f64::consts::PI, 0.0) * pi_z.sin())), + 1e-12, + 0.0, + "sincospi_frule.cos(z)", + ); + let (_, dtand) = tand_frule(z, dz); + let tand_z = (Complex64::new(std::f64::consts::PI / 180.0, 0.0) * z).tan(); + assert_close_complex64( + dtand, + dz * (Complex64::new(std::f64::consts::PI / 180.0, 0.0) + * (Complex64::new(1.0, 0.0) + tand_z * tand_z)), + 1e-12, + 0.0, + "tand_frule(z)", + ); assert_close_complex64( sec_rrule(z, cotangent), @@ -311,6 +344,23 @@ fn julia_compat_helpers_cover_complex_primal_and_cotangent_paths() { 0.0, "sinpi_rrule(z)", ); + let (_, dcot) = cot_frule(z, dz); + assert_close_complex64( + dcot, + dz * (-(Complex64::new(1.0, 0.0) / z.sin().powi(2))), + 1e-12, + 0.0, + "cot_frule(z)", + ); + let (_, dsech) = sech_frule(z, dz); + let sech_z = sech(z); + assert_close_complex64( + dsech, + dz * (-(sech_z * z.tanh())), + 1e-12, + 0.0, + "sech_frule(z)", + ); } #[test] diff --git a/crates/chainrules/tests/oracle_scalar_rules.rs b/crates/chainrules/tests/oracle_scalar_rules.rs index d43154a..0f3a01b 100644 --- a/crates/chainrules/tests/oracle_scalar_rules.rs +++ b/crates/chainrules/tests/oracle_scalar_rules.rs @@ -3,17 +3,13 @@ mod common; use chainrules::{ acos_frule, acos_rrule, acosh_frule, acosh_rrule, asin_frule, asin_rrule, asinh_frule, asinh_rrule, atan_frule, atan_rrule, atanh_frule, atanh_rrule, cos_frule, cos_rrule, - cosh_frule, cosh_rrule, exp2, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1_frule, - expm1_rrule, log1p_frule, log1p_rrule, log2, log2_frule, log2_rrule, log_frule, log_rrule, - sin_frule, sin_rrule, sinh_frule, sinh_rrule, sqrt_frule, sqrt_rrule, tan, tan_frule, - tan_rrule, tanh_frule, tanh_rrule, + cosh_frule, cosh_rrule, exp2_frule, exp2_rrule, exp_frule, exp_rrule, expm1_frule, expm1_rrule, + log1p_frule, log1p_rrule, log2_frule, log2_rrule, log_frule, log_rrule, sin_frule, sin_rrule, + sinh_frule, sinh_rrule, sqrt_frule, sqrt_rrule, tan_frule, tan_rrule, tanh_frule, tanh_rrule, }; use num_complex::Complex64; -use common::{ - run_unary_oracle_cases, run_unary_oracle_reverse_cases_complex64, UnaryOracleCase, - UnaryReverseOracleCase, -}; +use common::{run_unary_oracle_cases, UnaryOracleCase}; #[test] fn published_float64_oracles_match_unary_rule_entrypoints() { @@ -119,24 +115,24 @@ fn published_float64_oracles_match_unary_rule_entrypoints() { } #[test] -fn published_complex128_oracles_match_unary_rule_entrypoints_reverse_only() { - let cases = [ - UnaryReverseOracleCase { +fn published_complex128_oracles_match_unary_rule_entrypoints() { + let cases: [UnaryOracleCase; 3] = [ + UnaryOracleCase { op: "tan", - primal: tan, + frule: tan_frule, rrule: |_x: Complex64, result, cotangent| tan_rrule(result, cotangent), }, - UnaryReverseOracleCase { + UnaryOracleCase { op: "exp2", - primal: exp2, + frule: exp2_frule, rrule: |_x: Complex64, result, cotangent| exp2_rrule(result, cotangent), }, - UnaryReverseOracleCase { + UnaryOracleCase { op: "log2", - primal: log2, + frule: log2_frule, rrule: |x: Complex64, _result, cotangent| log2_rrule(x, cotangent), }, ]; - run_unary_oracle_reverse_cases_complex64(&cases); + run_unary_oracle_cases(&cases); } diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index 00d6131..ec387a2 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -174,27 +174,42 @@ fn smooth_basis_frules_and_rrules_match_expected_derivatives() { } #[test] -fn smooth_basis_complex_frules_match_expected_derivatives() { +fn smooth_basis_complex_frules_match_standard_jvps() { let z = Complex64::new(0.25, -0.5); let dz = Complex64::new(0.5, -0.25); let (tan_y, tan_dy) = tan_frule(z, dz); - let tan_scale = (Complex64::new(1.0, 0.0) + tan_y * tan_y).conj(); + let tan_scale = Complex64::new(1.0, 0.0) + tan_y * tan_y; assert_close_complex64(tan_y, z.tan(), 1.0e-12, 0.0, "tan.z"); assert_close_complex64(tan_dy, dz * tan_scale, 1.0e-12, 0.0, "tan.dz"); let (exp2_y, exp2_dy) = exp2_frule(z, dz); - let exp2_scale = (exp2_y * Complex64::new(std::f64::consts::LN_2, 0.0)).conj(); + let exp2_scale = exp2_y * Complex64::new(std::f64::consts::LN_2, 0.0); assert_close_complex64(exp2_y, z.exp2(), 1.0e-12, 0.0, "exp2.z"); assert_close_complex64(exp2_dy, dz * exp2_scale, 1.0e-12, 0.0, "exp2.dz"); let (log2_y, log2_dy) = log2_frule(z, dz); - let log2_scale = - (Complex64::new(1.0, 0.0) / (z * Complex64::new(std::f64::consts::LN_2, 0.0))).conj(); + let log2_scale = Complex64::new(1.0, 0.0) / (z * Complex64::new(std::f64::consts::LN_2, 0.0)); assert_close_complex64(log2_y, z.log2(), 1.0e-12, 0.0, "log2.z"); assert_close_complex64(log2_dy, dz * log2_scale, 1.0e-12, 0.0, "log2.dz"); } +#[test] +fn smooth_basis_complex_frules_cover_additional_standard_jvps() { + let z = Complex64::new(0.25, -0.5); + let dz = Complex64::new(0.5, -0.25); + + let ((sin_y, cos_y), (dsin_dz, dcos_dz)) = sincos_frule(z, dz); + assert_close_complex64(sin_y, z.sin(), 1.0e-12, 0.0, "sincos.z.sin"); + assert_close_complex64(cos_y, z.cos(), 1.0e-12, 0.0, "sincos.z.cos"); + assert_close_complex64(dsin_dz, dz * z.cos(), 1.0e-12, 0.0, "sincos.dz.sin"); + assert_close_complex64(dcos_dz, dz * -z.sin(), 1.0e-12, 0.0, "sincos.dz.cos"); + + let (inv_y, inv_dy) = inv_frule(z, dz); + assert_close_complex64(inv_y, z.recip(), 1.0e-12, 0.0, "inv.z"); + assert_close_complex64(inv_dy, dz * (-(inv_y * inv_y)), 1.0e-12, 0.0, "inv.dz"); +} + #[test] fn pow_rules_handle_zero_and_negative_real_paths() { let (neg_y, neg_dy) = pow_frule(-2.0_f64, 3.0_f64, 1.0_f64, 0.0_f64); From c4966b5fc81a1fda0e4cd01233444b88b6bc0402 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 18:39:14 +0900 Subject: [PATCH 30/32] refactor: align binary and power complex frules --- crates/chainrules/src/binary.rs | 6 +-- crates/chainrules/src/power.rs | 45 ++++++++----------- crates/chainrules/tests/scalarops_tests.rs | 18 ++++---- crates/chainrules/tests/smooth_basis_tests.rs | 4 +- 4 files changed, 33 insertions(+), 40 deletions(-) diff --git a/crates/chainrules/src/binary.rs b/crates/chainrules/src/binary.rs index 9d21d4c..9ce701a 100644 --- a/crates/chainrules/src/binary.rs +++ b/crates/chainrules/src/binary.rs @@ -122,7 +122,7 @@ pub fn mul(x: S, y: S) -> S { /// ``` pub fn mul_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { let primal = x * y; - let tangent = dx * y.conj() + dy * x.conj(); + let tangent = dx * y + dy * x; (primal, tangent) } @@ -175,8 +175,8 @@ pub fn div(x: S, y: S) -> S { pub fn div_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { let primal = x / y; let inv_y = S::from_i32(1) / y; - let dfdx = inv_y.conj(); - let dfdy = (-(x * inv_y * inv_y)).conj(); + let dfdx = inv_y; + let dfdy = -(x * inv_y * inv_y); let tangent = dx * dfdx + dy * dfdy; (primal, tangent) } diff --git a/crates/chainrules/src/power.rs b/crates/chainrules/src/power.rs index b93d799..9a190f6 100644 --- a/crates/chainrules/src/power.rs +++ b/crates/chainrules/src/power.rs @@ -3,27 +3,17 @@ use num_traits::{Float, One, Zero}; use crate::ScalarAd; /// Primal `powf`. -/// -/// # Examples -/// /// ```rust /// use chainrules::powf; -/// /// assert_eq!(powf(2.0_f64, 3.0), 8.0); /// ``` pub fn powf(x: S, exponent: S::Real) -> S { x.powf(exponent) } -/// Forward rule for `powf` with fixed exponent. -/// -/// Returns `(primal, tangent)`. -/// -/// # Examples -/// +/// Forward rule for `powf` with fixed exponent. Returns `(primal, tangent)`. /// ```rust /// use chainrules::powf_frule; -/// /// let (y, dy) = powf_frule(2.0_f64, 3.0, 1.0); /// assert_eq!(y, 8.0); /// assert_eq!(dy, 12.0); @@ -33,7 +23,7 @@ pub fn powf_frule(x: S, exponent: S::Real, dx: S) -> (S, S) { let dy = if exponent == S::Real::zero() { S::from_real(S::Real::zero()) } else { - dx * (S::from_real(exponent) * x.powf(exponent - S::Real::one())).conj() + dx * (S::from_real(exponent) * x.powf(exponent - S::Real::one())) }; (y, dy) } @@ -56,27 +46,17 @@ pub fn powf_rrule(x: S, exponent: S::Real, cotangent: S) -> S { } /// Primal `powi`. -/// -/// # Examples -/// /// ```rust /// use chainrules::powi; -/// /// assert_eq!(powi(2.0_f64, 4), 16.0); /// ``` pub fn powi(x: S, exponent: i32) -> S { x.powi(exponent) } -/// Forward rule for `powi` with fixed integer exponent. -/// -/// Returns `(primal, tangent)`. -/// -/// # Examples -/// +/// Forward rule for `powi` with fixed integer exponent. Returns `(primal, tangent)`. /// ```rust /// use chainrules::powi_frule; -/// /// let (y, dy) = powi_frule(2.0_f64, 4, 1.0); /// assert_eq!(y, 16.0); /// assert_eq!(dy, 32.0); @@ -86,7 +66,7 @@ pub fn powi_frule(x: S, exponent: i32, dx: S) -> (S, S) { let dy = if exponent == 0 { S::from_i32(0) } else { - dx * (S::from_i32(exponent) * x.powi(exponent - 1)).conj() + dx * (S::from_i32(exponent) * x.powi(exponent - 1)) }; (y, dy) } @@ -142,12 +122,25 @@ pub fn pow_frule(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) let dfdx = if dx == zero::() { zero::() } else { - dx * pow_x_scale(x, exponent) + dx * if exponent == zero::() { + zero::() + } else { + exponent * x.pow(exponent - S::from_i32(1)) + } }; let dfde = if dexponent == zero::() { zero::() } else { - dexponent * pow_exp_scale(x, exponent) + dexponent + * if x == zero::() && exponent.imag() == S::Real::zero() { + if exponent.real() > S::Real::zero() { + zero::() + } else { + nan::() + } + } else { + x.pow(exponent) * x.ln() + } }; (y, dfdx + dfde) } diff --git a/crates/chainrules/tests/scalarops_tests.rs b/crates/chainrules/tests/scalarops_tests.rs index 11216bf..a80ef5b 100644 --- a/crates/chainrules/tests/scalarops_tests.rs +++ b/crates/chainrules/tests/scalarops_tests.rs @@ -128,7 +128,7 @@ fn mul_div_rules_match_formula_complex64() { let (mul_y, mul_dy) = mul_frule(x, y, dx, dy); assert!((mul_y - (x * y)).norm() < 1e-12); - let expected_mul_tangent = dx * y.conj() + dy * x.conj(); + let expected_mul_tangent = dx * y + dy * x; assert!((mul_dy - expected_mul_tangent).norm() < 1e-12); let (mul_dx, mul_dy_rr) = mul_rrule(x, y, g); assert!((mul_dx - g * y.conj()).norm() < 1e-12); @@ -136,8 +136,8 @@ fn mul_div_rules_match_formula_complex64() { let (div_y, div_dy) = div_frule(x, y, dx, dy); assert!((div_y - (x / y)).norm() < 1e-12); - let expected_div_tangent = dx * (Complex64::new(1.0, 0.0) / y).conj() - + dy * ((Complex64::new(-1.0, 0.0) * x) / (y * y)).conj(); + let expected_div_tangent = + dx * (Complex64::new(1.0, 0.0) / y) + dy * ((Complex64::new(-1.0, 0.0) * x) / (y * y)); assert!((div_dy - expected_div_tangent).norm() < 1e-12); let (div_dx, div_dy_rr) = div_rrule(x, y, g); assert!((div_dx - g * (Complex64::new(1.0, 0.0) / y).conj()).norm() < 1e-12); @@ -155,11 +155,11 @@ fn powi_rules_match_formula_complex64() { let expected_y = x * x * x; assert!((y - expected_y).norm() < 1e-12); - let expected_scale = (Complex64::new(3.0, 0.0) * x.powi(2)).conj(); + let expected_scale = Complex64::new(3.0, 0.0) * x.powi(2); assert!((dy - (dx * expected_scale)).norm() < 1e-12); let grad = powi_rrule(x, exponent, g); - assert!((grad - (g * expected_scale)).norm() < 1e-12); + assert!((grad - (g * expected_scale.conj())).norm() < 1e-12); } #[test] @@ -204,16 +204,16 @@ fn complex_frules_and_rrules_cover_from_real_paths() { let g32 = Complex32::new(0.6, -0.2); let (_y32, dy32) = powf_frule(x32, 2.0_f32, dx32); let grad32 = powf_rrule(x32, 2.0_f32, g32); - let expected_scale32 = (Complex32::new(2.0, 0.0) * x32.powf(1.0_f32)).conj(); + let expected_scale32 = Complex32::new(2.0, 0.0) * x32.powf(1.0_f32); assert!((dy32 - dx32 * expected_scale32).norm() < 1e-5); - assert!((grad32 - g32 * expected_scale32).norm() < 1e-5); + assert!((grad32 - g32 * expected_scale32.conj()).norm() < 1e-5); let x64 = Complex64::new(-0.8, 1.1); let dx64 = Complex64::new(0.9, -0.4); let g64 = Complex64::new(0.3, 0.2); let (_y64, dy64) = powi_frule(x64, 4, dx64); let grad64 = powi_rrule(x64, 4, g64); - let expected_scale64 = (Complex64::new(4.0, 0.0) * x64.powi(3)).conj(); + let expected_scale64 = Complex64::new(4.0, 0.0) * x64.powi(3); assert!((dy64 - dx64 * expected_scale64).norm() < 1e-12); - assert!((grad64 - g64 * expected_scale64).norm() < 1e-12); + assert!((grad64 - g64 * expected_scale64.conj()).norm() < 1e-12); } diff --git a/crates/chainrules/tests/smooth_basis_tests.rs b/crates/chainrules/tests/smooth_basis_tests.rs index ec387a2..f5b7bb5 100644 --- a/crates/chainrules/tests/smooth_basis_tests.rs +++ b/crates/chainrules/tests/smooth_basis_tests.rs @@ -255,8 +255,8 @@ fn pow_rules_cover_complex_frule_and_rrule_paths() { let (y, dy) = pow_frule(x, exponent, dx, dexp); let expected_y = x.powc(exponent); - let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))).conj() - + dexp * (expected_y * x.ln()).conj(); + let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))) + + dexp * (expected_y * x.ln()); assert!((y - expected_y).norm() < 1.0e-12); assert!((dy - expected_dy).norm() < 1.0e-12); From e17cb82483e7c53b48eb8c6e0ea59a6d269fed1b Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 18:47:00 +0900 Subject: [PATCH 31/32] cleanup: remove projection helpers and refresh docs --- README.md | 8 ++--- crates/chainrules/README.md | 15 +++++----- crates/chainrules/src/lib.rs | 2 +- crates/chainrules/src/scalar_ad/mod.rs | 34 ---------------------- crates/chainrules/src/tests/behavior.rs | 12 +++----- crates/chainrules/tests/scalarops_tests.rs | 14 ++------- 6 files changed, 20 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index b33d655..78a4c1c 100644 --- a/README.md +++ b/README.md @@ -70,10 +70,10 @@ Scalar rules are checked in complementary ways: `crates/chainrules/tests/complex_helper_tests.rs` - oracle replay tests in `crates/chainrules/tests/oracle_scalar_rules.rs` against vendored published cases from `third_party/tensor-ad-oracles`, - including direct float64 replay and selected complex128 reverse-mode replay; - complex forward-mode checks stay in repository-local formula tests because the - crate's `frule` convention intentionally differs from the published JVP - convention + including direct float64 replay and selected direct Complex64 replay for + `tan`, `exp2`, and `log2`; complex + forward-mode checks use the standard JVP convention on `C ~= R^2`, while + complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses ```bash cargo test --workspace --release diff --git a/crates/chainrules/README.md b/crates/chainrules/README.md index 03613e9..d3106d2 100644 --- a/crates/chainrules/README.md +++ b/crates/chainrules/README.md @@ -34,12 +34,12 @@ Current shipped scalar families: - Julia-compatible hyperbolic helpers: `sech`, `csch`, `coth` - non-smooth real helpers: `round`, `floor`, `ceil`, `sign`, `min`, `max` - smooth helpers: `cbrt`, `inv`, `exp2`, `exp10`, `log2`, `log10`, `hypot`, `pow`, `sincos`, `tan` -- complex and projection helpers: `conj`, `abs`, `abs2`, `angle`, `real`, `imag`, `complex`, `handle_r_to_c_f32`, `handle_r_to_c_f64` +- complex and projection helpers: `conj`, `abs`, `abs2`, `angle`, `real`, `imag`, `complex` - real-valued binary helpers: `atan2` -This crate is intended as a landing zone for scalar rules ported or adapted -from Julia's `ChainRules.jl` where they fit this crate boundary, but it is not -a full port of `ChainRules.jl`. +This crate is a landing zone for scalar rules ported or adapted from Julia's +`ChainRules.jl` where they fit this crate boundary, but it is not a full port +of `ChainRules.jl`. ## Validation @@ -56,9 +56,10 @@ repository-local tests: constructor surface - `tests/oracle_scalar_rules.rs` replays vendored published oracle cases from `../../third_party/tensor-ad-oracles`, with direct float64 replay and - selected Complex64 reverse-mode replay -- complex forward-mode checks that follow this crate's `ScalarAd` convention - stay in repository-local formula tests such as `tests/smooth_basis_tests.rs` + selected direct Complex64 replay for `tan`, `exp2`, and `log2` +- complex forward-mode checks use the standard JVP convention on `C ~= R^2` + in repository-local formula tests such as `tests/smooth_basis_tests.rs` +- complex reverse-mode checks remain conjugate-Wirtinger for real-valued losses ## Examples diff --git a/crates/chainrules/src/lib.rs b/crates/chainrules/src/lib.rs index aba0777..c326045 100644 --- a/crates/chainrules/src/lib.rs +++ b/crates/chainrules/src/lib.rs @@ -24,7 +24,7 @@ pub use power::{powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule}; #[doc(inline)] pub use real_ops::{atan2, atan2_frule, atan2_rrule}; #[doc(inline)] -pub use scalar_ad::{handle_r_to_c_f32, handle_r_to_c_f64, ScalarAd}; +pub use scalar_ad::ScalarAd; #[doc(inline)] pub use unary::{ abs, abs2, abs2_frule, abs2_rrule, acos, acos_frule, acos_rrule, acosh, acosh_frule, diff --git a/crates/chainrules/src/scalar_ad/mod.rs b/crates/chainrules/src/scalar_ad/mod.rs index a40d67b..8db3e32 100644 --- a/crates/chainrules/src/scalar_ad/mod.rs +++ b/crates/chainrules/src/scalar_ad/mod.rs @@ -1,6 +1,5 @@ use core::ops::{Add, Div, Mul, Neg, Sub}; -use num_complex::{Complex32, Complex64}; use num_traits::{Float, FloatConst}; /// Scalar trait used by elementary AD rule helpers. @@ -132,36 +131,3 @@ pub trait ScalarAd: mod complex; mod real; - -/// PyTorch-style real-input / complex-gradient projection helper (`handle_r_to_c`). -/// -/// This is equivalent to taking the real part when a gradient for real input -/// becomes complex during intermediate algebra. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::handle_r_to_c_f64; -/// use num_complex::Complex64; -/// -/// let g = Complex64::new(1.25, -3.0); -/// assert_eq!(handle_r_to_c_f64(g), 1.25); -/// ``` -pub fn handle_r_to_c_f64(gradient: Complex64) -> f64 { - gradient.re -} - -/// `f32` variant of [`handle_r_to_c_f64`]. -/// -/// # Examples -/// -/// ```rust -/// use chainrules::handle_r_to_c_f32; -/// use num_complex::Complex32; -/// -/// let g = Complex32::new(2.0, 4.0); -/// assert_eq!(handle_r_to_c_f32(g), 2.0); -/// ``` -pub fn handle_r_to_c_f32(gradient: Complex32) -> f32 { - gradient.re -} diff --git a/crates/chainrules/src/tests/behavior.rs b/crates/chainrules/src/tests/behavior.rs index bba6313..60ea911 100644 --- a/crates/chainrules/src/tests/behavior.rs +++ b/crates/chainrules/src/tests/behavior.rs @@ -4,10 +4,9 @@ use crate::{ acos, acos_frule, acos_rrule, acosh, acosh_frule, acosh_rrule, asin, asin_frule, asin_rrule, asinh, asinh_frule, asinh_rrule, atan, atan2, atan2_frule, atan_frule, atan_rrule, atanh, atanh_frule, atanh_rrule, conj, conj_frule, conj_rrule, cos, cos_frule, cos_rrule, cosh, - cosh_frule, cosh_rrule, exp, exp_frule, exp_rrule, expm1_frule, expm1_rrule, handle_r_to_c_f32, - handle_r_to_c_f64, log, log1p_frule, log1p_rrule, log_frule, log_rrule, sin, sin_frule, - sin_rrule, sinh, sinh_frule, sinh_rrule, sqrt, sqrt_frule, sqrt_rrule, tanh, tanh_frule, - tanh_rrule, ScalarAd, + cosh_frule, cosh_rrule, exp, exp_frule, exp_rrule, expm1_frule, expm1_rrule, log, log1p_frule, + log1p_rrule, log_frule, log_rrule, sin, sin_frule, sin_rrule, sinh, sinh_frule, sinh_rrule, + sqrt, sqrt_frule, sqrt_rrule, tanh, tanh_frule, tanh_rrule, ScalarAd, }; fn assert_close_f32(actual: f32, expected: f32) { @@ -271,10 +270,7 @@ fn scalar_ad_complex_extended_surface_matches_std_complex_ops() { } #[test] -fn direct_entrypoints_match_real_projection_and_atan2_formulas() { - assert_eq!(handle_r_to_c_f32(Complex32::new(2.0, -5.0)), 2.0); - assert_eq!(handle_r_to_c_f64(Complex64::new(-3.0, 1.5)), -3.0); - +fn direct_entrypoints_match_atan2_formulas() { let primal = atan2(3.0_f64, 4.0_f64); assert_close_f64(primal, 3.0_f64.atan2(4.0)); diff --git a/crates/chainrules/tests/scalarops_tests.rs b/crates/chainrules/tests/scalarops_tests.rs index a80ef5b..8ae3084 100644 --- a/crates/chainrules/tests/scalarops_tests.rs +++ b/crates/chainrules/tests/scalarops_tests.rs @@ -1,18 +1,10 @@ use chainrules::{ - add, add_frule, add_rrule, conj, conj_frule, conj_rrule, div, div_frule, div_rrule, - handle_r_to_c_f32, handle_r_to_c_f64, mul, mul_frule, mul_rrule, powf, powf_frule, powf_rrule, - powi, powi_frule, powi_rrule, sqrt, sqrt_frule, sqrt_rrule, sub, sub_frule, sub_rrule, + add, add_frule, add_rrule, conj, conj_frule, conj_rrule, div, div_frule, div_rrule, mul, + mul_frule, mul_rrule, powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule, sqrt, + sqrt_frule, sqrt_rrule, sub, sub_frule, sub_rrule, }; use num_complex::{Complex32, Complex64}; -#[test] -fn handle_r_to_c_projects_real_part() { - let g32 = Complex32::new(1.25, -9.0); - let g64 = Complex64::new(-3.5, 2.0); - assert_eq!(handle_r_to_c_f32(g32), 1.25_f32); - assert_eq!(handle_r_to_c_f64(g64), -3.5_f64); -} - #[test] fn conj_rules_match_formula_complex64() { let x = Complex64::new(2.0, -3.0); From 4056fb8977e2def2ee822fa5217965f380f43d5b Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 21 Mar 2026 18:47:32 +0900 Subject: [PATCH 32/32] docs: add complex jvp alignment plan --- .../plans/2026-03-21-complex-jvp-alignment.md | 437 ++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 docs/plans/2026-03-21-complex-jvp-alignment.md diff --git a/docs/plans/2026-03-21-complex-jvp-alignment.md b/docs/plans/2026-03-21-complex-jvp-alignment.md new file mode 100644 index 0000000..9869cee --- /dev/null +++ b/docs/plans/2026-03-21-complex-jvp-alignment.md @@ -0,0 +1,437 @@ +# Complex JVP Alignment Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Align complex `frule` semantics with PyTorch and standard JVP over `C ~= R^2`, keep complex `rrule` on the existing conjugate-Wirtinger convention for real-valued losses, remove the unused `handle_r_to_c_*` helpers, and validate complex forward rules directly against vendored oracle JVPs. + +**Architecture:** Treat this as a family-wide semantic bug, not a one-off formula mismatch. Any complex `frule` that currently applies `conj(local_derivative)` must switch to the plain pushforward `J * dx`, while `rrule` keeps the existing `J^H * cotangent` behavior. Audit the shared scalar basis by module family, update local formula tests first, then re-enable direct complex oracle replay so the repository validates both float64 and Complex64 entrypoints against the same published PyTorch JVP/VJP data. + +**Tech Stack:** Rust 2021, `chainrules`, `chainrules-core`, `num-complex`, vendored `tensor-ad-oracles`, `cargo fmt`, `cargo nextest`, `cargo test --doc`, `cargo clippy`, `cargo llvm-cov`. + +--- + +### Task 1: Re-enable Direct Complex Oracle Replay + +**Files:** +- Modify: `crates/chainrules/tests/common.rs` +- Modify: `crates/chainrules/tests/oracle_scalar_rules.rs` +- Modify: `crates/chainrules/src/unary/trig.rs` +- Modify: `crates/chainrules/src/unary/exp_log.rs` + +**Step 1: Write the failing test** + +Replace the reverse-only Complex64 oracle entrypoint with a direct `UnaryOracleCase` replay for the currently covered ops: + +```rust +#[test] +fn published_complex128_oracles_match_unary_rule_entrypoints() { + let cases: [UnaryOracleCase; 3] = [ + UnaryOracleCase { + op: "tan", + frule: tan_frule, + rrule: |_x: Complex64, result, cotangent| tan_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "exp2", + frule: exp2_frule, + rrule: |_x: Complex64, result, cotangent| exp2_rrule(result, cotangent), + }, + UnaryOracleCase { + op: "log2", + frule: log2_frule, + rrule: |x: Complex64, _result, cotangent| log2_rrule(x, cotangent), + }, + ]; + + run_unary_oracle_cases(&cases); +} +``` + +Delete the reverse-only helper path and its comment in `tests/common.rs`. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test oracle_scalar_rules +``` + +Expected: FAIL with Complex64 JVP mismatches for `tan`, `exp2`, or `log2`. + +**Step 3: Write minimal implementation** + +Make the smallest source change that matches the oracle: + +```rust +pub fn tan_frule(x: S, dx: S) -> (S, S) { + let y = x.tan(); + (y, dx * (one::() + y * y)) +} + +pub fn exp2_frule(x: S, dx: S) -> (S, S) { + let y = x.exp2(); + (y, dx * (y * ln_2::())) +} + +pub fn log2_frule(x: S, dx: S) -> (S, S) { + let y = x.log2(); + let scale = one::() / (x * ln_2::()); + (y, dx * scale) +} +``` + +Keep the corresponding `rrule` implementations unchanged. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test oracle_scalar_rules +``` + +Expected: PASS for the Complex64 oracle replay test. + +**Step 5: Commit** + +```bash +git add crates/chainrules/tests/common.rs crates/chainrules/tests/oracle_scalar_rules.rs crates/chainrules/src/unary/trig.rs crates/chainrules/src/unary/exp_log.rs +git commit -m "refactor: replay complex unary oracles with direct jvp" +``` + +### Task 2: Align Smooth Unary Complex `frule` Families + +**Files:** +- Modify: `crates/chainrules/src/unary/basic.rs` +- Modify: `crates/chainrules/src/unary/exp_log.rs` +- Modify: `crates/chainrules/src/unary/trig.rs` +- Modify: `crates/chainrules/src/unary/hyperbolic.rs` +- Modify: `crates/chainrules/src/unary/roots.rs` +- Modify: `crates/chainrules/src/tests/behavior.rs` +- Modify: `crates/chainrules/tests/smooth_basis_tests.rs` + +**Step 1: Write the failing test** + +Rename the behavior test so it states the new rule, then switch the expectations to plain JVP: + +```rust +#[test] +fn extended_complex_unary_frules_match_standard_jvp_while_rrules_stay_conjugate() { + let x = Complex64::new(0.25, -0.5); + let dx = Complex64::new(-0.75, 0.5); + let cotangent = Complex64::new(0.5, -1.25); + + let (_sin_y, sin_dy) = sin_frule(x, dx); + assert_close_c64(sin_dy, dx * ComplexFloat::cos(x)); + assert_close_c64( + sin_rrule(x, cotangent), + cotangent * ComplexFloat::conj(ComplexFloat::cos(x)), + ); +} +``` + +Update `smooth_basis_complex_frules_match_expected_derivatives` the same way for `tan_frule`, `exp2_frule`, and `log2_frule`: remove `.conj()` from the expected forward scales but keep the reverse expectations unchanged. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test smooth_basis_tests --test chainrules \ + smooth_basis_complex_frules_match_expected_derivatives \ + extended_complex_unary_frules_match_standard_jvp_while_rrules_stay_conjugate +``` + +Expected: FAIL because the remaining unary `frule` implementations still conjugate their local derivatives. + +**Step 3: Write minimal implementation** + +Remove forward-mode conjugation from every smooth unary helper in these files: + +```rust +pub fn sqrt_frule(x: S, dx: S) -> (S, S) { + let y = x.sqrt(); + let dy = dx / (S::from_i32(2) * y); + (y, dy) +} + +pub fn exp_frule(x: S, dx: S) -> (S, S) { + let y = x.exp(); + (y, dx * y) +} + +pub fn asinh_frule(x: S, dx: S) -> (S, S) { + let y = x.asinh(); + let scale = inverse_sqrt_one_plus_square(x); + (y, dx * scale) +} +``` + +Audit every `*_frule` in the listed modules and change only the forward formulas. Keep `rrule` code as-is. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test smooth_basis_tests +cargo nextest run --release --test chainrules extended_complex_unary_frules_match_standard_jvp_while_rrules_stay_conjugate +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/unary/basic.rs crates/chainrules/src/unary/exp_log.rs crates/chainrules/src/unary/trig.rs crates/chainrules/src/unary/hyperbolic.rs crates/chainrules/src/unary/roots.rs crates/chainrules/src/tests/behavior.rs crates/chainrules/tests/smooth_basis_tests.rs +git commit -m "refactor: align smooth unary complex frules with standard jvp" +``` + +### Task 3: Align Binary And Power Complex `frule` Families + +**Files:** +- Modify: `crates/chainrules/src/binary.rs` +- Modify: `crates/chainrules/src/power.rs` +- Modify: `crates/chainrules/tests/scalarops_tests.rs` +- Modify: `crates/chainrules/tests/smooth_basis_tests.rs` + +**Step 1: Write the failing test** + +Change the complex forward expectations so they use the plain Jacobian instead of the conjugated one: + +```rust +#[test] +fn mul_div_rules_match_formula_complex64() { + let x = Complex64::new(1.5, -0.5); + let y = Complex64::new(-0.25, 2.0); + let dx = Complex64::new(0.3, -0.2); + let dy = Complex64::new(-0.1, 0.4); + + let (mul_y, mul_dy) = mul_frule(x, y, dx, dy); + assert!((mul_y - (x * y)).norm() < 1e-12); + assert!((mul_dy - (dx * y + dy * x)).norm() < 1e-12); +} +``` + +Update the power tests likewise: + +```rust +let expected_dy = dx * (exponent * x.powc(exponent - Complex64::new(1.0, 0.0))) + + dexp * (expected_y * x.ln()); +assert!((dy - expected_dy).norm() < 1e-12); +``` + +Keep all `rrule` expectations on the conjugated derivative. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test scalarops_tests --test smooth_basis_tests +``` + +Expected: FAIL in the Complex64 forward checks for `mul`, `div`, `powf`, `powi`, or `pow`. + +**Step 3: Write minimal implementation** + +Update only the forward formulas: + +```rust +pub fn mul_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { + let primal = x * y; + let tangent = dx * y + dy * x; + (primal, tangent) +} + +pub fn div_frule(x: S, y: S, dx: S, dy: S) -> (S, S) { + let primal = x / y; + let inv_y = one::() / y; + let dfdx = inv_y; + let dfdy = -(x * inv_y * inv_y); + (primal, dx * dfdx + dy * dfdy) +} +``` + +Mirror the same change in `powf_frule`, `powi_frule`, and `pow_frule`. Keep the singularity behavior at `pow(0, 0)` and zero-base negative exponents unchanged. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test scalarops_tests +cargo nextest run --release --test smooth_basis_tests +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/binary.rs crates/chainrules/src/power.rs crates/chainrules/tests/scalarops_tests.rs crates/chainrules/tests/smooth_basis_tests.rs +git commit -m "refactor: align binary and power complex frules with standard jvp" +``` + +### Task 4: Align Julia Compatibility Helpers With Standard JVP + +**Files:** +- Modify: `crates/chainrules/src/unary/trig_extra.rs` +- Modify: `crates/chainrules/src/unary/hyperbolic_extra.rs` +- Modify: `crates/chainrules/tests/julia_compat_trig_tests.rs` + +**Step 1: Write the failing test** + +Change the Complex64 forward expectations in the Julia-compat tests and add one extra representative check so the family is covered: + +```rust +let (_, dsinpi) = sinpi_frule(z, dz); +assert_close_complex64( + dsinpi, + dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), + 1e-12, + 0.0, + "sinpi_frule(z)", +); + +let ((_sin_y, _cos_y), (dsin, dcos)) = sincospi_frule(z, dz); +assert_close_complex64(dsin, dz * (Complex64::new(std::f64::consts::PI, 0.0) * pi_z.cos()), 1e-12, 0.0, "sincospi_frule.sin"); +assert_close_complex64(dcos, dz * (-(Complex64::new(std::f64::consts::PI, 0.0) * pi_z.sin())), 1e-12, 0.0, "sincospi_frule.cos"); +``` + +Keep the `rrule` expectations on the conjugated derivative. + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test julia_compat_trig_tests +``` + +Expected: FAIL for Complex64 forward checks in `sinpi_frule`, `sincospi_frule`, or related helpers. + +**Step 3: Write minimal implementation** + +Audit the Julia migration helpers and remove forward-mode conjugation everywhere it still appears: + +```rust +pub fn sinpi_frule(x: S, dx: S) -> (S, S) { + let y = sinpi(x); + let scale = pi::() * cospi(x); + (y, dx * scale) +} +``` + +Repeat the same cleanup for `cospi_frule`, `sincospi_frule`, `tand_frule`, `sec_frule`, `csc_frule`, `cot_frule`, `sech_frule`, `csch_frule`, and `coth_frule`. If a degree-based helper delegates entirely to corrected primitives, prefer reuse over re-deriving a new formula. + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test julia_compat_trig_tests +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/unary/trig_extra.rs crates/chainrules/src/unary/hyperbolic_extra.rs crates/chainrules/tests/julia_compat_trig_tests.rs +git commit -m "refactor: align Julia compatibility frules with standard jvp" +``` + +### Task 5: Remove `handle_r_to_c_*` And Update Public Docs + +**Files:** +- Modify: `crates/chainrules/src/scalar_ad/mod.rs` +- Modify: `crates/chainrules/src/lib.rs` +- Modify: `crates/chainrules/src/tests/behavior.rs` +- Modify: `crates/chainrules/tests/scalarops_tests.rs` +- Modify: `README.md` +- Modify: `crates/chainrules/README.md` + +**Step 1: Write the failing cleanup** + +Remove the public re-exports first: + +```rust +#[doc(inline)] +pub use scalar_ad::ScalarAd; +``` + +Then update the test imports so any remaining helper references are explicit failures: + +```rust +use chainrules::{ + add, add_frule, add_rrule, conj, conj_frule, conj_rrule, div, div_frule, div_rrule, mul, + mul_frule, mul_rrule, powf, powf_frule, powf_rrule, powi, powi_frule, powi_rrule, sqrt, + sqrt_frule, sqrt_rrule, sub, sub_frule, sub_rrule, +}; +``` + +**Step 2: Run test to verify it fails** + +Run: + +```bash +cargo nextest run --release --test scalarops_tests --test chainrules +cargo test --doc --release -p chainrules +``` + +Expected: FAIL because `handle_r_to_c_*` doctests and dedicated tests still exist. + +**Step 3: Write minimal implementation** + +Delete the unused helpers and update the docs to state the new convention clearly: + +- remove `handle_r_to_c_f32` and `handle_r_to_c_f64` from `scalar_ad/mod.rs` +- delete the dedicated helper checks from `src/tests/behavior.rs` and `tests/scalarops_tests.rs` +- update `README.md` so the testing section says Complex64 oracle replay is direct, not reverse-only +- update `crates/chainrules/README.md` so it: + - removes `handle_r_to_c_*` from the supported surface + - states that complex `frule` follows standard JVP on `C ~= R^2` + - states that complex `rrule` remains conjugate-Wirtinger for real-valued losses + - keeps the provenance note that this crate is a landing zone for scalar rules ported or adapted from `ChainRules.jl`, not a full port + +**Step 4: Run test to verify it passes** + +Run: + +```bash +cargo fmt --all +cargo nextest run --release --test scalarops_tests --test chainrules +cargo test --doc --release -p chainrules +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add crates/chainrules/src/scalar_ad/mod.rs crates/chainrules/src/lib.rs crates/chainrules/src/tests/behavior.rs crates/chainrules/tests/scalarops_tests.rs README.md crates/chainrules/README.md +git commit -m "cleanup: remove projection helpers and document complex jvp" +``` + +## Final Verification + +Run the full repository verification after the last task: + +```bash +cargo fmt --all --check +cargo nextest run --release --workspace --no-fail-fast +cargo test --doc --release --workspace +cargo clippy --workspace --all-targets -- -D warnings +cargo llvm-cov nextest --workspace --release --json --output-path coverage.json +python3 scripts/check-coverage.py coverage.json +cargo doc --workspace --no-deps +python3 scripts/check-docs-site.py +``` + +Expected: all commands PASS with no direct-complex-oracle failures, no remaining `handle_r_to_c_*` references, and no docs claiming that complex forward replay is local-only.