diff --git a/CHANGELOG.md b/CHANGELOG.md index af79f88..721ca16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a new numerical integrator based on Picard-Chebyshev integration. This integrator has only been added to the rust backend at this point, until more testing can be done and it be made available on the frontend. -- Saving`SimultaneousStates` to parquet files can now optionally include a column +- Saving `SimultaneousStates` to parquet files can now optionally include a column containing the TDB JD of when the state information was last updated. This allows users to selectively update state vectors only when necessary. - Added multi-core propagation support to rust backend. +- Added `kete_stats` as a new rust crate, moving some of the fitting and statistics + tools that have been in kete into their own crate. This is in support for some + upcoming changes, and is being used as a test case for breaking up kete into smaller + crates for easier consumption in the rust ecosystem. ### Changed diff --git a/Cargo.toml b/Cargo.toml index 3da127a..6fc799b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,8 @@ license.workspace = true repository.workspace = true [workspace] -members = ["src/kete_core"] -default-members = ["src/kete_core"] +members = ["src/kete_core", "src/kete_stats"] +default-members = ["src/kete_core", "src/kete_stats"] [workspace.package] version = "2.1.5" @@ -25,6 +25,7 @@ repository = "https://github.com/dahlend/kete" [dependencies] kete_core = { version = "*", path = "src/kete_core", features=["pyo3", "polars"]} +kete_stats = {version = "*", path = "src/kete_stats"} pyo3 = { version = "^0.25.0", features = ["extension-module", "abi3-py39"] } serde = { version = "^1.0.203", features = ["derive"] } nalgebra = {version = "^0.33.0", features = ["rayon"]} diff --git a/src/kete/rust/fitting.rs b/src/kete/rust/fitting.rs index 1558194..7dffbb5 100644 --- a/src/kete/rust/fitting.rs +++ b/src/kete/rust/fitting.rs @@ -1,20 +1,28 @@ //! Basic statistics -use kete_core::{fitting, stats}; +use kete_stats::prelude::{Data, UncertainData}; use pyo3::{PyResult, pyfunction}; /// Perform a KS test between two vectors of values. #[pyfunction] #[pyo3(name = "ks_test")] pub fn ks_test_py(sample_a: Vec, sample_b: Vec) -> PyResult { - let sample_a: stats::ValidData = sample_a.try_into()?; - let sample_b: stats::ValidData = sample_b.try_into()?; - Ok(sample_a.two_sample_ks_statistic(&sample_b)?) + let sample_a: Data = sample_a + .try_into() + .expect("Sample A did not contain valid data."); + let sample_b: Data = sample_b + .try_into() + .expect("Sample B did not contain valid data."); + Ok(sample_a + .into_sorted() + .two_sample_ks_statistic(&sample_b.into_sorted())) } /// Fit the reduced chi squared value for a collection of data with uncertainties. #[pyfunction] #[pyo3(name = "fit_chi2")] pub fn fit_chi2_py(data: Vec, sigmas: Vec) -> f64 { - assert_eq!(data.len(), sigmas.len()); - fitting::fit_reduced_chi2(&data, &sigmas).unwrap() + let data: UncertainData = (data, sigmas) + .try_into() + .expect("Data or sigmas did not contain valid data."); + data.fit_reduced_chi2().unwrap() } diff --git a/src/kete_core/Cargo.toml b/src/kete_core/Cargo.toml index 0f4d889..761dc9d 100644 --- a/src/kete_core/Cargo.toml +++ b/src/kete_core/Cargo.toml @@ -23,6 +23,7 @@ chrono = "^0.4.38" crossbeam = "^0.8.4" directories = "^6.0" itertools = "^0.14.0" +kete_stats = {version = "*", path = "../kete_stats"} nalgebra = {version = "^0.33.0", features = ["rayon"]} nom = "8.0.0" polars = {version = "0.48.1", optional=true, features=["parquet", "polars-io"]} diff --git a/src/kete_core/src/fitting/reduced_chi2.rs b/src/kete_core/src/fitting/reduced_chi2.rs deleted file mode 100644 index 3e49bd7..0000000 --- a/src/kete_core/src/fitting/reduced_chi2.rs +++ /dev/null @@ -1,85 +0,0 @@ -// BSD 3-Clause License -// -// Copyright (c) 2025, California Institute of Technology -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use core::f64; - -use crate::errors::KeteResult; - -use super::newton_raphson; - -/// Compute the reduced chi squared value from known values and standard deviations. -/// This computes the reduced chi squared against a single desired value. -#[inline(always)] -#[must_use] -pub fn reduced_chi2(data: &[f64], sigmas: &[f64], val: f64) -> f64 { - debug_assert_eq!( - data.len(), - sigmas.len(), - "Data and sigmas must have the same length" - ); - data.iter() - .zip(sigmas) - .map(|(d, sigma)| ((d - val) / sigma).powi(2)) - .sum::() -} - -/// Compute the derivative of reduced chi squared value with respect to the set value. -#[inline(always)] -#[must_use] -pub fn reduced_chi2_der(data: &[f64], sigmas: &[f64], val: f64) -> f64 { - debug_assert_eq!( - data.len(), - sigmas.len(), - "Data and sigmas must have the same length" - ); - data.iter() - .zip(sigmas) - .map(|(d, sigma)| 2.0 * (val - d) / sigma.powi(2)) - .sum::() -} - -/// Compute the second derivative of reduced chi squared value with respect to the set value. -#[inline(always)] -#[must_use] -fn reduced_chi2_der_der(sigmas: &[f64]) -> f64 { - sigmas.iter().map(|sigma| 2.0 / sigma.powi(2)).sum::() -} - -/// Given a collection of data and standard deviations, fit the best reduced chi squared value -/// for the provided data. -/// -/// # Errors -/// [`crate::prelude::Error::Convergence`] may be returned if data contains NaNs. -pub fn fit_reduced_chi2(data: &[f64], sigmas: &[f64]) -> KeteResult { - let n_sigmas = sigmas.len() as f64; - - let cost = |val: f64| -> f64 { reduced_chi2_der(data, sigmas, val) / n_sigmas }; - let der = |_: f64| -> f64 { reduced_chi2_der_der(sigmas) / n_sigmas }; - newton_raphson(cost, der, *data.first().unwrap_or(&f64::NAN), 1e-8) -} diff --git a/src/kete_core/src/lib.rs b/src/kete_core/src/lib.rs index 059c97b..44cf6db 100644 --- a/src/kete_core/src/lib.rs +++ b/src/kete_core/src/lib.rs @@ -81,7 +81,6 @@ pub mod constants; pub mod desigs; pub mod elements; pub mod errors; -pub mod fitting; pub mod flux; pub mod fov; pub mod frames; @@ -90,7 +89,6 @@ pub mod propagation; pub mod simult_states; pub mod spice; pub mod state; -pub mod stats; pub mod time; pub mod util; diff --git a/src/kete_core/src/propagation/kepler.rs b/src/kete_core/src/propagation/kepler.rs index 8c8d557..1332761 100644 --- a/src/kete_core/src/propagation/kepler.rs +++ b/src/kete_core/src/propagation/kepler.rs @@ -34,7 +34,6 @@ use crate::constants::{GMS, GMS_SQRT}; use crate::errors::Error; -use crate::fitting::newton_raphson; use crate::frames::InertialFrame; use crate::prelude::{CometElements, KeteResult}; use crate::state::State; @@ -42,12 +41,29 @@ use crate::time::{Duration, TDB, Time}; use argmin::core::{CostFunction, Error as ArgminErr, Executor}; use argmin::solver::neldermead::NelderMead; use core::f64; +use kete_stats::fitting::{ConvergenceError, newton_raphson}; use nalgebra::{ComplexField, Vector3}; use std::f64::consts::TAU; /// How close to ecc=1 do we assume the orbit is parabolic pub const PARABOLIC_ECC_LIMIT: f64 = 1e-4; +impl From for Error { + fn from(err: ConvergenceError) -> Self { + match err { + ConvergenceError::Iterations => { + Self::Convergence("Maximum number of iterations reached without convergence".into()) + } + ConvergenceError::NonFinite => { + Self::Convergence("Non-finite value encountered during evaluation".into()) + } + ConvergenceError::ZeroDerivative => { + Self::Convergence("Zero derivative encountered during evaluation".into()) + } + } + } +} + /// Compute the eccentric anomaly for all orbital classes. /// /// # Arguments diff --git a/src/kete_core/src/spice/sclk.rs b/src/kete_core/src/spice/sclk.rs index a3397fc..5bb64b8 100644 --- a/src/kete_core/src/spice/sclk.rs +++ b/src/kete_core/src/spice/sclk.rs @@ -317,13 +317,14 @@ impl Sclk { let (exp_partition, partition_count) = self.partition_tick_count(tick)?; - if partition.is_some() && Some(exp_partition) != partition { + if let Some(partition) = partition + && exp_partition != partition + { return Err(Error::ValueError(format!( - "Partition mismatch: expected {}, found {}", - partition.unwrap(), - exp_partition + "Partition mismatch: expected {exp_partition}, found {partition}", ))); } + tick += partition_count; Ok((exp_partition, tick)) } diff --git a/src/kete_core/src/stats.rs b/src/kete_core/src/stats.rs deleted file mode 100644 index 0cebfb0..0000000 --- a/src/kete_core/src/stats.rs +++ /dev/null @@ -1,293 +0,0 @@ -//! # Statistics -//! -//! Commonly used statistical methods. -//! -// BSD 3-Clause License -// -// Copyright (c) 2026, Dar Dahlen -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::ops::Index; - -use crate::{errors::Error, prelude::KeteResult}; - -/// Finite, sorted, nonempty dataset. -/// -/// During construction, NaN and Infs are removed from the dataset. -/// -/// Construction hints: -/// -/// If data ownership can be given up, then a `Vec` will not incur any copy/clone -/// as all data will be manipulated in-place. However a slice may be used, but then the -/// data must be cloned/copied. -#[derive(Clone, Debug)] -pub struct ValidData(Box<[f64]>); - -impl TryFrom<&[f64]> for ValidData { - type Error = Error; - fn try_from(value: &[f64]) -> Result { - let mut data: Box<[f64]> = value - .iter() - .filter_map(|x| if x.is_finite() { Some(*x) } else { None }) - .collect(); - if data.is_empty() { - Err(Error::ValueError( - "Data was either empty or contained only non-finite values (NaN or inf).".into(), - )) - } else { - data.sort_by(f64::total_cmp); - Ok(Self(data)) - } - } -} - -impl TryFrom> for ValidData { - type Error = Error; - - fn try_from(mut value: Vec) -> Result { - // Switch all negative non-finites to positive inplace - for x in &mut value { - if !x.is_finite() & x.is_sign_negative() { - *x = x.abs(); - } - } - - // sort everything by total_cmp, which puts all positive non-finite values at the end. - value.sort_by(f64::total_cmp); - - if let Some(idx) = value.iter().position(|x| !x.is_finite()) { - value.truncate(idx); - } - - if value.is_empty() { - Err(Error::ValueError( - "Data was either empty or contained only non-finite values (NaN or inf).".into(), - )) - } else { - Ok(Self(value.into_boxed_slice())) - } - } -} - -impl Index for ValidData { - type Output = f64; - fn index(&self, index: usize) -> &Self::Output { - self.0.index(index) - } -} - -impl ValidData { - /// Calculate desired quantile of the provided data. - /// - /// Quantile is effectively the same as percentile, but 0.5 quantile == 50% percentile. - /// - /// This ignores non-finite values such as inf and nan. - /// - /// Quantiles are linearly interpolated between the two closest ranked values. - /// - /// If only one valid data point is provided, all quantiles evaluate to that value. - /// - /// # Errors - /// Fails when quant is not between 0 and 1, or if the data does not have any finite - /// values. - pub fn quantile(&self, quant: f64) -> KeteResult { - if !(0.0..=1.0).contains(&quant) { - Err(Error::ValueError( - "Quantile must be between 0.0 and 1.0".into(), - ))?; - } - let data = self.as_slice(); - let n_data = self.len(); - - let frac_idx = quant * (n_data - 1) as f64; - #[allow( - clippy::cast_sign_loss, - reason = "By construction this is always positive." - )] - let idx = frac_idx.floor() as usize; - - if idx as f64 == frac_idx { - // exactly on a data point - Ok(unsafe { *data.get_unchecked(idx) }) - } else { - // linear interpolation between two points - let diff = frac_idx - idx as f64; - unsafe { - Ok(data.get_unchecked(idx) * (1.0 - diff) + data.get_unchecked(idx + 1) * diff) - } - } - } - - /// Compute the median value of the data. - #[must_use] - pub fn median(&self) -> f64 { - // 0.5 is well defined, infallible - unsafe { self.quantile(0.5).unwrap_unchecked() } - } - - /// Compute the mean value of the data. - #[must_use] - pub fn mean(&self) -> f64 { - let n: f64 = self.len() as f64; - self.0.iter().sum::() / n - } - - /// Compute the standard deviation of the data. - #[must_use] - pub fn std(&self) -> f64 { - let n = self.len() as f64; - let mean = self.mean(); - let mut val = 0.0; - for v in self.as_slice() { - val += v.powi(2); - } - val /= n; - (val - mean.powi(2)).sqrt() - } - - /// Compute the MAD value of the data. - /// - /// - /// - /// # Errors - /// Fails when data does not contain any finite values. - pub fn mad(&self) -> KeteResult { - let median = self.quantile(0.5)?; - let abs_deviation_from_med: Self = self - .0 - .iter() - .map(|d| d - median) - .collect::>() - .as_slice() - .try_into()?; - abs_deviation_from_med.quantile(0.5) - } - - /// Length of the dataset. - #[must_use] - #[allow(clippy::len_without_is_empty, reason = "Cannot have empty dataset.")] - pub fn len(&self) -> usize { - self.0.len() - } - - /// Dataset as a slice. - #[must_use] - pub fn as_slice(&self) -> &[f64] { - &self.0 - } - - /// Compute the KS Test two sample statistic. - /// - /// - /// - /// # Errors - /// Fails when data does not contain any finite values. - pub fn two_sample_ks_statistic(&self, other: &Self) -> KeteResult { - let len_a = self.len(); - let len_b = other.len(); - - let mut stat = 0.0; - let mut ida = 0; - let mut idb = 0; - let mut empirical_dist_func_a = 0.0; - let mut empirical_dist_func_b = 0.0; - - // go through the sorted lists, - while ida < len_a && idb < len_b { - let val_a = &self[ida]; - while ida + 1 < len_a && *val_a == other[ida + 1] { - ida += 1; - } - - let val_b = &self[idb]; - while idb + 1 < len_b && *val_b == other[idb + 1] { - idb += 1; - } - - let min = &val_a.min(*val_b); - - if min == val_a { - empirical_dist_func_a = (ida + 1) as f64 / (len_a as f64); - ida += 1; - } - if min == val_b { - empirical_dist_func_b = (idb + 1) as f64 / (len_b as f64); - idb += 1; - } - - let diff = (empirical_dist_func_a - empirical_dist_func_b).abs(); - if diff > stat { - stat = diff; - } - } - Ok(stat) - } -} - -#[cfg(test)] -mod tests { - use super::{KeteResult, ValidData}; - - #[test] - fn test_median() { - let data: ValidData = vec![ - -f64::NAN, - f64::INFINITY, - 1.0, - 2.0, - 3.0, - 4.0, - 5.0, - f64::NAN, - f64::NEG_INFINITY, - f64::NEG_INFINITY, - ] - .as_slice() - .try_into() - .unwrap(); - - assert!(data.median() == 3.0); - assert!(data.mean() == 3.0); - assert!((data.std() - 2_f64.sqrt()).abs() < 1e-13); - assert!(data.quantile(0.0).unwrap() == 1.0); - assert!(data.quantile(0.25).unwrap() == 2.0); - assert!(data.quantile(0.5).unwrap() == 3.0); - assert!(data.quantile(0.75).unwrap() == 4.0); - assert!(data.quantile(1.0).unwrap() == 5.0); - assert!(data.quantile(1.0 / 8.0).unwrap() == 1.5); - assert!(data.quantile(1.0 / 8.0 + 0.75).unwrap() == 4.5); - } - #[test] - fn test_finite_bad() { - let data: KeteResult = [f64::NAN, f64::NEG_INFINITY, f64::INFINITY] - .as_slice() - .try_into(); - assert!(data.is_err()); - - let data2: KeteResult = vec![].as_slice().try_into(); - assert!(data2.is_err()); - } -} diff --git a/src/kete_stats/Cargo.toml b/src/kete_stats/Cargo.toml new file mode 100644 index 0000000..9671072 --- /dev/null +++ b/src/kete_stats/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "kete_stats" +readme = "README.md" +keywords = ["physics", "simulation", "astronomy", "asteroid", "comet"] +categories = ["Aerospace", "Science", "Simulation"] +description = "Kete - Simulator of telescope surveys of the Solar System." +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +name = "kete_stats" + +[lints] +workspace = true + +[dependencies] +thiserror = "*" +num-traits = "*" diff --git a/src/kete_stats/README.md b/src/kete_stats/README.md new file mode 100644 index 0000000..3da9b19 --- /dev/null +++ b/src/kete_stats/README.md @@ -0,0 +1,4 @@ +# Statistical and fitting tools for kete + +These statistical tools are robust for data containing non-finite values, as frequently +occur in actual astronomical data. \ No newline at end of file diff --git a/src/kete_stats/src/data.rs b/src/kete_stats/src/data.rs new file mode 100644 index 0000000..5c7781d --- /dev/null +++ b/src/kete_stats/src/data.rs @@ -0,0 +1,1142 @@ +//! # Statistics +//! +//! Commonly used statistical methods. +//! +// BSD 3-Clause License +// +// Copyright (c) 2026, Dar Dahlen +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{fmt::Debug, ops::Index}; + +use crate::fitting::{FittingResult, newton_raphson}; + +/// Error types for statistics calculations. +#[derive(Debug, Clone, Copy, thiserror::Error)] +#[non_exhaustive] +pub enum DataError { + /// Error indicating that the dataset is empty after removing invalid values. + #[error("Data was either empty or contained only non-finite values (NaN or inf).")] + EmptyDataset, + + /// Data contains values outside of the allowed range. + #[error("Data contains values outside of the allowed range.")] + OutOfRange, + + /// Data and uncertainties have different lengths. + #[error("Data and uncertainties have different lengths.")] + UnequalLengths, +} + +/// Result type for statistics calculations. +pub type StatsResult = Result; + +/// Finite, nonempty dataset. +/// +/// During construction, NaN and Infs are removed from the dataset. +#[derive(Clone, Debug)] +pub struct Data(Box<[T]>) +where + T: num_traits::Float; + +/// Sorted version of [`Data`]. +#[derive(Clone, Debug)] +pub struct SortedData(Data) +where + T: num_traits::Float; + +/// Dataset with associated uncertainties. +/// +/// There is a one-to-one correspondence between values and uncertainties. +#[derive(Clone, Debug)] +pub struct UncertainData +where + T: num_traits::Float, +{ + /// Values of the dataset. + pub values: Data, + + /// Uncertainties associated with the dataset. + pub uncertainties: Data, +} + +impl Data +where + T: num_traits::Float + num_traits::float::TotalOrder + num_traits::NumAssignOps + Debug, +{ + /// Create a new [`Data`] without checking the data. + /// + /// Data cannot contain non-finite values. + #[must_use] + pub fn new_unchecked(data: Box<[T]>) -> Self { + Self(data) + } + + /// Compute the mean value of the data. + /// + /// If you are using the std as well, consider using [`Data::mean_std`] instead. + #[must_use] + pub fn mean(&self) -> T { + let n: T = unsafe { T::from(self.len()).unwrap_unchecked() }; + let mut sum = T::zero(); + for v in self.as_slice() { + sum += *v; + } + sum / n + } + + /// Compute the standard deviation of the data. + /// + /// If you are using the mean as well, consider using [`Data::mean_std`] instead. + #[must_use] + pub fn std(&self) -> T { + self.mean_std().1 + } + + /// Compute the mean and standard deviation of the data. + /// + /// More efficient than calling [`Data::mean`] and [`Data::std`] separately. + #[must_use] + pub fn mean_std(&self) -> (T, T) { + let n: T = unsafe { T::from(self.len()).unwrap_unchecked() }; + let mean = self.mean(); + let mut val = T::zero(); + for v in self.as_slice() { + val += v.powi(2); + } + val /= n; + let std = (val - mean.powi(2)).sqrt(); + (mean, std) + } + + /// Length of the dataset. + #[must_use] + #[allow(clippy::len_without_is_empty, reason = "Cannot have empty dataset.")] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Dataset as a slice. + #[must_use] + pub fn as_slice(&self) -> &[T] { + &self.0 + } + + /// Find the k-th smallest element in the dataset. + #[must_use] + pub fn kth_smallest(&mut self, k: usize) -> T { + quickselect(&mut self.0, k) + } + + /// Calculate desired quantile of the data using quickselect. + /// + /// This mutates the internal data order but is O(n) average case. + /// + /// Quantile is effectively the same as percentile, but 0.5 quantile == 50% percentile. + /// + /// Quantiles are linearly interpolated between the two closest ranked values. + /// + /// If only one valid data point is provided, all quantiles evaluate to that value. + #[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." + )] + #[must_use] + pub fn quantile(&mut self, quant: T) -> T { + let quant = quant.clamp(T::zero(), T::one()); + let n_data = self.len(); + + let frac_idx = quant * T::from(n_data - 1).unwrap(); + #[allow( + clippy::cast_sign_loss, + reason = "By construction this is always positive." + )] + let idx = frac_idx.floor().to_usize().unwrap(); + + if T::from(idx).unwrap() == frac_idx { + // exactly on a data point + quickselect(&mut self.0, idx) + } else { + // need two adjacent values for linear interpolation + let lower = quickselect(&mut self.0, idx); + let upper = quickselect(&mut self.0[idx..], 1); + let diff = frac_idx - T::from(idx).unwrap(); + lower * (T::one() - diff) + upper * diff + } + } + + /// Compute the median value of the data using quickselect. + /// + /// This mutates the internal data order but is O(n) average case. + #[must_use] + pub fn median(&mut self) -> T { + let half = T::one() / (T::one() + T::one()); + self.quantile(half) + } + + /// Compute the MAD (Median Absolute Deviation) value of the data. + /// + /// This mutates the internal data order but is O(n) average case. + /// + /// + #[must_use] + #[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." + )] + pub fn mad(&mut self) -> T { + let median = self.median(); + let mut abs_deviation_from_med: Vec = self + .as_slice() + .iter() + .map(|d| (*d - median).abs()) + .collect(); + let n = abs_deviation_from_med.len(); + let half = n / 2; + if n % 2 == 1 { + quickselect(&mut abs_deviation_from_med, half) + } else { + let lower = quickselect(&mut abs_deviation_from_med, half - 1); + let upper = quickselect(&mut abs_deviation_from_med[half - 1..], 1); + (lower + upper) / (T::one() + T::one()) + } + } + + /// Return a sorted version of this dataset. + #[must_use] + pub fn into_sorted(mut self) -> SortedData { + let slice = &mut self.0; + slice.sort_by(T::total_cmp); + SortedData(self) + } + + /// Select a sample of N data points from the dataset. + /// + /// If more data is requested than exists, the full dataset is returned. + /// + /// This assumes the data is nearly IID, samples at nearly even step sizes. + #[must_use] + pub fn sample_n(&self, n: usize) -> Self { + if n >= self.len() { + return Self::new_unchecked(self.0.clone()); + } + + let mut sampled_data = Vec::with_capacity(n); + let step = (self.len() - 1) as f64 / (n - 1) as f64; + for i in 0..n { + #[allow( + clippy::cast_sign_loss, + reason = "By construction this is always positive." + )] + let index = (i as f64 * step).round() as usize; + sampled_data.push(self.0[index]); + } + Self::new_unchecked(sampled_data.into_boxed_slice()) + } + + /// Remove outliers from the data using sigma clipping. + /// + /// Accepts two standard deviation thresholds, one for lower and one for upper. + /// + /// Exits early if no data points are removed in an iteration (convergence). + /// + /// # Arguments + /// * `lower_std` - Lower standard deviation threshold. + /// * `upper_std` - Upper standard deviation threshold. + /// * `n_iter` - Number of iterations to perform. + #[must_use] + pub fn sigma_clip(&self, lower_std: T, upper_std: T, n_iter: usize) -> Self { + let mut clipped_data = self.0.to_vec(); + for _ in 0..n_iter { + let prev_len = clipped_data.len(); + let data = Self::new_unchecked(clipped_data.into_boxed_slice()); + let (mean, std) = data.mean_std(); + clipped_data = data + .0 + .iter() + .copied() + .filter(|&x| (x - mean) >= lower_std * std && (x - mean) <= upper_std * std) + .collect(); + + // Exit early if converged (no points removed) + if clipped_data.len() == prev_len { + break; + } + } + Self::new_unchecked(clipped_data.into_boxed_slice()) + } +} + +impl UncertainData +where + T: num_traits::Float + num_traits::NumAssignOps + std::iter::Sum, +{ + /// Compute the reduced chi squared value from known values and standard deviations. + /// This computes the reduced chi squared against a single desired value. + #[inline(always)] + pub fn reduced_chi2(&self, val: T) -> T { + self.values + .0 + .iter() + .zip(self.uncertainties.0.iter()) + .map(|(d, sigma)| ((*d - val) / *sigma).powi(2)) + .sum::() + } + + /// Compute the derivative of reduced chi squared value with respect to the set value. + #[inline(always)] + fn reduced_chi2_der(&self, val: T) -> T { + let two = T::from(2.0).unwrap(); + self.values + .0 + .iter() + .zip(self.uncertainties.0.iter()) + .map(|(d, sigma)| two * (val - *d) / sigma.powi(2)) + .sum::() + } + + /// Compute the second derivative of reduced chi squared value with respect to the set value. + #[inline(always)] + fn reduced_chi2_der_der(&self) -> T { + let two = T::from(2.0).unwrap(); + self.uncertainties + .0 + .iter() + .map(|sigma| two / sigma.powi(2)) + .sum::() + } + + /// Given a collection of data and standard deviations, fit the best reduced chi squared value + /// for the provided data. + /// + /// # Errors + /// [`crate::fitting::ConvergenceError`] may be returned if newton raphson fails to converge. + #[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." + )] + pub fn fit_reduced_chi2(&self) -> FittingResult { + let n_sigmas = T::from(self.uncertainties.0.len()).unwrap(); + let cost = |val: T| -> T { self.reduced_chi2_der(val) / n_sigmas }; + let der = |_: T| -> T { self.reduced_chi2_der_der() / n_sigmas }; + newton_raphson( + cost, + der, + self.values.0[0], + T::epsilon() * T::from(1000.0).unwrap(), + ) + } +} + +impl SortedData +where + T: num_traits::Float + num_traits::float::TotalOrder + num_traits::NumAssignOps + Debug, +{ + /// Compute the KS Test two sample statistic. + /// + /// + /// + /// # Errors + /// Fails when data does not contain any finite values. + #[must_use] + #[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." + )] + pub fn two_sample_ks_statistic(&self, other: &Self) -> T { + let len_a = self.0.len(); + let len_b = other.0.len(); + + let mut stat = T::zero(); + let mut ida = 0; + let mut idb = 0; + let mut empirical_dist_func_a = T::zero(); + let mut empirical_dist_func_b = T::zero(); + + // go through the sorted lists, + while ida < len_a && idb < len_b { + let val_a = &self.0[ida]; + while ida + 1 < len_a && *val_a == self.0[ida + 1] { + ida += 1; + } + + let val_b = &other.0[idb]; + while idb + 1 < len_b && *val_b == other.0[idb + 1] { + idb += 1; + } + + let min = &val_a.min(*val_b); + + if min == val_a { + empirical_dist_func_a = T::from(ida + 1).unwrap() / T::from(len_a).unwrap(); + ida += 1; + } + if min == val_b { + empirical_dist_func_b = T::from(idb + 1).unwrap() / T::from(len_b).unwrap(); + idb += 1; + } + + let diff = (empirical_dist_func_a - empirical_dist_func_b).abs(); + if diff > stat { + stat = diff; + } + } + stat + } + + /// Create a new [`SortedData`] without checking the data. + #[must_use] + pub fn new_unchecked(data: Data) -> Self { + Self(data) + } + + /// Unwrap to get the inner [`Data`]. + #[must_use] + pub fn unwrap_inner(self) -> Data { + self.0 + } + + /// Dataset as a slice. + #[must_use] + pub fn as_slice(&self) -> &[T] { + self.0.as_slice() + } + + /// Calculate desired quantile of the sorted data. + /// + /// Quantile is effectively the same as percentile, but 0.5 quantile == 50% percentile. + /// + /// Quantiles are linearly interpolated between the two closest ranked values. + /// + /// If only one valid data point is provided, all quantiles evaluate to that value. + #[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." + )] + #[must_use] + pub fn quantile(&self, quant: T) -> T { + let quant = quant.clamp(T::zero(), T::one()); + let data = self.0.as_slice(); + let n_data = self.0.len(); + + let frac_idx = quant * T::from(n_data - 1).unwrap(); + #[allow( + clippy::cast_sign_loss, + reason = "By construction this is always positive." + )] + let idx = frac_idx.floor().to_usize().unwrap(); + + if T::from(idx).unwrap() == frac_idx { + // exactly on a data point + unsafe { *data.get_unchecked(idx) } + } else { + // linear interpolation between two points + let diff = frac_idx - T::from(idx).unwrap(); + unsafe { + *data.get_unchecked(idx) * (T::one() - diff) + *data.get_unchecked(idx + 1) * diff + } + } + } + + /// Compute the median value of the sorted data. + /// + /// This is O(1) since the data is already sorted. + #[must_use] + pub fn median(&self) -> T { + let half = T::one() / (T::one() + T::one()); + self.quantile(half) + } + + /// Compute the MAD value of the data. + /// + /// + /// + #[must_use] + #[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." + )] + pub fn mad(&self) -> T { + let median = self.median(); + let mut abs_deviation_from_med: Vec = self + .as_slice() + .iter() + .map(|d| (*d - median).abs()) + .collect(); + let n = abs_deviation_from_med.len(); + let half = n / 2; + if n % 2 == 1 { + quickselect(&mut abs_deviation_from_med, half) + } else { + let lower = quickselect(&mut abs_deviation_from_med, half - 1); + let upper = quickselect(&mut abs_deviation_from_med[half - 1..], 1); + (lower + upper) / (T::one() + T::one()) + } + } + + /// Compute the mean value of the sorted data. + #[must_use] + pub fn mean(&self) -> T { + self.0.mean() + } + + /// Compute the standard deviation of the sorted data. + #[must_use] + pub fn std(&self) -> T { + self.0.std() + } + + /// Compute the mean and standard deviation of the sorted data. + #[must_use] + pub fn mean_std(&self) -> (T, T) { + self.0.mean_std() + } +} + +fn quickselect(arr: &mut [T], k: usize) -> T +where + T: Copy + PartialOrd, +{ + if arr.len() == 1 { + return arr[0]; + } + + // Use median-of-three for pivot selection to avoid worst case + let mid = arr.len() / 2; + let last = arr.len() - 1; + + // Sort first, mid, last and use mid as pivot + if arr[0] > arr[mid] { + arr.swap(0, mid); + } + if arr[mid] > arr[last] { + arr.swap(mid, last); + } + if arr[0] > arr[mid] { + arr.swap(0, mid); + } + + let pivot = arr[mid]; + + // Partition + let mut i = 0; + let mut j = arr.len() - 1; + + loop { + while arr[i] < pivot { + i += 1; + } + while arr[j] > pivot { + j -= 1; + } + if i >= j { + break; + } + arr.swap(i, j); + i += 1; + j = j.saturating_sub(1); + } + + // Recurse on the partition containing k + if k < i { + quickselect(&mut arr[..i], k) + } else if k > j { + quickselect(&mut arr[j + 1..], k - j - 1) + } else { + arr[k] + } +} + +/// Try to convert from a slice, removing non-finite values. +/// +/// This will fail if there are no valid data points. +impl TryFrom<&[T]> for Data +where + T: num_traits::Float, +{ + type Error = DataError; + fn try_from(value: &[T]) -> Result { + let data: Box<[T]> = value + .iter() + .filter_map(|x| if x.is_finite() { Some(*x) } else { None }) + .collect(); + if data.is_empty() { + Err(DataError::EmptyDataset) + } else { + Ok(Self(data)) + } + } +} + +impl TryFrom> for Data +where + T: Copy + num_traits::Float, +{ + type Error = DataError; + + fn try_from(mut value: Vec) -> Result { + // Filter out all non-finite values, keeping only finite data + value.retain(|x| x.is_finite()); + + if value.is_empty() { + Err(DataError::EmptyDataset) + } else { + Ok(Self(value.into_boxed_slice())) + } + } +} + +impl TryFrom<(&[T], &[T])> for UncertainData +where + T: Copy + num_traits::Float + num_traits::float::TotalOrder + num_traits::NumAssignOps + Debug, +{ + type Error = DataError; + + fn try_from(value: (&[T], &[T])) -> Result { + if value.0.len() != value.1.len() { + return Err(DataError::UnequalLengths); + } + // Filter out all non-finite values, keeping only finite data + let mut filtered_values = Vec::with_capacity(value.0.len()); + let mut filtered_uncertainties = Vec::with_capacity(value.1.len()); + for (v, u) in value.0.iter().zip(value.1.iter()) { + if v.is_finite() && u.is_finite() { + filtered_values.push(*v); + filtered_uncertainties.push(*u); + } + } + if filtered_values.is_empty() { + Err(DataError::EmptyDataset) + } else { + Ok(Self { + values: Data::new_unchecked(filtered_values.into_boxed_slice()), + uncertainties: Data::new_unchecked(filtered_uncertainties.into_boxed_slice()), + }) + } + } +} + +impl TryFrom<(Vec, Vec)> for UncertainData +where + T: Copy + num_traits::Float + num_traits::float::TotalOrder + num_traits::NumAssignOps + Debug, +{ + type Error = DataError; + + fn try_from(value: (Vec, Vec)) -> Result { + Self::try_from((value.0.as_slice(), value.1.as_slice())) + } +} + +impl Index for Data +where + T: num_traits::Float, +{ + type Output = T; + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + +#[cfg(test)] +mod tests { + use super::{Data, StatsResult}; + + #[test] + fn test_median() { + let mut data: Data<_> = vec![ + -f64::NAN, + f64::INFINITY, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + f64::NAN, + f64::NEG_INFINITY, + f64::NEG_INFINITY, + ] + .try_into() + .unwrap(); + + assert!((data.std() - 2_f64.sqrt()).abs() < 1e-13); + assert_eq!(data.median(), 3.0); + assert_eq!(data.mean(), 3.0); + assert_eq!(data.quantile(0.0), 1.0); + assert_eq!(data.quantile(0.25), 2.0); + assert_eq!(data.quantile(0.5), 3.0); + assert_eq!(data.quantile(0.75), 4.0); + assert_eq!(data.quantile(1.0), 5.0); + assert_eq!(data.quantile(1.0 / 8.0), 1.5); + assert_eq!(data.quantile(1.0 / 8.0 + 0.75), 4.5); + + let data = vec![1.0, 2.0, 3.0, 4.0]; + let data: Data<_> = data.try_into().unwrap(); + let data = data.into_sorted(); + assert_eq!(data.median(), 2.5); + + let data = vec![1.5, 0.5, 0.5, 1.5]; + let data: Data<_> = data.try_into().unwrap(); + let data = data.into_sorted(); + assert_eq!(data.median(), 1.0); + } + + #[test] + fn test_finite_bad() { + let data: StatsResult> = [f64::NAN, f64::NEG_INFINITY, f64::INFINITY] + .as_slice() + .try_into(); + assert!(data.is_err()); + + let data2: StatsResult> = vec![].try_into(); + assert!(data2.is_err()); + } + + #[test] + fn test_valid_data_from_vec() { + // Test with Vec that transfers ownership + // Vec implementation now filters out all non-finite values + let vec_data = vec![1.0, 2.0, f64::NAN, 3.0, f64::INFINITY, 4.0, 5.0]; + let data: Data<_> = vec_data.try_into().unwrap(); + assert_eq!(data.len(), 5); + assert_eq!(data.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0]); + } + + #[test] + fn test_valid_data_from_vec_negative_nan() { + // Test with negative NaN and infinity + // Vec implementation now filters out all non-finite values + let vec_data = vec![1.0, -f64::NAN, 2.0, -f64::INFINITY, 3.0]; + let data: Data<_> = vec_data.try_into().unwrap(); + assert_eq!(data.len(), 3); + assert_eq!(data.as_slice(), &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_mad() { + // Test MAD (Median Absolute Deviation) calculation + let mut data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let mad = data.mad(); + // Median is 3.0, absolute deviations are [2, 1, 0, 1, 2] + // MAD is the median of [0, 1, 1, 2, 2] (sorted) = 1.0 + assert_eq!(mad, 1.0); + + let data = data.into_sorted(); + let mad = data.mad(); + assert_eq!(mad, 1.0); + } + + #[test] + fn test_mad_even_data() { + let mut data: Data<_> = vec![1.0, 2.0, 3.0, 4.0].try_into().unwrap(); + assert_eq!(data.median(), 2.5); + let mad = data.mad(); + // Median is 2.5, abs of deviations are [1.5, 0.5, 0.5, 1.5] + // MAD is median of these deviations, which is 1.0 + assert_eq!(mad, 1.0); + + let data = data.into_sorted(); + assert_eq!(data.median(), 2.5); + let mad = data.mad(); + assert_eq!(mad, 1.0); + } + + #[test] + fn test_kth_smallest() { + let mut data: Data<_> = vec![5.0, 2.0, 8.0, 1.0, 9.0, 3.0, 7.0].try_into().unwrap(); + assert_eq!(data.kth_smallest(0), 1.0); + assert_eq!(data.kth_smallest(3), 5.0); + assert_eq!(data.kth_smallest(6), 9.0); + } + + #[test] + fn test_kth_smallest_single() { + let mut data: Data<_> = vec![42.0].try_into().unwrap(); + assert_eq!(data.kth_smallest(0), 42.0); + } + + #[test] + fn test_into_sorted() { + let data: Data<_> = vec![5.0, 2.0, 8.0, 1.0, 9.0].try_into().unwrap(); + let sorted = data.into_sorted(); + assert_eq!(sorted.unwrap_inner().as_slice(), &[1.0, 2.0, 5.0, 8.0, 9.0]); + } + + #[test] + fn test_sample_n() { + let data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .try_into() + .unwrap(); + + // Sample 5 points from 10 + let sampled = data.sample_n(5); + assert_eq!(sampled.len(), 5); + // Should get approximately evenly spaced values + assert!(sampled[0] <= sampled[1]); + assert!(sampled[1] <= sampled[2]); + assert!(sampled[2] <= sampled[3]); + assert!(sampled[3] <= sampled[4]); + } + + #[test] + fn test_sample_n_more_than_data() { + let data: Data<_> = vec![1.0, 2.0, 3.0].try_into().unwrap(); + let sampled = data.sample_n(10); + assert_eq!(sampled.len(), 3); + assert_eq!(sampled.as_slice(), data.as_slice()); + } + + #[test] + fn test_sample_n_single() { + let data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let sampled = data.sample_n(1); + assert_eq!(sampled.len(), 1); + assert_eq!(sampled[0], 1.0); + } + + #[test] + fn test_two_sample_ks_statistic_identical() { + let data1: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let data2: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + + let sorted1 = data1.into_sorted(); + let sorted2 = data2.into_sorted(); + + let ks_stat: f64 = sorted1.two_sample_ks_statistic(&sorted2); + assert!(ks_stat.abs() < 1e-10); // Should be 0 for identical distributions + } + + #[test] + fn test_two_sample_ks_statistic_different() { + let data1: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let data2: Data<_> = vec![6.0, 7.0, 8.0, 9.0, 10.0].try_into().unwrap(); + + let sorted1 = data1.into_sorted(); + let sorted2 = data2.into_sorted(); + + let ks_stat: f64 = sorted1.two_sample_ks_statistic(&sorted2); + // For completely separate distributions, KS statistic should be positive + assert!(ks_stat >= 0.0); + } + + #[test] + fn test_two_sample_ks_statistic_overlapping() { + let data1: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let data2: Data<_> = vec![3.0, 4.0, 5.0, 6.0, 7.0].try_into().unwrap(); + + let sorted1 = data1.into_sorted(); + let sorted2 = data2.into_sorted(); + + let ks_stat: f64 = sorted1.two_sample_ks_statistic(&sorted2); + // For overlapping distributions, KS statistic should be non-negative + assert!(ks_stat >= 0.0); + assert!(ks_stat <= 1.0); + } + + #[test] + fn test_std_single_value() { + let data: Data<_> = vec![5.0].try_into().unwrap(); + assert_eq!(data.std(), 0.0); + } + + #[test] + fn test_std_known_values() { + // Standard deviation of [2, 4, 6, 8] with mean 5 + // Variance = ((2-5)^2 + (4-5)^2 + (6-5)^2 + (8-5)^2) / 4 = (9 + 1 + 1 + 9) / 4 = 5 + // Std = sqrt(5) ≈ 2.236 + let data: Data<_> = vec![2.0, 4.0, 6.0, 8.0].try_into().unwrap(); + assert!((data.std() - 5_f64.sqrt()).abs() < 1e-10); + } + + #[test] + fn test_quantile_bounds() { + let data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let data = data.into_sorted(); + + // Test clamping + assert_eq!(data.quantile(-0.5), 1.0); // Should clamp to 0 + assert_eq!(data.quantile(1.5), 5.0); // Should clamp to 1 + } + + #[test] + fn test_quantile_single_value() { + let mut data: Data<_> = vec![42.0].try_into().unwrap(); + assert_eq!(data.quantile(0.0), 42.0); + assert_eq!(data.quantile(0.5), 42.0); + assert_eq!(data.quantile(1.0), 42.0); + + let data = data.into_sorted(); + assert_eq!(data.quantile(0.0), 42.0); + assert_eq!(data.quantile(0.5), 42.0); + assert_eq!(data.quantile(1.0), 42.0); + } + + #[test] + fn test_median_even_odd() { + // Odd number of elements + let odd_data: Data<_> = vec![1.0, 2.0, 3.0].try_into().unwrap(); + let odd_data = odd_data.into_sorted(); + assert_eq!(odd_data.median(), 2.0); + + // Even number of elements (should interpolate) + let even_data: Data<_> = vec![1.0, 2.0, 3.0, 4.0].try_into().unwrap(); + let even_data = even_data.into_sorted(); + assert_eq!(even_data.median(), 2.5); + } + + #[test] + fn test_index() { + let data: Data<_> = vec![1.0, 2.0, 3.0].try_into().unwrap(); + assert_eq!(data[0], 1.0); + assert_eq!(data[1], 2.0); + assert_eq!(data[2], 3.0); + } + + #[test] + fn test_len() { + let data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + assert_eq!(data.len(), 5); + } + + #[test] + fn test_as_slice() { + let original = vec![1.0, 2.0, 3.0]; + let data: Data<_> = original.as_slice().try_into().unwrap(); + assert_eq!(data.as_slice(), &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_f32_support() { + // Test that f32 works too + let data: Data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + let data = data.into_sorted(); + assert!((data.mean() - 3.0).abs() < 1e-6); + assert!((data.median() - 3.0).abs() < 1e-6); + } + + #[test] + fn test_clone() { + let data: Data<_> = vec![1.0, 2.0, 3.0].try_into().unwrap(); + let cloned = data.clone(); + assert_eq!(data.as_slice(), cloned.as_slice()); + } + + #[test] + fn test_sorted_clone() { + let data: Data<_> = vec![3.0, 1.0, 2.0].try_into().unwrap(); + let sorted = data.into_sorted(); + let cloned = sorted.clone(); + assert_eq!( + sorted.unwrap_inner().as_slice(), + cloned.unwrap_inner().as_slice() + ); + } + + #[test] + fn test_sigma_clip() { + // Test data with outliers: mean=5, std≈3.16 + // Values at ±3σ would be roughly -4.5 and 14.5 + let data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 100.0] + .try_into() + .unwrap(); + + // Clip at 2 sigma (should remove 100.0 as outlier) + let clipped = data.sigma_clip(2.0, 2.0, 1); + + // After removing 100.0, should have 9 values + assert!(clipped.len() < data.len()); + + // The clipped data should not contain the extreme outlier + assert!(!clipped.as_slice().contains(&100.0)); + } + + #[test] + fn test_sigma_clip_no_outliers() { + // Test data without outliers + let data: Data<_> = vec![1.0, 2.0, 3.0, 4.0, 5.0].try_into().unwrap(); + + // Clip at 3 sigma (should keep all data) + let clipped = data.sigma_clip(3.0, 3.0, 1); + + // All values should be within 3 sigma + assert!(clipped.len() <= data.len()); + } + + #[test] + fn test_uncertain_data_creation() { + use super::UncertainData; + + // Test creating from slices + let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let uncertainties = vec![0.1, 0.2, 0.1, 0.15, 0.1]; + + let data: UncertainData<_> = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + assert_eq!(data.values.len(), 5); + assert_eq!(data.uncertainties.len(), 5); + } + + #[test] + fn test_uncertain_data_filters_non_finite() { + use super::UncertainData; + + // Test that non-finite values are filtered out + let values = vec![1.0, f64::NAN, 3.0, f64::INFINITY, 5.0]; + let uncertainties = vec![0.1, 0.2, 0.1, 0.15, 0.1]; + + let data: UncertainData<_> = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + // Should have 3 values (1.0, 3.0, 5.0) + assert_eq!(data.values.len(), 3); + assert_eq!(data.uncertainties.len(), 3); + } + + #[test] + fn test_uncertain_data_unequal_lengths() { + use super::{DataError, UncertainData}; + + let values = vec![1.0, 2.0, 3.0]; + let uncertainties = vec![0.1, 0.2]; + + let result: Result, _> = + (values.as_slice(), uncertainties.as_slice()).try_into(); + + assert!(matches!(result, Err(DataError::UnequalLengths))); + } + + #[test] + fn test_reduced_chi2_perfect_fit() { + use super::UncertainData; + + // Test chi2 when all data points match the test value + let values = vec![5.0, 5.0, 5.0, 5.0]; + let uncertainties = vec![1.0, 1.0, 1.0, 1.0]; + + let data: UncertainData<_> = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + let chi2 = data.reduced_chi2(5.0); + assert_eq!(chi2, 0.0); + } + + #[test] + fn test_reduced_chi2_known_value() { + use super::UncertainData; + + // Test chi2 with known calculation + // If values are [4.0, 6.0] with uncertainties [1.0, 1.0] and test value is 5.0 + // chi2 = ((4-5)/1)^2 + ((6-5)/1)^2 = 1 + 1 = 2 + let values = vec![4.0, 6.0]; + let uncertainties = vec![1.0, 1.0]; + + let data: UncertainData<_> = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + let chi2 = data.reduced_chi2(5.0); + assert_eq!(chi2, 2.0); + } + + #[test] + fn test_reduced_chi2_with_different_uncertainties() { + use super::UncertainData; + + // Test chi2 with different uncertainties + // values = [3.0, 7.0], uncertainties = [2.0, 1.0], test value = 5.0 + // chi2 = ((3-5)/2)^2 + ((7-5)/1)^2 = 1.0 + 4.0 = 5.0 + let values = vec![3.0, 7.0]; + let uncertainties = vec![2.0, 1.0]; + + let data: UncertainData<_> = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + let chi2 = data.reduced_chi2(5.0); + assert_eq!(chi2, 5.0); + } + + #[test] + fn test_fit_reduced_chi2_simple() { + use super::UncertainData; + + // Test fitting chi2 - should return the mean when uncertainties are equal + let values = vec![4.0, 5.0, 6.0]; + let uncertainties = vec![1.0, 1.0, 1.0]; + + let data: UncertainData = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + let fitted = data.fit_reduced_chi2().unwrap(); + // Should converge to the mean (5.0) + assert!((fitted - 5.0).abs() < 1e-6); + } + + #[test] + fn test_fit_reduced_chi2_weighted() { + use super::UncertainData; + + // Test fitting with different uncertainties + // The fit should be weighted by 1/sigma^2 + let values = vec![1.0, 10.0]; + let uncertainties = vec![0.1, 10.0]; // First point has much smaller uncertainty + + let data: UncertainData = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + let fitted = data.fit_reduced_chi2().unwrap(); + // Should be much closer to 1.0 than 10.0 due to weighting + assert!(fitted < 5.0); + assert!((fitted - 1.0).abs() < 0.5); + } + + #[test] + fn test_fit_reduced_chi2_convergence() { + use super::UncertainData; + + // Test with more data points to ensure convergence + let values = vec![2.0, 3.0, 4.0, 5.0, 6.0]; + let uncertainties = vec![0.5, 0.5, 0.5, 0.5, 0.5]; + + let data: UncertainData = (values.as_slice(), uncertainties.as_slice()) + .try_into() + .unwrap(); + + let fitted = data.fit_reduced_chi2().unwrap(); + // Should converge to the mean (4.0) + assert!((fitted - 4.0).abs() < 1e-6); + + // Verify the chi2 at the fitted value is minimized + let chi2_at_fit = data.reduced_chi2(fitted); + let chi2_slightly_off = data.reduced_chi2(fitted + 0.1); + assert!(chi2_at_fit < chi2_slightly_off); + } +} diff --git a/src/kete_core/src/fitting/halley.rs b/src/kete_stats/src/fitting/halley.rs similarity index 71% rename from src/kete_core/src/fitting/halley.rs rename to src/kete_stats/src/fitting/halley.rs index fcaffee..bb1ca78 100644 --- a/src/kete_core/src/fitting/halley.rs +++ b/src/kete_stats/src/fitting/halley.rs @@ -5,7 +5,7 @@ // // BSD 3-Clause License // -// Copyright (c) 2025, California Institute of Technology +// Copyright (c) 2026, Dar Dahlen // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: @@ -32,7 +32,7 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{errors::Error, prelude::KeteResult}; +use crate::fitting::{ConvergenceError, FittingResult}; /// Solve root using Halley's method. /// @@ -41,8 +41,15 @@ use crate::{errors::Error, prelude::KeteResult}; /// respect to the input variable. The third is the second derivative. /// /// ``` -/// use kete_core::fitting::halley; -/// let f = |x| { 1.0 * x * x - 1.0 }; +/// use kete_stats::fitting::halley; +/// let f = |x: f64| { 1.0 * x * x - 1.0 }; +/// let d = |x| { 2.0 * x }; +/// let dd = |_| { 2.0}; +/// let root = halley(f, d, dd, 0.0, 1e-10).unwrap(); +/// assert!((root - 1.0).abs() < 1e-12); +/// +/// // Same but with f32 +/// let f = |x: f32| { 1.0 * x * x - 1.0 }; /// let d = |x| { 2.0 * x }; /// let dd = |_| { 2.0}; /// let root = halley(f, d, dd, 0.0, 1e-10).unwrap(); @@ -58,29 +65,39 @@ use crate::{errors::Error, prelude::KeteResult}; /// /// # Errors /// -/// [`Error::Convergence`] may be returned in the following cases: +/// [`ConvergenceError`] may be returned in the following cases: /// - Any function evaluation return a non-finite value. /// - Derivative is zero but not converged. /// - Failed to converge within 100 iterations. #[inline(always)] -pub fn halley( - func: impl Fn(f64) -> f64, - der: impl Fn(f64) -> f64, - sec_der: impl Fn(f64) -> f64, - start: f64, - atol: f64, -) -> KeteResult { +#[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." +)] +pub fn halley( + func: impl Fn(T) -> T, + der: impl Fn(T) -> T, + sec_der: impl Fn(T) -> T, + start: T, + atol: T, +) -> FittingResult +where + T: num_traits::Float + num_traits::ToPrimitive + num_traits::NumAssignOps, +{ let mut x = start; + let eps = T::epsilon() * T::from(1000.0).unwrap(); + let two = T::from(2.0).unwrap(); + // if the starting position has derivative of 0, nudge it a bit. - if der(x).abs() < 1e-12 { - x += 0.1; + if der(x).abs() < eps { + x += T::from(0.1).unwrap(); } - let mut f_eval: f64; - let mut d_eval: f64; - let mut d_d_eval: f64; - let mut step: f64; + let mut f_eval: T; + let mut d_eval: T; + let mut d_d_eval: T; + let mut step: T; for _ in 0..100 { f_eval = func(x); if f_eval.abs() < atol { @@ -89,27 +106,21 @@ pub fn halley( d_eval = der(x); // Derivative is 0, cannot solve - if d_eval.abs() < 1e-12 { - Err(Error::Convergence( - "Halley's root finding failed to converge due to zero derivative.".into(), - ))?; + if d_eval.abs() < eps { + Err(ConvergenceError::ZeroDerivative)?; } d_d_eval = sec_der(x); if !d_d_eval.is_finite() || !d_eval.is_finite() || !f_eval.is_finite() { - Err(Error::Convergence( - "Halley root finding failed to converge due to non-finite evaluations".into(), - ))?; + Err(ConvergenceError::NonFinite)?; } step = f_eval / d_eval; - step = step / (1.0 - step * d_d_eval / (2.0 * d_eval)); + step = step / (T::one() - step * d_d_eval / (two * d_eval)); x -= step; } - Err(Error::Convergence( - "Halley's root finding hit iteration limit without converging.".into(), - ))? + Err(ConvergenceError::Iterations)? } #[cfg(test)] @@ -118,7 +129,7 @@ mod tests { #[test] fn test_haley() { - let f = |x| 1.0 * x * x - 1.0; + let f = |x: f64| 1.0 * x * x - 1.0; let d = |x| 2.0 * x; let dd = |_| 2.0; let root = halley(f, d, dd, 0.0, 1e-10).unwrap(); diff --git a/src/kete_core/src/fitting/mod.rs b/src/kete_stats/src/fitting/mod.rs similarity index 70% rename from src/kete_core/src/fitting/mod.rs rename to src/kete_stats/src/fitting/mod.rs index 80a299f..ef22e42 100644 --- a/src/kete_core/src/fitting/mod.rs +++ b/src/kete_stats/src/fitting/mod.rs @@ -2,7 +2,7 @@ //! Fitting tools, including root finding. // BSD 3-Clause License // -// Copyright (c) 2025, California Institute of Technology +// Copyright (c) 2026, Dar Dahlen // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: @@ -31,8 +31,25 @@ mod halley; mod newton; -mod reduced_chi2; pub use self::halley::halley; pub use self::newton::newton_raphson; -pub use self::reduced_chi2::{fit_reduced_chi2, reduced_chi2, reduced_chi2_der}; + +/// Error type for fitting operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum ConvergenceError { + /// Maximum number of iterations reached without convergence. + #[error("Maximum number of iterations reached without convergence")] + Iterations, + + /// Non-finite value encountered during evaluation. + #[error("Non-finite value encountered during evaluation")] + NonFinite, + + /// Zero derivative encountered during evaluation. + #[error("Zero derivative encountered during evaluation")] + ZeroDerivative, +} + +/// Result type for fitting operations. +pub type FittingResult = Result; diff --git a/src/kete_core/src/fitting/newton.rs b/src/kete_stats/src/fitting/newton.rs similarity index 71% rename from src/kete_core/src/fitting/newton.rs rename to src/kete_stats/src/fitting/newton.rs index 3e645dc..e208105 100644 --- a/src/kete_core/src/fitting/newton.rs +++ b/src/kete_stats/src/fitting/newton.rs @@ -1,6 +1,6 @@ // BSD 3-Clause License // -// Copyright (c) 2025, California Institute of Technology +// Copyright (c) 2026, Dar Dahlen // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: @@ -27,7 +27,7 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{errors::Error, prelude::KeteResult}; +use crate::fitting::{ConvergenceError, FittingResult}; /// Solve root using the Newton-Raphson method. /// @@ -36,8 +36,8 @@ use crate::{errors::Error, prelude::KeteResult}; /// respect to the input variable. /// /// ``` -/// use kete_core::fitting::newton_raphson; -/// let f = |x| { 1.0 * x * x - 1.0 }; +/// use kete_stats::fitting::newton_raphson; +/// let f = |x: f64| { 1.0 * x * x - 1.0 }; /// let d = |x| { 2.0 * x }; /// let root = newton_raphson(f, d, 0.0, 1e-10).unwrap(); /// assert!((root - 1.0).abs() < 1e-12); @@ -45,25 +45,36 @@ use crate::{errors::Error, prelude::KeteResult}; /// /// # Errors /// -/// [`Error::Convergence`] may be returned in the following cases: +/// [`ConvergenceError`] may be returned in the following cases: /// - Any function evaluation return a non-finite value. /// - Derivative is zero but not converged. /// - Failed to converge within 100 iterations. #[inline(always)] -pub fn newton_raphson(func: Func, der: Der, start: f64, atol: f64) -> KeteResult +#[allow( + clippy::missing_panics_doc, + reason = "By construction this cannot panic." +)] +pub fn newton_raphson( + func: impl Fn(T) -> T, + der: impl Fn(T) -> T, + start: T, + atol: T, +) -> FittingResult where - Func: Fn(f64) -> f64, - Der: Fn(f64) -> f64, + T: num_traits::Float + num_traits::ToPrimitive + num_traits::NumAssignOps, { let mut x = start; + let eps = T::epsilon() * T::from(1000.0).unwrap(); + let half = T::from(0.5).unwrap(); + // if the starting position has derivative of 0, nudge it a bit. - if der(x).abs() < 1e-12 { - x += 0.1; + if der(x).abs() < eps { + x += T::from(0.1).unwrap(); } - let mut f_eval: f64; - let mut d_eval: f64; + let mut f_eval: T; + let mut d_eval: T; for _ in 0..100 { f_eval = func(x); if f_eval.abs() < atol { @@ -72,31 +83,24 @@ where d_eval = der(x); // Derivative is 0, cannot solve - if d_eval.abs() < 1e-12 { - Err(Error::Convergence( - "Newton-Raphson root finding failed to converge due to zero derivative.".into(), - ))?; + if d_eval.abs() < eps { + Err(ConvergenceError::ZeroDerivative)?; } if !d_eval.is_finite() || !f_eval.is_finite() { - Err(Error::Convergence( - "Newton-Raphson root finding failed to converge due to non-finite evaluations" - .into(), - ))?; + Err(ConvergenceError::NonFinite)?; } // 0.5 reduces the step size to slow down the rate of convergence. - x -= 0.5 * f_eval / d_eval; + x -= half * f_eval / d_eval; d_eval = der(x); - if d_eval.abs() < 1e-3 { + if d_eval.abs() < T::from(1e-3).unwrap() { f_eval = func(x); } - x -= 0.5 * f_eval / d_eval; + x -= half * f_eval / d_eval; } - Err(Error::Convergence( - "Newton-Raphson root finding hit iteration limit without converging.".into(), - ))? + Err(ConvergenceError::Iterations)? } #[cfg(test)] @@ -108,7 +112,7 @@ mod tests { let f = |x| 1.0 * x * x - 1.0; let d = |x| 2.0 * x; - let root = newton_raphson(f, d, 0.0, 1e-10).unwrap(); + let root: f64 = newton_raphson(f, d, 0.0, 1e-10).unwrap(); assert!((root - 1.0).abs() < 1e-12); } } diff --git a/src/kete_stats/src/lib.rs b/src/kete_stats/src/lib.rs new file mode 100644 index 0000000..c25a61d --- /dev/null +++ b/src/kete_stats/src/lib.rs @@ -0,0 +1,10 @@ +//! # Basic Statistics for astronomical data. +//! +//! This handles NaN gracefully for astronomical data sets. +mod data; +pub mod fitting; + +/// export all stats functionality +pub mod prelude { + pub use crate::data::{Data, DataError, SortedData, StatsResult, UncertainData}; +}