Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions math_explorer/src/applied/algorithms/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::fmt;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AlgorithmError {
SingularMatrix,
DimensionMismatch { expected: usize, actual: usize },
}

impl fmt::Display for AlgorithmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SingularMatrix => write!(f, "Matrix is singular (non-invertible)"),
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"Dimension mismatch: expected {}, got {}",
expected, actual
)
}
}
}
}

impl std::error::Error for AlgorithmError {}
7 changes: 3 additions & 4 deletions math_explorer/src/applied/algorithms/kalman.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! A dimension-agnostic implementation of the Discrete Kalman Filter using `nalgebra::DMatrix`.
//! This module allows for state estimation of linear systems with arbitrary state and measurement dimensions.

use super::error::AlgorithmError;
use nalgebra::{DMatrix, DVector};

/// Defines the physics/dynamics model for the Kalman Filter.
Expand Down Expand Up @@ -99,7 +100,7 @@ impl<M: KalmanModel> KalmanFilter<M> {
/// # Arguments
///
/// * `measurement` - The measurement vector $z_k$.
pub fn update(&mut self, measurement: &DVector<f64>) -> Result<(), String> {
pub fn update(&mut self, measurement: &DVector<f64>) -> Result<(), AlgorithmError> {
let h = self.model.measurement_matrix();
let r = self.model.measurement_noise();

Expand All @@ -112,9 +113,7 @@ impl<M: KalmanModel> KalmanFilter<M> {
// Invert S.
// For 1D measurements, this is trivial. For nD, we need matrix inversion.
// Kalman Filter requires S to be invertible (positive definite).
let s_inv = s
.try_inverse()
.ok_or("Failed to invert innovation covariance matrix (singular)")?;
let s_inv = s.try_inverse().ok_or(AlgorithmError::SingularMatrix)?;

// Kalman Gain K = P H^T S^-1
let k = &self.covariance * h.transpose() * s_inv;
Expand Down
3 changes: 3 additions & 0 deletions math_explorer/src/applied/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
//!
//! A collection of general-purpose algorithms.

pub mod error;
pub mod kalman;
pub mod sorting;

pub use error::AlgorithmError;
16 changes: 16 additions & 0 deletions math_explorer/src/applied/favoritism/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use std::fmt;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FavoritismError {
InvalidInput(String),
}

impl fmt::Display for FavoritismError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidInput(msg) => write!(f, "Invalid favoritism input: {}", msg),
}
}
}

impl std::error::Error for FavoritismError {}
4 changes: 3 additions & 1 deletion math_explorer/src/applied/favoritism/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@
//! println!("Your Favoritism Score: {:.2}", score);
//! ```

pub mod error;
pub mod favorite_child;
pub mod scoring;
pub mod strategies;
pub mod types;

pub use error::FavoritismError;
pub use scoring::{
calculate_favoritism_score, calculate_favoritism_score_full,
calculate_favoritism_score_with_rng,
calculate_favoritism_score_with_rng, try_calculate_favoritism_score,
};
pub use types::{
ComplimentParams, ContactParams, FamilyParams, FavoritismInputs, GiftParams, PersonalityParams,
Expand Down
9 changes: 9 additions & 0 deletions math_explorer/src/applied/favoritism/scoring.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::error::FavoritismError;
use super::strategies::UnifiedFavoritismModel;
use super::types::FavoritismInputs;
use crate::pure_math::analysis::integration::{ClenshawCurtis, Integrator};
Expand Down Expand Up @@ -39,6 +40,14 @@ pub fn calculate_favoritism_score(inputs: &FavoritismInputs) -> f64 {
calculate_favoritism_score_with_rng(inputs, &mut rng)
}

/// Tries to calculate the favoritism score, returning an error if inputs are invalid.
///
/// This is the safe version of `calculate_favoritism_score`.
pub fn try_calculate_favoritism_score(inputs: &FavoritismInputs) -> Result<f64, FavoritismError> {
inputs.validate()?;
Ok(calculate_favoritism_score(inputs))
}

/// Calculates the favoritism score using an injected RNG.
///
/// See `calculate_favoritism_score` for details.
Expand Down
58 changes: 58 additions & 0 deletions math_explorer/src/applied/favoritism/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::error::FavoritismError;
use nalgebra::DVector;

/// Time and proximity related parameters.
Expand Down Expand Up @@ -203,3 +204,60 @@ pub struct FavoritismInputs {
/// Family context (siblings).
pub family: FamilyParams,
}

impl FavoritismInputs {
/// Validates the input parameters.
///
/// Checks for NaN, infinite, or invalid negative parameters.
pub fn validate(&self) -> Result<(), FavoritismError> {
// Validate Time Params
if self.time.t < 0.0 || !self.time.t.is_finite() {
return Err(FavoritismError::InvalidInput(format!(
"Time t must be non-negative and finite, got {}",
self.time.t
)));
}
if self.time.x_0 < 0.0 || !self.time.x_0.is_finite() {
return Err(FavoritismError::InvalidInput(format!(
"Initial distance x_0 must be non-negative and finite, got {}",
self.time.x_0
)));
}

// Validate Gift Params
if self.gifts.g_emotional < 0.0 || !self.gifts.g_emotional.is_finite() {
return Err(FavoritismError::InvalidInput(format!(
"Emotional gift value must be non-negative, got {}",
self.gifts.g_emotional
)));
}
if self.gifts.g_practical < 0.0 || !self.gifts.g_practical.is_finite() {
return Err(FavoritismError::InvalidInput(format!(
"Practical gift value must be non-negative, got {}",
self.gifts.g_practical
)));
}

// Validate Contact Params
if self.contact.time_since_last_contact < 0.0
|| !self.contact.time_since_last_contact.is_finite()
{
return Err(FavoritismError::InvalidInput(format!(
"Time since last contact must be non-negative, got {}",
self.contact.time_since_last_contact
)));
}

// Validate Family Params (siblings)
for (i, d) in self.family.sibling_distances.iter().enumerate() {
if *d <= 0.0 || !d.is_finite() {
return Err(FavoritismError::InvalidInput(format!(
"Sibling distance must be positive and finite. Sibling {} has distance {}",
i, d
)));
}
}

Ok(())
}
}
5 changes: 3 additions & 2 deletions math_explorer/src/applied/lorahub/ensemble.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::error::LoraHubError;
use super::strategies::{
CombinationStrategy, L1RegularizationStrategy, LinearCombinationStrategy, ObjectiveStrategy,
};
Expand Down Expand Up @@ -51,9 +52,9 @@ impl LoraEnsemble {
/// * `weights` - A slice of weights corresponding to each LoRA module.
///
/// # Returns
/// A `Result` containing the combined `LoraStateDict`, or an error message
/// A `Result` containing the combined `LoraStateDict`, or a `LoraHubError`
/// if the inputs are invalid.
pub fn combine(&self, weights: &[f64]) -> Result<LoraStateDict, &'static str> {
pub fn combine(&self, weights: &[f64]) -> Result<LoraStateDict, LoraHubError> {
self.combination_strategy.combine(&self.modules, weights)
}

Expand Down
27 changes: 27 additions & 0 deletions math_explorer/src/applied/lorahub/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use std::fmt;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LoraHubError {
EmptyEnsemble,
EmptyWeights,
LengthMismatch,
ShapeMismatch,
StrategyError(String),
}

impl fmt::Display for LoraHubError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EmptyEnsemble => write!(f, "Ensemble is empty; cannot combine."),
Self::EmptyWeights => write!(f, "Weights cannot be empty."),
Self::LengthMismatch => write!(
f,
"The number of weights must match the number of modules in the ensemble."
),
Self::ShapeMismatch => write!(f, "Mismatched tensor shapes for the same key."),
Self::StrategyError(msg) => write!(f, "Strategy error: {}", msg),
}
}
}

