Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct NPAG<E: Equation + Send + 'static> {
ranges: Vec<(f64, f64)>,
psi: Psi,
theta: Theta,
theta_old: Option<Theta>, // Store previous theta for CHECKBIG calculation
lambda: Weights,
w: Weights,
eps: f64,
Expand All @@ -57,6 +58,7 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
ranges: settings.parameters().ranges(),
psi: Psi::new(),
theta: Theta::new(),
theta_old: None, // Initialize as None (no previous theta yet)
lambda: Weights::default(),
w: Weights::default(),
eps: 0.2,
Expand Down Expand Up @@ -162,18 +164,56 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
if self.eps <= THETA_E {
let pyl = psi * w.weights();
self.f1 = pyl.iter().map(|x| x.ln()).sum();
if (self.f1 - self.f0).abs() <= THETA_F {
tracing::info!("The model converged after {} cycles", self.cycle,);

// Calculate CHECKBIG if we have a previous theta
let checkbig = if let Some(ref old_theta) = self.theta_old {
Some(self.theta.max_relative_difference(&old_theta)?)
} else {
None
};

let f1_f0_diff = (self.f1 - self.f0).abs();

// Log convergence metrics for diagnostics
match checkbig {
Some(cb) => tracing::debug!(
"f1-f0={:.6e} (threshold={:.6e}), CHECKBIG={:.6e} (threshold={:.6e})",
f1_f0_diff,
THETA_F,
cb,
THETA_E
),
None => tracing::debug!(
"f1-f0={:.6e} (threshold={:.6e}), CHECKBIG=N/A (no previous theta)",
f1_f0_diff,
THETA_F
),
}

// Standard likelihood convergence check
if f1_f0_diff <= THETA_F {
tracing::info!("The model converged according to the LIKELIHOOD criteria",);
self.set_status(Status::Stop(StopReason::Converged));
self.log_cycle_state();
return Ok(self.status().clone());
} else if let Some(cb) = checkbig {
// Additional CHECKBIG convergence check
if cb <= THETA_E {
tracing::info!("The model converged according to the CHECKBIG criteria",);
self.set_status(Status::Stop(StopReason::Converged));
self.log_cycle_state();
return Ok(self.status().clone());
}
} else {
self.f0 = self.f1;
self.eps = 0.2;
}
}
}

// Save current theta for next cycle's CHECKBIG calculation
self.theta_old = Some(self.theta.clone());

// Stop if we have reached maximum number of cycles
if self.cycle >= self.settings.config().cycles {
tracing::warn!("Maximum number of cycles reached");
Expand Down
88 changes: 87 additions & 1 deletion src/structs/theta.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Debug;

use anyhow::{bail, Result};
use faer::Mat;
use faer::{ColRef, Mat};
use serde::{Deserialize, Serialize};

use crate::prelude::Parameters;
Expand Down Expand Up @@ -201,6 +201,41 @@ impl Theta {

Theta::from_parts(mat, parameters)
}

/// Compute the maximum relative difference in medians across parameters between two Thetas
///
/// This is useful for assessing convergence between iterations
/// # Errors
/// Returns an error if the number of parameters (columns) do not match between the two Thetas
pub fn max_relative_difference(&self, other: &Theta) -> Result<f64> {
if self.matrix.ncols() != other.matrix.ncols() {
bail!("Number of parameters (columns) do not match between Thetas");
}

fn median_col(col: ColRef<f64>) -> f64 {
let mut vals: Vec<&f64> = col.iter().collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = vals.len() / 2;
if vals.len() % 2 == 0 {
(vals[mid - 1] + vals[mid]) / 2.0
} else {
*vals[mid]
}
}

let mut max_rel_diff = 0.0;
for i in 0..self.matrix.ncols() {
let current_median = median_col(self.matrix.col(i));
let other_median = median_col(other.matrix.col(i));

let denom = current_median.abs().max(other_median.abs()).max(1e-8); // Avoid division by zero
let rel_diff = ((current_median - other_median).abs()) / denom;
if rel_diff > max_rel_diff {
max_rel_diff = rel_diff;
}
}
Ok(max_rel_diff)
}
}

impl Debug for Theta {
Expand Down Expand Up @@ -379,4 +414,55 @@ mod tests {

assert_eq!(theta.matrix(), &new_matrix);
}

#[test]
fn test_max_relative_difference() {
let matrix1 = mat![[2.0, 4.0], [6.0, 8.0]];
let matrix2 = mat![[2.0, 4.0], [8.0, 8.0]];
let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
let theta1 = Theta::from_parts(matrix1, parameters.clone()).unwrap();
let theta2 = Theta::from_parts(matrix2, parameters).unwrap();
let max_rel_diff = theta1.max_relative_difference(&theta2).unwrap();
println!("Max relative difference: {}", max_rel_diff);
assert!((max_rel_diff - 0.2).abs() < 1e-6);
}

#[test]
fn test_max_relative_difference_same_theta() {
let matrix1 = mat![[1.0, 2.0], [3.0, 4.0]];
let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
let theta1 = Theta::from_parts(matrix1, parameters.clone()).unwrap();
let theta2 = theta1.clone();
let max_rel_diff = theta1.max_relative_difference(&theta2).unwrap();
println!("Max relative difference: {}", max_rel_diff);
assert!((max_rel_diff - 0.0).abs() < 1e-6);
}

#[test]
fn test_max_relative_difference_shape_error() {
let matrix1 = mat![[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]];
let matrix2 = mat![[2.0, 4.0], [8.0, 8.0]];
let parameters1 = Parameters::new()
.add("A", 0.0, 10.0)
.add("B", 0.0, 10.0)
.add("C", 0.0, 10.0);
let parameters2 = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
let theta1 = Theta::from_parts(matrix1, parameters1).unwrap();
let theta2 = Theta::from_parts(matrix2, parameters2).unwrap();
let result = theta1.max_relative_difference(&theta2);
assert!(result.is_err());
}

#[test]
fn test_max_relative_difference_odd_length() {
let matrix1 = mat![[1.0, 2.0], [3.0, 6.0], [5.0, 10.0]];
let matrix2 = mat![[1.0, 2.0], [4.0, 6.0], [5.0, 10.0]];
let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
let theta1 = Theta::from_parts(matrix1, parameters.clone()).unwrap();
let theta2 = Theta::from_parts(matrix2, parameters).unwrap();
let max_rel_diff = theta1.max_relative_difference(&theta2).unwrap();
println!("Max relative difference (odd length): {}", max_rel_diff);

assert!((max_rel_diff - 0.25).abs() < 1e-6);
}
}
11 changes: 11 additions & 0 deletions src/structs/weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,25 @@ impl Weights {
self.weights.nrows()
}

/// Check if there are no weights.
pub fn is_empty(&self) -> bool {
self.weights.nrows() == 0
}

/// Get a vector representation of the weights.
pub fn to_vec(&self) -> Vec<f64> {
self.weights.iter().cloned().collect()
}

/// Get an iterator over the weights.
pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
self.weights.iter().cloned()
}

/// Get a mutable iterator over the weights.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut f64> + '_ {
self.weights.iter_mut()
}
}

impl Serialize for Weights {
Expand Down
Loading