diff --git a/math_explorer/src/applied/algorithms/error.rs b/math_explorer/src/applied/algorithms/error.rs new file mode 100644 index 0000000..dd56926 --- /dev/null +++ b/math_explorer/src/applied/algorithms/error.rs @@ -0,0 +1,24 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AlgorithmError { + SingularMatrix, + DimensionMismatch { expected: usize, actual: usize }, +} + +impl fmt::Display for AlgorithmError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::SingularMatrix => write!(f, "Matrix is singular (non-invertible)"), + Self::DimensionMismatch { expected, actual } => { + write!( + f, + "Dimension mismatch: expected {}, got {}", + expected, actual + ) + } + } + } +} + +impl std::error::Error for AlgorithmError {} diff --git a/math_explorer/src/applied/algorithms/kalman.rs b/math_explorer/src/applied/algorithms/kalman.rs index cf28573..d865985 100644 --- a/math_explorer/src/applied/algorithms/kalman.rs +++ b/math_explorer/src/applied/algorithms/kalman.rs @@ -3,6 +3,7 @@ //! A dimension-agnostic implementation of the Discrete Kalman Filter using `nalgebra::DMatrix`. //! This module allows for state estimation of linear systems with arbitrary state and measurement dimensions. +use super::error::AlgorithmError; use nalgebra::{DMatrix, DVector}; /// Defines the physics/dynamics model for the Kalman Filter. @@ -99,7 +100,7 @@ impl KalmanFilter { /// # Arguments /// /// * `measurement` - The measurement vector $z_k$. - pub fn update(&mut self, measurement: &DVector) -> Result<(), String> { + pub fn update(&mut self, measurement: &DVector) -> Result<(), AlgorithmError> { let h = self.model.measurement_matrix(); let r = self.model.measurement_noise(); @@ -112,9 +113,7 @@ impl KalmanFilter { // Invert S. // For 1D measurements, this is trivial. For nD, we need matrix inversion. // Kalman Filter requires S to be invertible (positive definite). - let s_inv = s - .try_inverse() - .ok_or("Failed to invert innovation covariance matrix (singular)")?; + let s_inv = s.try_inverse().ok_or(AlgorithmError::SingularMatrix)?; // Kalman Gain K = P H^T S^-1 let k = &self.covariance * h.transpose() * s_inv; diff --git a/math_explorer/src/applied/algorithms/mod.rs b/math_explorer/src/applied/algorithms/mod.rs index 3db160d..7d8f866 100644 --- a/math_explorer/src/applied/algorithms/mod.rs +++ b/math_explorer/src/applied/algorithms/mod.rs @@ -2,5 +2,8 @@ //! //! A collection of general-purpose algorithms. +pub mod error; pub mod kalman; pub mod sorting; + +pub use error::AlgorithmError; diff --git a/math_explorer/src/applied/favoritism/error.rs b/math_explorer/src/applied/favoritism/error.rs new file mode 100644 index 0000000..b87c63c --- /dev/null +++ b/math_explorer/src/applied/favoritism/error.rs @@ -0,0 +1,16 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FavoritismError { + InvalidInput(String), +} + +impl fmt::Display for FavoritismError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidInput(msg) => write!(f, "Invalid favoritism input: {}", msg), + } + } +} + +impl std::error::Error for FavoritismError {} diff --git a/math_explorer/src/applied/favoritism/mod.rs b/math_explorer/src/applied/favoritism/mod.rs index 5bec953..c66a39b 100644 --- a/math_explorer/src/applied/favoritism/mod.rs +++ b/math_explorer/src/applied/favoritism/mod.rs @@ -76,14 +76,16 @@ //! println!("Your Favoritism Score: {:.2}", score); //! ``` +pub mod error; pub mod favorite_child; pub mod scoring; pub mod strategies; pub mod types; +pub use error::FavoritismError; pub use scoring::{ calculate_favoritism_score, calculate_favoritism_score_full, - calculate_favoritism_score_with_rng, + calculate_favoritism_score_with_rng, try_calculate_favoritism_score, }; pub use types::{ ComplimentParams, ContactParams, FamilyParams, FavoritismInputs, GiftParams, PersonalityParams, diff --git a/math_explorer/src/applied/favoritism/scoring.rs b/math_explorer/src/applied/favoritism/scoring.rs index a96567d..3054094 100644 --- a/math_explorer/src/applied/favoritism/scoring.rs +++ b/math_explorer/src/applied/favoritism/scoring.rs @@ -1,3 +1,4 @@ +use super::error::FavoritismError; use super::strategies::UnifiedFavoritismModel; use super::types::FavoritismInputs; use crate::pure_math::analysis::integration::{ClenshawCurtis, Integrator}; @@ -39,6 +40,14 @@ pub fn calculate_favoritism_score(inputs: &FavoritismInputs) -> f64 { calculate_favoritism_score_with_rng(inputs, &mut rng) } +/// Tries to calculate the favoritism score, returning an error if inputs are invalid. +/// +/// This is the safe version of `calculate_favoritism_score`. +pub fn try_calculate_favoritism_score(inputs: &FavoritismInputs) -> Result { + inputs.validate()?; + Ok(calculate_favoritism_score(inputs)) +} + /// Calculates the favoritism score using an injected RNG. /// /// See `calculate_favoritism_score` for details. diff --git a/math_explorer/src/applied/favoritism/types.rs b/math_explorer/src/applied/favoritism/types.rs index 89b8078..a125d48 100644 --- a/math_explorer/src/applied/favoritism/types.rs +++ b/math_explorer/src/applied/favoritism/types.rs @@ -1,3 +1,4 @@ +use super::error::FavoritismError; use nalgebra::DVector; /// Time and proximity related parameters. @@ -203,3 +204,60 @@ pub struct FavoritismInputs { /// Family context (siblings). pub family: FamilyParams, } + +impl FavoritismInputs { + /// Validates the input parameters. + /// + /// Checks for NaN, infinite, or invalid negative parameters. + pub fn validate(&self) -> Result<(), FavoritismError> { + // Validate Time Params + if self.time.t < 0.0 || !self.time.t.is_finite() { + return Err(FavoritismError::InvalidInput(format!( + "Time t must be non-negative and finite, got {}", + self.time.t + ))); + } + if self.time.x_0 < 0.0 || !self.time.x_0.is_finite() { + return Err(FavoritismError::InvalidInput(format!( + "Initial distance x_0 must be non-negative and finite, got {}", + self.time.x_0 + ))); + } + + // Validate Gift Params + if self.gifts.g_emotional < 0.0 || !self.gifts.g_emotional.is_finite() { + return Err(FavoritismError::InvalidInput(format!( + "Emotional gift value must be non-negative, got {}", + self.gifts.g_emotional + ))); + } + if self.gifts.g_practical < 0.0 || !self.gifts.g_practical.is_finite() { + return Err(FavoritismError::InvalidInput(format!( + "Practical gift value must be non-negative, got {}", + self.gifts.g_practical + ))); + } + + // Validate Contact Params + if self.contact.time_since_last_contact < 0.0 + || !self.contact.time_since_last_contact.is_finite() + { + return Err(FavoritismError::InvalidInput(format!( + "Time since last contact must be non-negative, got {}", + self.contact.time_since_last_contact + ))); + } + + // Validate Family Params (siblings) + for (i, d) in self.family.sibling_distances.iter().enumerate() { + if *d <= 0.0 || !d.is_finite() { + return Err(FavoritismError::InvalidInput(format!( + "Sibling distance must be positive and finite. Sibling {} has distance {}", + i, d + ))); + } + } + + Ok(()) + } +} diff --git a/math_explorer/src/applied/lorahub/ensemble.rs b/math_explorer/src/applied/lorahub/ensemble.rs index 3025f04..476844f 100644 --- a/math_explorer/src/applied/lorahub/ensemble.rs +++ b/math_explorer/src/applied/lorahub/ensemble.rs @@ -1,3 +1,4 @@ +use super::error::LoraHubError; use super::strategies::{ CombinationStrategy, L1RegularizationStrategy, LinearCombinationStrategy, ObjectiveStrategy, }; @@ -51,9 +52,9 @@ impl LoraEnsemble { /// * `weights` - A slice of weights corresponding to each LoRA module. /// /// # Returns - /// A `Result` containing the combined `LoraStateDict`, or an error message + /// A `Result` containing the combined `LoraStateDict`, or a `LoraHubError` /// if the inputs are invalid. - pub fn combine(&self, weights: &[f64]) -> Result { + pub fn combine(&self, weights: &[f64]) -> Result { self.combination_strategy.combine(&self.modules, weights) } diff --git a/math_explorer/src/applied/lorahub/error.rs b/math_explorer/src/applied/lorahub/error.rs new file mode 100644 index 0000000..551d089 --- /dev/null +++ b/math_explorer/src/applied/lorahub/error.rs @@ -0,0 +1,27 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LoraHubError { + EmptyEnsemble, + EmptyWeights, + LengthMismatch, + ShapeMismatch, + StrategyError(String), +} + +impl fmt::Display for LoraHubError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::EmptyEnsemble => write!(f, "Ensemble is empty; cannot combine."), + Self::EmptyWeights => write!(f, "Weights cannot be empty."), + Self::LengthMismatch => write!( + f, + "The number of weights must match the number of modules in the ensemble." + ), + Self::ShapeMismatch => write!(f, "Mismatched tensor shapes for the same key."), + Self::StrategyError(msg) => write!(f, "Strategy error: {}", msg), + } + } +} + +impl std::error::Error for LoraHubError {} diff --git a/math_explorer/src/applied/lorahub/mod.rs b/math_explorer/src/applied/lorahub/mod.rs index da81ad0..be6bfe7 100644 --- a/math_explorer/src/applied/lorahub/mod.rs +++ b/math_explorer/src/applied/lorahub/mod.rs @@ -75,8 +75,10 @@ //! ``` pub mod ensemble; +pub mod error; pub mod strategies; pub mod types; pub use ensemble::LoraEnsemble; +pub use error::LoraHubError; pub use types::LoraStateDict; diff --git a/math_explorer/src/applied/lorahub/strategies.rs b/math_explorer/src/applied/lorahub/strategies.rs index 6b863b7..cae7379 100644 --- a/math_explorer/src/applied/lorahub/strategies.rs +++ b/math_explorer/src/applied/lorahub/strategies.rs @@ -1,3 +1,4 @@ +use super::error::LoraHubError; use super::types::LoraStateDict; /// Strategy for combining multiple LoRA state dictionaries. @@ -6,7 +7,7 @@ pub trait CombinationStrategy { &self, modules: &[LoraStateDict], weights: &[f64], - ) -> Result; + ) -> Result; } /// Strategy for calculating the objective score. @@ -22,15 +23,15 @@ impl CombinationStrategy for LinearCombinationStrategy { &self, modules: &[LoraStateDict], weights: &[f64], - ) -> Result { + ) -> Result { if modules.is_empty() { - return Err("Ensemble is empty; cannot combine."); + return Err(LoraHubError::EmptyEnsemble); } if weights.is_empty() { - return Err("Weights cannot be empty."); + return Err(LoraHubError::EmptyWeights); } if modules.len() != weights.len() { - return Err("The number of weights must match the number of modules in the ensemble."); + return Err(LoraHubError::LengthMismatch); } let first_lora = &modules[0]; @@ -45,7 +46,7 @@ impl CombinationStrategy for LinearCombinationStrategy { for (key, final_tensor) in &mut final_state_dict { if let Some(tensor) = lora_state_dict.get(key) { if final_tensor.shape() != tensor.shape() { - return Err("Mismatched tensor shapes for the same key."); + return Err(LoraHubError::ShapeMismatch); } // Optimized in-place addition to avoid allocating intermediate matrix let weight = weights[i]; @@ -55,7 +56,9 @@ impl CombinationStrategy for LinearCombinationStrategy { *f += *t * weight; } } else { - return Err("Mismatched keys between LoRA modules."); + return Err(LoraHubError::StrategyError( + "Mismatched keys between LoRA modules.".to_string(), + )); } } } diff --git a/math_explorer/tests/refactor_validation.rs b/math_explorer/tests/refactor_validation.rs new file mode 100644 index 0000000..e93d19b --- /dev/null +++ b/math_explorer/tests/refactor_validation.rs @@ -0,0 +1,78 @@ +use math_explorer::applied::algorithms::kalman::KalmanModel; +use math_explorer::applied::algorithms::{AlgorithmError, kalman::KalmanFilter}; +use math_explorer::applied::favoritism::{ + FavoritismError, FavoritismInputs, try_calculate_favoritism_score, +}; +use math_explorer::applied::lorahub::{LoraEnsemble, LoraHubError, LoraStateDict}; +use nalgebra::{DMatrix, DVector}; + +#[test] +fn test_favoritism_validation_negative_time() { + let mut inputs = FavoritismInputs::default(); + inputs.time.t = -1.0; + let result = inputs.validate(); + assert!(matches!(result, Err(FavoritismError::InvalidInput(_)))); + + let result = try_calculate_favoritism_score(&inputs); + assert!(matches!(result, Err(FavoritismError::InvalidInput(_)))); +} + +#[test] +fn test_favoritism_validation_negative_sibling_distance() { + let mut inputs = FavoritismInputs::default(); + inputs.family.sibling_distances = vec![100.0, -10.0]; + let result = inputs.validate(); + assert!(matches!(result, Err(FavoritismError::InvalidInput(_)))); +} + +#[test] +fn test_lorahub_empty_ensemble_error() { + let modules = vec![]; + let ensemble = LoraEnsemble::new(modules); + let weights = vec![0.5]; + let result = ensemble.combine(&weights); + assert_eq!(result.unwrap_err(), LoraHubError::EmptyEnsemble); +} + +#[test] +fn test_lorahub_length_mismatch_error() { + let mut lora_1 = LoraStateDict::new(); + lora_1.insert("layer1".to_string(), DMatrix::from_element(1, 1, 1.0)); + let modules = vec![lora_1]; + let ensemble = LoraEnsemble::new(modules); + let weights = vec![0.5, 0.5]; // Mismatch + let result = ensemble.combine(&weights); + assert_eq!(result.unwrap_err(), LoraHubError::LengthMismatch); +} + +// Mock Model for testing Kalman error +struct MockModel; +impl KalmanModel for MockModel { + fn transition_matrix(&self, _dt: f64) -> DMatrix { + DMatrix::identity(1, 1) + } + fn measurement_matrix(&self) -> DMatrix { + DMatrix::identity(1, 1) + } + fn process_noise(&self, _dt: f64) -> DMatrix { + DMatrix::identity(1, 1) + } + fn measurement_noise(&self) -> DMatrix { + DMatrix::zeros(1, 1) + } // Zero noise might cause singularity if covariance is also zero/singular +} + +#[test] +fn test_kalman_singular_error() { + let model = MockModel; + let x_init = DVector::from_element(1, 0.0); + // Zero covariance and zero measurement noise -> Innovation Covariance S = HPH^T + R = 0. + // Inverting 0 will fail. + let p_init = DMatrix::zeros(1, 1); + + let mut kf = KalmanFilter::new(x_init, p_init, model, 1.0); + let measurement = DVector::from_element(1, 1.0); + + let result = kf.update(&measurement); + assert_eq!(result.unwrap_err(), AlgorithmError::SingularMatrix); +}