impl std::error::Error for LoraHubError {}
2 changes: 2 additions & 0 deletions math_explorer/src/applied/lorahub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@
//! ```

pub mod ensemble;
pub mod error;
pub mod strategies;
pub mod types;

pub use ensemble::LoraEnsemble;
pub use error::LoraHubError;
pub use types::LoraStateDict;
17 changes: 10 additions & 7 deletions math_explorer/src/applied/lorahub/strategies.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::error::LoraHubError;
use super::types::LoraStateDict;

/// Strategy for combining multiple LoRA state dictionaries.
Expand All @@ -6,7 +7,7 @@ pub trait CombinationStrategy {
&self,
modules: &[LoraStateDict],
weights: &[f64],
) -> Result<LoraStateDict, &'static str>;
) -> Result<LoraStateDict, LoraHubError>;
}

/// Strategy for calculating the objective score.
Expand All @@ -22,15 +23,15 @@ impl CombinationStrategy for LinearCombinationStrategy {
&self,
modules: &[LoraStateDict],
weights: &[f64],
) -> Result<LoraStateDict, &'static str> {
) -> Result<LoraStateDict, LoraHubError> {
if modules.is_empty() {
return Err("Ensemble is empty; cannot combine.");
return Err(LoraHubError::EmptyEnsemble);
}
if weights.is_empty() {
return Err("Weights cannot be empty.");
return Err(LoraHubError::EmptyWeights);
}
if modules.len() != weights.len() {
return Err("The number of weights must match the number of modules in the ensemble.");
return Err(LoraHubError::LengthMismatch);
}

let first_lora = &modules[0];
Expand All @@ -45,7 +46,7 @@ impl CombinationStrategy for LinearCombinationStrategy {
for (key, final_tensor) in &mut final_state_dict {
if let Some(tensor) = lora_state_dict.get(key) {
if final_tensor.shape() != tensor.shape() {
return Err("Mismatched tensor shapes for the same key.");
return Err(LoraHubError::ShapeMismatch);
}
// Optimized in-place addition to avoid allocating intermediate matrix
let weight = weights[i];
Expand All @@ -55,7 +56,9 @@ impl CombinationStrategy for LinearCombinationStrategy {
*f += *t * weight;
}
} else {
return Err("Mismatched keys between LoRA modules.");
return Err(LoraHubError::StrategyError(
"Mismatched keys between LoRA modules.".to_string(),
));
}
}
}
Expand Down
78 changes: 78 additions & 0 deletions math_explorer/tests/refactor_validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use math_explorer::applied::algorithms::kalman::KalmanModel;
use math_explorer::applied::algorithms::{AlgorithmError, kalman::KalmanFilter};
use math_explorer::applied::favoritism::{
FavoritismError, FavoritismInputs, try_calculate_favoritism_score,
};
use math_explorer::applied::lorahub::{LoraEnsemble, LoraHubError, LoraStateDict};
use nalgebra::{DMatrix, DVector};

#[test]
fn test_favoritism_validation_negative_time() {
let mut inputs = FavoritismInputs::default();
inputs.time.t = -1.0;
let result = inputs.validate();
assert!(matches!(result, Err(FavoritismError::InvalidInput(_))));

let result = try_calculate_favoritism_score(&inputs);
assert!(matches!(result, Err(FavoritismError::InvalidInput(_))));
}

#[test]
fn test_favoritism_validation_negative_sibling_distance() {
let mut inputs = FavoritismInputs::default();
inputs.family.sibling_distances = vec![100.0, -10.0];
let result = inputs.validate();
assert!(matches!(result, Err(FavoritismError::InvalidInput(_))));
}

#[test]
fn test_lorahub_empty_ensemble_error() {
let modules = vec![];
let ensemble = LoraEnsemble::new(modules);
let weights = vec![0.5];
let result = ensemble.combine(&weights);
assert_eq!(result.unwrap_err(), LoraHubError::EmptyEnsemble);
}

#[test]
fn test_lorahub_length_mismatch_error() {
let mut lora_1 = LoraStateDict::new();
lora_1.insert("layer1".to_string(), DMatrix::from_element(1, 1, 1.0));
let modules = vec![lora_1];
let ensemble = LoraEnsemble::new(modules);
let weights = vec![0.5, 0.5]; // Mismatch
let result = ensemble.combine(&weights);
assert_eq!(result.unwrap_err(), LoraHubError::LengthMismatch);
}

// Mock Model for testing Kalman error
struct MockModel;
impl KalmanModel for MockModel {
fn transition_matrix(&self, _dt: f64) -> DMatrix<f64> {
DMatrix::identity(1, 1)
}
fn measurement_matrix(&self) -> DMatrix<f64> {
DMatrix::identity(1, 1)
}
fn process_noise(&self, _dt: f64) -> DMatrix<f64> {
DMatrix::identity(1, 1)
}
fn measurement_noise(&self) -> DMatrix<f64> {
DMatrix::zeros(1, 1)
} // Zero noise might cause singularity if covariance is also zero/singular
}

#[test]
fn test_kalman_singular_error() {
let model = MockModel;
let x_init = DVector::from_element(1, 0.0);
// Zero covariance and zero measurement noise -> Innovation Covariance S = HPH^T + R = 0.
// Inverting 0 will fail.
let p_init = DMatrix::zeros(1, 1);

let mut kf = KalmanFilter::new(x_init, p_init, model, 1.0);
let measurement = DVector::from_element(1, 1.0);

let result = kf.update(&measurement);
assert_eq!(result.unwrap_err(), AlgorithmError::SingularMatrix);
}