From 31154450dfcef037f61936dcb1a977d96c151836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 27 Nov 2025 02:04:02 +0000 Subject: [PATCH 1/3] log-likelihood --- Cargo.toml | 2 +- src/algorithms/mod.rs | 44 ++- src/algorithms/npag.rs | 80 +++--- src/algorithms/npod.rs | 130 ++++----- src/algorithms/postprob.rs | 12 +- src/bestdose/posterior.rs | 14 +- src/routines/condensation/mod.rs | 6 +- src/routines/estimation/ipm.rs | 460 +++++++++++++++++++++++++++++++ src/routines/estimation/qr.rs | 55 +++- src/routines/math.rs | 176 ++++++++++++ src/routines/mod.rs | 2 + src/routines/output/mod.rs | 4 +- src/routines/output/posterior.rs | 53 ++-- src/routines/settings.rs | 8 + src/structs/psi.rs | 144 +++++++++- 15 files changed, 1003 insertions(+), 187 deletions(-) create mode 100644 src/routines/math.rs diff --git a/Cargo.toml b/Cargo.toml index 69a8b0d2a..2341051bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [ ] } faer = "0.23.1" faer-ext = { version = "0.7.1", features = ["nalgebra", "ndarray"] } -pharmsol = "=0.21.0" +pharmsol = "=0.22.0" rand = "0.9.0" anyhow = "1.0.100" rayon = "1.10.0" diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index ab5c16912..5c5dd1edb 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -1,6 +1,7 @@ use std::fs; use std::path::Path; +use crate::routines::math::logsumexp_rows; use crate::routines::output::NPResult; use crate::routines::settings::Settings; use crate::structs::psi::Psi; @@ -36,35 +37,60 @@ pub trait Algorithms: Sync + Send + 'static { // Count problematic values in psi let mut nan_count = 0; let mut inf_count = 0; + let is_log_space = self.psi().is_log_space(); let psi = self.psi().matrix().as_ref().into_ndarray(); - // First coerce all NaN and infinite in psi to 0.0 + // First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space) for i in 0..psi.nrows() { for j in 0..self.psi().matrix().ncols() { let val = psi.get((i, j)).unwrap(); if val.is_nan() { nan_count += 1; - // *val = 0.0; } else if val.is_infinite() { - inf_count += 1; - // *val = 0.0; + // In log-space, NEG_INFINITY is valid (represents zero probability) + // Only count positive infinity as problematic + if !is_log_space || val.is_sign_positive() { + inf_count += 1; + } } } } if nan_count + inf_count > 0 { tracing::warn!( - "Psi matrix contains {} NaN, {} Infinite values of {} total values", + "Psi matrix contains {} NaN, {} problematic Infinite values of {} total values", nan_count, inf_count, psi.ncols() * psi.nrows() ); } - let (_, col) = psi.dim(); - let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); - let plam = psi.dot(&ecol); - let w = 1. / &plam; + // Calculate row sums: for regular space: sum; for log-space: logsumexp + let plam: ArrayBase, Dim<[usize; 1]>> = if is_log_space { + // For log-space, use logsumexp for each row + Array::from_vec(logsumexp_rows(psi.nrows(), psi.ncols(), |i, j| psi[(i, j)])) + } else { + // For regular space, sum each row + let (_, col) = psi.dim(); + let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); + psi.dot(&ecol) + }; + + // Check for subjects with zero probability + // In log-space: -inf means zero probability + // In regular space: 0 means zero probability + let w: ArrayBase, Dim<[usize; 1]>> = if is_log_space { + // For log-space, check if logsumexp result is -inf + Array::from_shape_fn(plam.len(), |i| { + if plam[i].is_infinite() && plam[i].is_sign_negative() { + f64::INFINITY // Will be flagged as problematic + } else { + 1.0 // Valid + } + }) + } else { + 1. / &plam + }; // Get the index of each element in `w` that is NaN or infinite let indices: Vec = w diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 68ed04693..7c377b020 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -1,12 +1,13 @@ use crate::algorithms::{Status, StopReason}; use crate::prelude::algorithms::Algorithms; -pub use crate::routines::estimation::ipm::burke; +pub use crate::routines::estimation::ipm::{burke, burke_ipm, burke_log}; pub use crate::routines::estimation::qr; +use crate::routines::math::logsumexp; use crate::routines::settings::Settings; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; -use crate::structs::psi::{calculate_psi, Psi}; +use crate::structs::psi::{calculate_psi_dispatch, Psi}; use crate::structs::theta::Theta; use crate::structs::weights::Weights; @@ -160,8 +161,24 @@ impl Algorithms for NPAG { if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E { self.eps /= 2.; if self.eps <= THETA_E { - let pyl = psi * w.weights(); - self.f1 = pyl.iter().map(|x| x.ln()).sum(); + // Compute f1 = sum(log(pyl)) where pyl = psi * w + self.f1 = if self.psi.is_log_space() { + // For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w))) + let log_w: Vec = w.weights().iter().map(|&x| x.ln()).collect(); + (0..psi.nrows()) + .map(|i| { + let combined: Vec = (0..psi.ncols()) + .map(|j| *psi.get(i, j) + log_w[j]) + .collect(); + logsumexp(&combined) + }) + .sum() + } else { + // For regular space: f1 = sum(log(psi * w)) + let pyl = psi * w.weights(); + pyl.iter().map(|x| x.ln()).sum() + }; + if (self.f1 - self.f0).abs() <= THETA_F { tracing::info!("The model converged after {} cycles", self.cycle,); self.set_status(Status::Stop(StopReason::Converged)); @@ -197,31 +214,29 @@ impl Algorithms for NPAG { } fn estimation(&mut self) -> Result<()> { - self.psi = calculate_psi( + let use_log_space = self.settings.advanced().log_space; + + self.psi = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, &self.error_models, self.cycle == 1 && self.settings.config().progress, self.cycle != 1, + use_log_space, )?; if let Err(err) = self.validate_psi() { bail!(err); } - (self.lambda, _) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during estimation: {:?}", err); - } - }; + (self.lambda, _) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?; Ok(()) } fn condensation(&mut self) -> Result<()> { // Filter out the support points with lambda < max(lambda)/1000 - let max_lambda = self .lambda .iter() @@ -273,20 +288,16 @@ impl Algorithms for NPAG { self.psi.filter_column_indices(keep.as_slice()); self.validate_psi()?; - (self.lambda, self.objf) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - return Err(anyhow::anyhow!( - "Error in IPM during condensation: {:?}", - err - )); - } - }; + + (self.lambda, self.objf) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?; self.w = self.lambda.clone(); Ok(()) } fn optimizations(&mut self) -> Result<()> { + let use_log_space = self.settings.advanced().log_space; + self.error_models .clone() .iter_mut() @@ -298,8 +309,6 @@ impl Algorithms for NPAG { } }) .try_for_each(|(outeq, em)| -> Result<()> { - // OPTIMIZATION - let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]); let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]); @@ -309,35 +318,32 @@ impl Algorithms for NPAG { let mut error_model_down = self.error_models.clone(); error_model_down.set_factor(outeq, gamma_down)?; - let psi_up = calculate_psi( + let psi_up = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, &error_model_up, false, true, + use_log_space, )?; - let psi_down = calculate_psi( + + let psi_down = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, &error_model_down, false, true, + use_log_space, )?; - let (lambda_up, objf_up) = match burke(&psi_up) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; - let (lambda_down, objf_down) = match burke(&psi_down) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; + let (lambda_up, objf_up) = burke_ipm(&psi_up) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + + let (lambda_down, objf_down) = burke_ipm(&psi_down) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + if objf_up > self.objf { self.error_models.set_factor(outeq, gamma_up)?; self.objf = objf_up; diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index ed962d971..50f2a299c 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,5 +1,6 @@ use crate::algorithms::StopReason; use crate::routines::initialization::sample_space; +use crate::routines::math::logsumexp; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; use crate::structs::weights::Weights; use crate::{ @@ -7,12 +8,12 @@ use crate::{ prelude::{ algorithms::Algorithms, routines::{ - estimation::{ipm::burke, qr}, + estimation::{ipm::burke_ipm, qr}, settings::Settings, }, }, structs::{ - psi::{calculate_psi, Psi}, + psi::{calculate_psi_dispatch, Psi}, theta::Theta, }, }; @@ -21,15 +22,12 @@ use pharmsol::SppOptimizer; use anyhow::bail; use anyhow::Result; use faer_ext::IntoNdarray; +use pharmsol::prelude::{data::Data, simulator::Equation}; use pharmsol::{prelude::ErrorModel, ErrorModels}; -use pharmsol::{ - prelude::{data::Data, simulator::Equation}, - Subject, -}; use ndarray::{ parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator}, - Array, Array1, ArrayBase, Dim, OwnedRepr, + Array1, }; const THETA_F: f64 = 1e-2; @@ -207,27 +205,24 @@ impl Algorithms for NPOD { } fn estimation(&mut self) -> Result<()> { - let error_model: ErrorModels = self.error_models.clone(); + let use_log_space = self.settings.advanced().log_space; - self.psi = calculate_psi( + self.psi = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, - &error_model, + &self.error_models, self.cycle == 1 && self.settings.config().progress, self.cycle != 1, + use_log_space, )?; if let Err(err) = self.validate_psi() { bail!(err); } - (self.lambda, _) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!(err); - } - }; + (self.lambda, _) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?; Ok(()) } @@ -280,17 +275,15 @@ impl Algorithms for NPOD { self.theta.filter_indices(keep.as_slice()); self.psi.filter_column_indices(keep.as_slice()); - (self.lambda, self.objf) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - return Err(anyhow::anyhow!("Error in IPM: {:?}", err)); - } - }; + (self.lambda, self.objf) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?; self.w = self.lambda.clone(); Ok(()) } fn optimizations(&mut self) -> Result<()> { + let use_log_space = self.settings.advanced().log_space; + self.error_models .clone() .iter_mut() @@ -313,35 +306,31 @@ impl Algorithms for NPOD { let mut error_model_down = self.error_models.clone(); error_model_down.set_factor(outeq, gamma_down)?; - let psi_up = calculate_psi( + let psi_up = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, &error_model_up, false, true, + use_log_space, )?; - let psi_down = calculate_psi( + let psi_down = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, &error_model_down, false, true, + use_log_space, )?; - let (lambda_up, objf_up) = match burke(&psi_up) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; - let (lambda_down, objf_down) = match burke(&psi_down) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; + let (lambda_up, objf_up) = burke_ipm(&psi_up) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + + let (lambda_down, objf_down) = burke_ipm(&psi_down) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + if objf_up > self.objf { self.error_models.set_factor(outeq, gamma_up)?; self.objf = objf_up; @@ -368,9 +357,29 @@ impl Algorithms for NPOD { fn expansion(&mut self) -> Result<()> { // If no stop signal, add new point to theta based on the optimization of the D function - let psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); + // Note: SppOptimizer expects regular-space psi for the D-optimizer + // If we're in log-space, we need to convert pyl to regular space + + let psi_mat = self.psi().matrix().as_ref().into_ndarray().to_owned(); let w: Array1 = self.w.clone().iter().collect(); - let pyl = psi.dot(&w); + + // Compute pyl = P(Y|L) for each subject + // In log-space, we need to use logsumexp and then exp to get regular pyl + let pyl = if self.psi.is_log_space() { + // pyl[i] = sum_j(exp(log_psi[i,j]) * w[j]) = sum_j(exp(log_psi[i,j] + log(w[j]))) + // Using logsumexp for stability, then exp to get regular values + let log_w: Array1 = w.iter().map(|&x| x.ln()).collect(); + let mut pyl = Array1::zeros(psi_mat.nrows()); + for i in 0..psi_mat.nrows() { + let combined: Vec = (0..psi_mat.ncols()) + .map(|j| psi_mat[[i, j]] + log_w[j]) + .collect(); + pyl[i] = logsumexp(&combined).exp(); + } + pyl + } else { + psi_mat.dot(&w) + }; // Add new point to theta based on the optimization of the D function let error_model: ErrorModels = self.error_models.clone(); @@ -397,48 +406,3 @@ impl Algorithms for NPOD { Ok(()) } } - -impl NPOD { - fn validate_psi(&mut self) -> Result<()> { - let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); - // First coerce all NaN and infinite in psi to 0.0 - if psi.iter().any(|x| x.is_nan() || x.is_infinite()) { - tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0"); - for i in 0..psi.nrows() { - for j in 0..psi.ncols() { - let val = psi.get_mut((i, j)).unwrap(); - if val.is_nan() || val.is_infinite() { - *val = 0.0; - } - } - } - } - - // Calculate the sum of each column in psi - let (_, col) = psi.dim(); - let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); - let plam = psi.dot(&ecol); - let w = 1. / &plam; - - // Get the index of each element in `w` that is NaN or infinite - let indices: Vec = w - .iter() - .enumerate() - .filter(|(_, x)| x.is_nan() || x.is_infinite()) - .map(|(i, _)| i) - .collect::>(); - - // If any elements in `w` are NaN or infinite, return the subject IDs for each index - if !indices.is_empty() { - let subject: Vec<&Subject> = self.data.subjects(); - let zero_probability_subjects: Vec<&String> = - indices.iter().map(|&i| subject[i].id()).collect(); - - return Err(anyhow::anyhow!( - "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects - )); - } - - Ok(()) - } -} diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 496b36e28..9e9250d74 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -2,7 +2,7 @@ use crate::{ algorithms::{Status, StopReason}, prelude::algorithms::Algorithms, structs::{ - psi::{calculate_psi, Psi}, + psi::{calculate_psi_dispatch, Psi}, theta::Theta, weights::Weights, }, @@ -14,7 +14,7 @@ use pharmsol::prelude::{ simulator::Equation, }; -use crate::routines::estimation::ipm::burke; +use crate::routines::estimation::ipm::burke_ipm; use crate::routines::initialization; use crate::routines::output::{cycles::CycleLog, NPResult}; use crate::routines::settings::Settings; @@ -119,15 +119,19 @@ impl Algorithms for POSTPROB { } fn estimation(&mut self) -> Result<()> { - self.psi = calculate_psi( + let use_log_space = self.settings.advanced().log_space; + + self.psi = calculate_psi_dispatch( &self.equation, &self.data, &self.theta, &self.error_models, false, false, + use_log_space, )?; - (self.w, self.objf) = burke(&self.psi).context("Error in IPM")?; + + (self.w, self.objf) = burke_ipm(&self.psi).context("Error in IPM")?; Ok(()) } diff --git a/src/bestdose/posterior.rs b/src/bestdose/posterior.rs index 9109f65c1..8df890117 100644 --- a/src/bestdose/posterior.rs +++ b/src/bestdose/posterior.rs @@ -53,12 +53,12 @@ use anyhow::Result; use faer::Mat; -use crate::algorithms::npag::burke; use crate::algorithms::npag::NPAG; use crate::algorithms::Algorithms; use crate::algorithms::Status; use crate::prelude::*; -use crate::structs::psi::calculate_psi; +use crate::routines::estimation::ipm::burke_ipm; +use crate::structs::psi::calculate_psi_dispatch; use crate::structs::theta::Theta; use crate::structs::weights::Weights; use pharmsol::prelude::*; @@ -95,14 +95,16 @@ pub fn npagfull11_filter( past_data: &Data, eq: &ODE, error_models: &ErrorModels, + use_log_space: bool, ) -> Result<(Theta, Weights, Weights)> { tracing::info!("Stage 1.1: NPAGFULL11 Bayesian filtering"); // Calculate psi matrix P(data|theta_i) for all support points - let psi = calculate_psi(eq, past_data, population_theta, error_models, false, true)?; + // Use log-space or regular space based on setting + let psi = calculate_psi_dispatch(eq, past_data, population_theta, error_models, false, true, use_log_space)?; // First burke call to get initial posterior probabilities - let (initial_weights, _) = burke(&psi)?; + let (initial_weights, _) = burke_ipm(&psi)?; // NPAGFULL11 filtering: Keep all points within 1e-100 of the maximum weight // This is different from NPAG's condensation - NO QR decomposition here! @@ -317,6 +319,9 @@ pub fn calculate_two_step_posterior( ) -> Result<(Theta, Weights, Weights)> { tracing::info!("=== STAGE 1: Posterior Density Calculation ==="); + // Use log-space based on settings + let use_log_space = settings.advanced().log_space; + // Step 1.1: NPAGFULL11 filtering (returns filtered posterior AND filtered prior) let (filtered_theta, filtered_posterior_weights, filtered_population_weights) = npagfull11_filter( @@ -325,6 +330,7 @@ pub fn calculate_two_step_posterior( past_data, eq, error_models, + use_log_space, )?; // Step 1.2: NPAGFULL refinement diff --git a/src/routines/condensation/mod.rs b/src/routines/condensation/mod.rs index d01533b6b..2ef44fe6b 100644 --- a/src/routines/condensation/mod.rs +++ b/src/routines/condensation/mod.rs @@ -1,4 +1,4 @@ -use crate::algorithms::npag::{burke, qr}; +use crate::routines::estimation::{ipm::burke_ipm, qr}; use crate::structs::psi::Psi; use crate::structs::theta::Theta; use crate::structs::weights::Weights; @@ -93,8 +93,8 @@ pub fn condense_support_points( filtered_theta.filter_indices(&keep_qr); filtered_psi.filter_column_indices(&keep_qr); - // Step 3: Recalculate weights with Burke's IPM - let (final_weights, objf) = burke(&filtered_psi)?; + // Step 3: Recalculate weights with Burke's IPM (auto-dispatches based on psi.is_log_space()) + let (final_weights, objf) = burke_ipm(&filtered_psi)?; tracing::debug!( "Condensation complete: {} -> {} support points (objective: {:.4})", diff --git a/src/routines/estimation/ipm.rs b/src/routines/estimation/ipm.rs index fbb1768b2..41cf1055e 100644 --- a/src/routines/estimation/ipm.rs +++ b/src/routines/estimation/ipm.rs @@ -278,6 +278,298 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> { Ok((lam.into(), obj)) } +/// Applies Burke's Interior Point Method (IPM) operating in log space. +/// +/// This version works with log-likelihoods directly, which provides better numerical +/// stability when dealing with very small probabilities (many observations or extreme +/// parameter values). +/// +/// The objective function to maximize is: +/// f(x) = Σ(log(Σ(exp(log_ψ_ij) * x_j))) for i = 1 to n_sub +/// = Σ(logsumexp(log_ψ_ij + log(x_j))) +/// +/// subject to: +/// 1. x_j ≥ 0 for all j = 1 to n_point, +/// 2. Σ(x_j) = 1, +/// +/// # Arguments +/// +/// * `log_psi` - A reference to a Psi structure containing log-likelihoods. +/// +/// # Returns +/// +/// On success, returns a tuple `(weights, obj)` where: +/// - [Weights] contains the optimized weights (probabilities) for each support point. +/// - `obj` is the value of the objective function at the solution. +/// +/// # Errors +/// +/// This function returns an error if any step in the optimization fails. +pub fn burke_log(log_psi: &Psi) -> anyhow::Result<(Weights, f64)> { + let log_psi_mat = log_psi.matrix(); + + // Validate that all entries are finite + for row in log_psi_mat.row_iter() { + for &x in row.iter() { + if !x.is_finite() { + bail!("Input log-psi matrix must have finite entries"); + } + } + } + + let (n_sub, n_point) = log_psi_mat.shape(); + + if n_sub == 0 || n_point == 0 { + bail!("Input matrix cannot be empty"); + } + + // Convert log_psi to regular psi for the IPM iterations + // We need to work in regular space for the IPM, but we use logsumexp for the weighted sums + // to maintain numerical stability. + // + // Key insight: The IPM needs to compute psi * lam, which in log space is logsumexp(log_psi + log_lam). + // However, the internal IPM computations (Hessian, gradients) are more complex in log space. + // + // Strategy: Convert log_psi to psi using exp, but handle potential underflow by using + // a shifted version. We'll keep track of the shift and adjust the objective function. + + // Find the maximum log-likelihood per row to prevent underflow + let row_max: Vec = (0..n_sub) + .map(|i| { + (0..n_point) + .map(|j| *log_psi_mat.get(i, j)) + .fold(f64::NEG_INFINITY, f64::max) + }) + .collect(); + + // Create shifted psi matrix: psi_shifted[i,j] = exp(log_psi[i,j] - row_max[i]) + // This ensures the maximum value in each row is 1.0, preventing underflow + let psi_shifted: Mat = Mat::from_fn(n_sub, n_point, |i, j| { + let log_val = *log_psi_mat.get(i, j); + (log_val - row_max[i]).exp() + }); + + // Now run the standard IPM on the shifted matrix + let ecol: Col = Col::from_fn(n_point, |_| 1.0); + let erow: Row = Row::from_fn(n_sub, |_| 1.0); + + let mut plam: Col = &psi_shifted * &ecol; + let eps: f64 = 1e-8; + let mut sig: f64 = 0.0; + + let mut lam = ecol.clone(); + + let mut w: Col = Col::from_fn(plam.nrows(), |i| 1.0 / plam.get(i)); + + let mut ptw: Col = psi_shifted.transpose() * &w; + + let ptw_max = ptw.iter().fold(f64::NEG_INFINITY, |acc, &x| x.max(acc)); + let shrink = 2.0 * ptw_max; + lam *= shrink; + plam *= shrink; + w /= shrink; + ptw /= shrink; + + let mut y: Col = &ecol - &ptw; + let mut r: Col = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i)); + let mut norm_r: f64 = r.iter().fold(0.0, |max, &val| max.max(val.abs())); + + let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum(); + let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum(); + let mut gap: f64 = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam.abs()); + + let mut mu = lam.transpose() * &y / n_point as f64; + + let mut psi_inner: Mat = Mat::zeros(n_sub, n_point); + + let n_threads = faer::get_global_parallelism().degree(); + let mut output: Vec> = (0..n_threads).map(|_| Mat::zeros(n_sub, n_sub)).collect(); + + let mut h: Mat = Mat::zeros(n_sub, n_sub); + + while mu > eps || norm_r > eps || gap > eps { + let smu = sig * mu; + let inner = Col::from_fn(lam.nrows(), |i| lam.get(i) / y.get(i)); + let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i)); + + // Scale columns and compute H matrix + if psi_shifted.ncols() > n_threads * 128 { + psi_inner + .par_col_partition_mut(n_threads) + .zip(psi_shifted.par_col_partition(n_threads)) + .zip(inner.par_partition(n_threads)) + .zip(output.par_iter_mut()) + .for_each(|(((mut psi_inner, psi_part), inner_part), output)| { + psi_inner + .as_mut() + .col_iter_mut() + .zip(psi_part.col_iter()) + .zip(inner_part.iter()) + .for_each(|((col, psi_col), inner_val)| { + col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| { + *x = psi_val * inner_val; + }); + }); + faer::linalg::matmul::triangular::matmul( + output.as_mut(), + faer::linalg::matmul::triangular::BlockStructure::TriangularLower, + faer::Accum::Replace, + &psi_inner, + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + psi_part.transpose(), + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + 1.0, + faer::Par::Seq, + ); + }); + + let mut first_iter = true; + for out in &output { + if first_iter { + h.copy_from(out); + first_iter = false; + } else { + h += out; + } + } + } else { + psi_inner + .as_mut() + .col_iter_mut() + .zip(psi_shifted.col_iter()) + .zip(inner.iter()) + .for_each(|((col, psi_col), inner_val)| { + col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| { + *x = psi_val * inner_val; + }); + }); + faer::linalg::matmul::triangular::matmul( + h.as_mut(), + faer::linalg::matmul::triangular::BlockStructure::TriangularLower, + faer::Accum::Replace, + &psi_inner, + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + psi_shifted.transpose(), + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + 1.0, + faer::Par::Seq, + ); + } + + for i in 0..h.nrows() { + h[(i, i)] += w_plam[i]; + } + + let uph = match h.llt(faer::Side::Lower) { + Ok(llt) => llt, + Err(_) => { + bail!("Error during Cholesky decomposition in log-space IPM. The matrix might not be positive definite.") + } + }; + let uph = uph.L().transpose().to_owned(); + + let smuyinv: Col = Col::from_fn(ecol.nrows(), |i| smu * (ecol[i] / y[i])); + let psi_dot_muyinv: Col = &psi_shifted * &smuyinv; + let rhsdw: Row = Row::from_fn(erow.ncols(), |i| erow[i] / w[i] - psi_dot_muyinv[i]); + + let mut dw = Mat::from_fn(rhsdw.ncols(), 1, |i, _j| *rhsdw.get(i)); + + solve_lower_triangular_in_place(uph.transpose().as_ref(), dw.as_mut(), faer::Par::rayon(0)); + solve_upper_triangular_in_place(uph.as_ref(), dw.as_mut(), faer::Par::rayon(0)); + + let dw = dw.col(0); + let dy = -(psi_shifted.transpose() * dw); + + let inner_times_dy = Col::from_fn(ecol.nrows(), |i| inner[i] * dy[i]); + let dlam: Row = + Row::from_fn(ecol.nrows(), |i| smuyinv[i] - lam[i] - inner_times_dy[i]); + + let ratio_dlam_lam = Row::from_fn(lam.nrows(), |i| dlam[i] / lam[i]); + let min_ratio_dlam = ratio_dlam_lam.iter().cloned().fold(f64::INFINITY, f64::min); + let mut alfpri: f64 = -1.0 / min_ratio_dlam.min(-0.5); + alfpri = (0.99995 * alfpri).min(1.0); + + let ratio_dy_y = Row::from_fn(y.nrows(), |i| dy[i] / y[i]); + let min_ratio_dy = ratio_dy_y.iter().cloned().fold(f64::INFINITY, f64::min); + let ratio_dw_w = Row::from_fn(dw.nrows(), |i| dw[i] / w[i]); + let min_ratio_dw = ratio_dw_w.iter().cloned().fold(f64::INFINITY, f64::min); + let mut alfdual = -1.0 / min_ratio_dy.min(-0.5); + alfdual = alfdual.min(-1.0 / min_ratio_dw.min(-0.5)); + alfdual = (0.99995 * alfdual).min(1.0); + + lam += alfpri * dlam.transpose(); + w += alfdual * dw; + y += alfdual * &dy; + + mu = lam.transpose() * &y / n_point as f64; + plam = &psi_shifted * &lam; + + r = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i)); + ptw -= alfdual * dy; + + norm_r = r.norm_max(); + let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum(); + let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum(); + gap = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam.abs()); + + if mu < eps && norm_r > eps { + sig = 1.0; + } else { + let candidate1 = (1.0 - alfpri).powi(2); + let candidate2 = (1.0 - alfdual).powi(2); + let candidate3 = (norm_r - mu) / (norm_r + 100.0 * mu); + sig = candidate1.max(candidate2).max(candidate3).min(0.3); + } + } + + // Scale lam + lam /= n_sub as f64; + + // Compute the objective function value in log space + // obj = sum_i(log(sum_j(psi[i,j] * lam[j]))) + // = sum_i(log(sum_j(exp(log_psi[i,j]) * lam[j]))) + // = sum_i(logsumexp(log_psi[i,j] + log(lam[j]))) + // + // But we worked with shifted psi, so: + // obj = sum_i(row_max[i] + log(sum_j(psi_shifted[i,j] * lam[j]))) + // = sum_i(row_max[i]) + sum_i(log(plam[i])) + let plam_final = &psi_shifted * &lam; + let obj: f64 = row_max.iter().sum::() + plam_final.iter().map(|x| x.ln()).sum::(); + + // Normalize lam to sum to 1 + let lam_sum: f64 = lam.iter().sum(); + lam = &lam / lam_sum; + + Ok((lam.into(), obj)) +} + +/// Unified IPM dispatch that automatically chooses the correct algorithm based on psi type. +/// +/// This function checks if `psi` is in log-space and calls the appropriate Burke IPM: +/// - If `psi.is_log_space()` is true: calls `burke_log` +/// - Otherwise: calls `burke` +/// +/// # Arguments +/// +/// * `psi` - A reference to a Psi structure (either regular or log-space) +/// +/// # Returns +/// +/// On success, returns a tuple `(weights, obj)` where: +/// - [Weights] contains the optimized weights (probabilities) for each support point. +/// - `obj` is the value of the objective function at the solution. +/// +/// # Errors +/// +/// Returns an error if the underlying IPM optimization fails. +pub fn burke_ipm(psi: &Psi) -> anyhow::Result<(Weights, f64)> { + if psi.is_log_space() { + burke_log(psi) + } else { + burke(psi) + } +} + #[cfg(test)] mod tests { use super::*; @@ -514,4 +806,172 @@ mod tests { // The objective function should be finite assert!(obj.is_finite(), "Objective function should be finite"); } + + // ========== Log-space IPM tests ========== + + #[test] + fn test_burke_log_identity() { + // Test with identity matrix converted to log space + // log(1) = 0, log(0) = -inf, but we use a small positive value instead + use crate::structs::psi::PsiBuilder; + use ndarray::Array2; + + let n = 10; + // Create identity matrix in log space: log(1) = 0, log(eps) for off-diagonal + let log_mat = Array2::from_shape_fn((n, n), |(i, j)| { + if i == j { + 0.0 // log(1) = 0 + } else { + -30.0 // very small probability, exp(-30) ≈ 0 + } + }); + + let psi = PsiBuilder::new(log_mat).log_space(true).build(); + + let (lam, obj) = burke_log(&psi).unwrap(); + + // For identity-like matrix, weights should be roughly equal + let expected = 1.0 / n as f64; + for i in 0..n { + assert_relative_eq!(lam[i], expected, epsilon = 1e-6); + } + + // Check that lambda sums to 1 + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + + // Objective should be finite + assert!(obj.is_finite(), "Objective function should be finite"); + } + + #[test] + fn test_burke_log_uniform() { + // Test with uniform matrix in log space + // log(1) = 0 everywhere + use crate::structs::psi::PsiBuilder; + use ndarray::Array2; + + let n_sub = 10; + let n_point = 10; + let log_mat = Array2::from_shape_fn((n_sub, n_point), |_| 0.0); // log(1) = 0 + + let psi = PsiBuilder::new(log_mat).log_space(true).build(); + + let (lam, obj) = burke_log(&psi).unwrap(); + + // Check that lambda sums to 1 + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + + // For uniform matrix, all weights should be equal + let expected = 1.0 / n_point as f64; + for i in 0..n_point { + assert_relative_eq!(lam[i], expected, epsilon = 1e-6); + } + + // Objective should be finite + assert!(obj.is_finite(), "Objective function should be finite"); + } + + #[test] + fn test_burke_log_consistency_with_regular() { + // Test that burke_log produces the same results as burke + // when given equivalent inputs + use crate::structs::psi::PsiBuilder; + use ndarray::Array2; + + let n_sub = 5; + let n_point = 8; + + // Create a regular psi matrix with positive values + let regular_mat = Array2::from_shape_fn((n_sub, n_point), |(i, j)| { + 0.5 + 0.1 * (i as f64) + 0.05 * (j as f64) + }); + let regular_psi = Psi::from(regular_mat.clone()); + + // Create the equivalent log-space matrix + let log_mat = regular_mat.mapv(|x| x.ln()); + let log_psi = PsiBuilder::new(log_mat).log_space(true).build(); + + // Run both algorithms + let (lam_regular, obj_regular) = burke(®ular_psi).unwrap(); + let (lam_log, obj_log) = burke_log(&log_psi).unwrap(); + + // The weights should be very similar + for i in 0..n_point { + assert_relative_eq!(lam_regular[i], lam_log[i], epsilon = 1e-6); + } + + // The objective functions should be very similar + assert_relative_eq!(obj_regular, obj_log, epsilon = 1e-6); + } + + #[test] + fn test_burke_log_handles_very_small_likelihoods() { + // Test that log-space IPM handles very small likelihoods that would + // underflow in regular space + use crate::structs::psi::PsiBuilder; + use ndarray::Array2; + + let n_sub = 5; + let n_point = 5; + + // Create log-likelihoods that would underflow if exponentiated directly + // These represent likelihoods of exp(-500) ≈ 10^(-217) + let log_mat = Array2::from_shape_fn((n_sub, n_point), |(i, j)| { + -500.0 + (i as f64) * 0.1 + (j as f64) * 0.05 + }); + + let psi = PsiBuilder::new(log_mat).log_space(true).build(); + + // This should succeed without underflow issues + let result = burke_log(&psi); + assert!( + result.is_ok(), + "Log-space IPM should handle very small likelihoods" + ); + + let (lam, obj) = result.unwrap(); + + // Check basic properties + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + assert!(obj.is_finite(), "Objective function should be finite"); + + // All weights should be non-negative + for i in 0..n_point { + assert!(lam[i] >= 0.0, "Lambda values should be non-negative"); + } + } + + #[test] + fn test_burke_log_with_varying_magnitudes() { + // Test with log-likelihoods of varying magnitudes + use crate::structs::psi::PsiBuilder; + use ndarray::Array2; + + let n_sub = 8; + let n_point = 12; + + // Create varying log-likelihoods + let log_mat = Array2::from_shape_fn((n_sub, n_point), |(i, j)| { + // Range from -100 to -10, with column 0 having higher values (better fit) + if j == 0 { + -10.0 - (i as f64) + } else { + -50.0 - (i as f64) - (j as f64) + } + }); + + let psi = PsiBuilder::new(log_mat).log_space(true).build(); + + let (lam, obj) = burke_log(&psi).unwrap(); + + // Check basic properties + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + assert!(obj.is_finite(), "Objective function should be finite"); + + // First column should have higher weight since it has higher log-likelihoods + assert!( + lam[0] > lam[1], + "First support point should have higher weight" + ); + } } diff --git a/src/routines/estimation/qr.rs b/src/routines/estimation/qr.rs index acc104d26..81fd66c23 100644 --- a/src/routines/estimation/qr.rs +++ b/src/routines/estimation/qr.rs @@ -3,29 +3,66 @@ use anyhow::{bail, Result}; use faer::linalg::solvers::ColPivQr; use faer::Mat; +/// Compute log-sum-exp of a slice for numerical stability +#[inline] +fn logsumexp(values: &[f64]) -> f64 { + let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + if max_val.is_infinite() { + return max_val; + } + max_val + + values + .iter() + .map(|&x| (x - max_val).exp()) + .sum::() + .ln() +} + /// Perform a QR decomposition on the Psi matrix /// /// Normalizes each row of the matrix to sum to 1 before decomposition. +/// For log-space matrices, applies softmax normalization (converts log-probs to probs, +/// normalizes, then can proceed with QR). /// Returns the R matrix from QR decomposition and the column permutation vector. /// /// # Arguments -/// * `psi` - The Psi matrix to decompose +/// * `psi` - The Psi matrix to decompose (can be in regular or log space) /// /// # Returns /// * Tuple containing the R matrix (as [faer::Mat]) and permutation vector (as [Vec]) -/// * Error if any row in the matrix sums to zero +/// * Error if any row in the matrix sums to zero (or has all -inf in log space) pub fn qrd(psi: &Psi) -> Result<(Mat, Vec)> { let mut mat = psi.matrix().to_owned(); - // Normalize the rows to sum to 1 - for (index, row) in mat.row_iter_mut().enumerate() { - let row_sum: f64 = row.as_ref().iter().sum(); + if psi.is_log_space() { + // For log-space: apply softmax normalization + // softmax(x)_i = exp(x_i) / sum(exp(x_j)) = exp(x_i - logsumexp(x)) + for index in 0..mat.nrows() { + let log_values: Vec = (0..mat.ncols()).map(|j| *mat.get(index, j)).collect(); + let log_sum = logsumexp(&log_values); + + if log_sum.is_infinite() && log_sum < 0.0 { + bail!( + "In log_psi, the row with index {} has all -inf values (zero probability)", + index + ); + } + + // Convert to normalized probabilities via softmax + for j in 0..mat.ncols() { + *mat.get_mut(index, j) = (log_values[j] - log_sum).exp(); + } + } + } else { + // For regular space: normalize rows to sum to 1 + for (index, row) in mat.row_iter_mut().enumerate() { + let row_sum: f64 = row.as_ref().iter().sum(); - // Check if the row sum is zero - if row_sum.abs() == 0.0 { - bail!("In psi, the row with index {} sums to zero", index); + if row_sum.abs() == 0.0 { + bail!("In psi, the row with index {} sums to zero", index); + } + row.iter_mut().for_each(|x| *x /= row_sum); } - row.iter_mut().for_each(|x| *x /= row_sum); } // Perform column pivoted QR decomposition diff --git a/src/routines/math.rs b/src/routines/math.rs new file mode 100644 index 000000000..dbd24b514 --- /dev/null +++ b/src/routines/math.rs @@ -0,0 +1,176 @@ +//! Mathematical utility functions for numerical stability +//! +//! This module provides stable implementations of common numerical operations. + +/// Compute the log-sum-exp of a slice of values in a numerically stable way. +/// +/// The log-sum-exp is defined as: `log(sum(exp(x_i)))` for all elements `x_i`. +/// +/// This implementation uses the "shift by max" trick to avoid overflow: +/// `logsumexp(x) = max(x) + log(sum(exp(x_i - max(x))))` +/// +/// # Arguments +/// * `values` - A slice of f64 values (typically log-likelihoods) +/// +/// # Returns +/// The log-sum-exp of the values. Returns `f64::NEG_INFINITY` if all values are `-inf`. +/// +/// # Example +/// ```ignore +/// let log_probs = vec![-1.0, -2.0, -3.0]; +/// let result = logsumexp(&log_probs); +/// // result ≈ log(exp(-1) + exp(-2) + exp(-3)) ≈ -0.407 +/// ``` +#[inline] +pub fn logsumexp(values: &[f64]) -> f64 { + if values.is_empty() { + return f64::NEG_INFINITY; + } + + let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + + if max_val.is_infinite() && max_val.is_sign_negative() { + // All values are -inf, return -inf + f64::NEG_INFINITY + } else if max_val.is_infinite() && max_val.is_sign_positive() { + // At least one value is +inf + f64::INFINITY + } else { + max_val + values.iter().map(|&x| (x - max_val).exp()).sum::().ln() + } +} + +/// Compute the weighted log-sum-exp: `logsumexp(log_values + log_weights)`. +/// +/// This computes `log(sum(values_i * weights_i))` when `log_values` contains log-likelihoods. +/// Equivalent to `log(sum(exp(log_values_i) * weights_i))`. +/// +/// # Arguments +/// * `log_values` - A slice of log-values (e.g., log-likelihoods) +/// * `log_weights` - A slice of log-weights (should be same length as log_values) +/// +/// # Returns +/// The weighted log-sum-exp. Panics if slices have different lengths. +#[inline] +pub fn logsumexp_weighted(log_values: &[f64], log_weights: &[f64]) -> f64 { + assert_eq!( + log_values.len(), + log_weights.len(), + "log_values and log_weights must have the same length" + ); + + let combined: Vec = log_values + .iter() + .zip(log_weights.iter()) + .map(|(&lv, &lw)| lv + lw) + .collect(); + + logsumexp(&combined) +} + +/// Compute log-sum-exp for each row of a matrix represented as a closure. +/// +/// # Arguments +/// * `nrows` - Number of rows +/// * `ncols` - Number of columns +/// * `get_value` - Closure that returns the value at (row, col) +/// +/// # Returns +/// A vector of logsumexp values, one per row. +pub fn logsumexp_rows(nrows: usize, ncols: usize, get_value: F) -> Vec +where + F: Fn(usize, usize) -> f64, +{ + (0..nrows) + .map(|i| { + let row: Vec = (0..ncols).map(|j| get_value(i, j)).collect(); + logsumexp(&row) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_logsumexp_basic() { + let values = vec![-1.0, -2.0, -3.0]; + let result = logsumexp(&values); + // log(exp(-1) + exp(-2) + exp(-3)) ≈ -0.4076 + let expected = ((-1.0_f64).exp() + (-2.0_f64).exp() + (-3.0_f64).exp()).ln(); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_single_value() { + let values = vec![-5.0]; + let result = logsumexp(&values); + assert!((result - (-5.0)).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_empty() { + let values: Vec = vec![]; + let result = logsumexp(&values); + assert!(result.is_infinite() && result.is_sign_negative()); + } + + #[test] + fn test_logsumexp_all_neg_inf() { + let values = vec![f64::NEG_INFINITY, f64::NEG_INFINITY]; + let result = logsumexp(&values); + assert!(result.is_infinite() && result.is_sign_negative()); + } + + #[test] + fn test_logsumexp_with_neg_inf() { + // logsumexp([-inf, 0]) = log(0 + 1) = 0 + let values = vec![f64::NEG_INFINITY, 0.0]; + let result = logsumexp(&values); + assert!((result - 0.0).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_large_values() { + // Test numerical stability with large values + let values = vec![1000.0, 1001.0, 1002.0]; + let result = logsumexp(&values); + // Should be close to 1002 + log(exp(-2) + exp(-1) + 1) ≈ 1002.41 + let expected = 1002.0 + ((-2.0_f64).exp() + (-1.0_f64).exp() + 1.0).ln(); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_very_negative() { + // Test with very negative values that would underflow with naive implementation + let values = vec![-1000.0, -1001.0, -1002.0]; + let result = logsumexp(&values); + let expected = -1000.0 + (1.0 + (-1.0_f64).exp() + (-2.0_f64).exp()).ln(); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_weighted() { + let log_values = vec![-1.0, -2.0]; + let log_weights = vec![0.0, 0.0]; // weights = 1 + let result = logsumexp_weighted(&log_values, &log_weights); + let expected = logsumexp(&log_values); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_rows() { + let matrix = vec![ + vec![-1.0, -2.0], + vec![-3.0, -4.0], + ]; + let result = logsumexp_rows(2, 2, |i, j| matrix[i][j]); + + let expected_0 = logsumexp(&[-1.0, -2.0]); + let expected_1 = logsumexp(&[-3.0, -4.0]); + + assert!((result[0] - expected_0).abs() < 1e-10); + assert!((result[1] - expected_1).abs() < 1e-10); + } +} diff --git a/src/routines/mod.rs b/src/routines/mod.rs index af25d67e6..ac626d2c9 100644 --- a/src/routines/mod.rs +++ b/src/routines/mod.rs @@ -8,6 +8,8 @@ pub mod expansion; pub mod initialization; // Routines for logging pub mod logger; +// Mathematical utilities +pub mod math; // Routines for output pub mod output; // Routines for settings diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 2a9e43d69..533137b20 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -22,8 +22,6 @@ pub mod cycles; pub mod posterior; pub mod predictions; -use posterior::posterior; - /// Defines the result objects from an NPAG run /// An [NPResult] contains the necessary information to generate predictions and summary statistics #[derive(Debug, Serialize)] @@ -61,7 +59,7 @@ impl NPResult { cyclelog: CycleLog, ) -> Result { // Calculate the posterior probabilities - let posterior = posterior(&psi, &w) + let posterior = Posterior::calculate(&psi, &w) .context("Failed to calculate posterior during initialization of NPResult")?; let result = Self { diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs index 008ce16c1..827c9ddb5 100644 --- a/src/routines/output/posterior.rs +++ b/src/routines/output/posterior.rs @@ -2,6 +2,7 @@ pub use anyhow::{bail, Result}; use faer::Mat; use serde::{Deserialize, Serialize}; +use crate::routines::math::logsumexp; use crate::structs::{psi::Psi, weights::Weights}; /// Posterior probabilities for each support points @@ -37,10 +38,34 @@ impl Posterior { } let psi_matrix = psi.matrix(); - let py = psi_matrix * w.weights(); - + let is_log_space = psi.is_log_space(); + + // Calculate py[i] = sum_j(psi[i,j] * w[j]) for each subject i + // In log-space: py[i] = logsumexp_j(log_psi[i,j] + log(w[j])) + let py: Vec = if is_log_space { + let log_w: Vec = (0..w.len()).map(|j| w.weights().get(j).ln()).collect(); + (0..psi_matrix.nrows()) + .map(|i| { + let combined: Vec = (0..psi_matrix.ncols()) + .map(|j| *psi_matrix.get(i, j) + log_w[j]) + .collect(); + logsumexp(&combined) + }) + .collect() + } else { + let py_mat = psi_matrix * w.weights(); + (0..py_mat.nrows()).map(|i| *py_mat.get(i)).collect() + }; + + // Calculate posterior[i,j] = psi[i,j] * w[j] / py[i] + // In log-space: posterior[i,j] = exp(log_psi[i,j] + log(w[j]) - log_py[i]) let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) + if is_log_space { + let log_w_j = w.weights().get(j).ln(); + (*psi_matrix.get(i, j) + log_w_j - py[i]).exp() + } else { + psi_matrix.get(i, j) * w.weights().get(j) / py[i] + } }); Ok(posterior.into()) @@ -180,25 +205,3 @@ impl<'de> Deserialize<'de> for Posterior { deserializer.deserialize_seq(PosteriorVisitor) } } - -/// Calculates the posterior probabilities for each support point given the weights -/// -/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. -pub fn posterior(psi: &Psi, w: &Weights) -> Result { - if psi.matrix().ncols() != w.len() { - bail!( - "Number of rows in psi ({}) and number of weights ({}) do not match.", - psi.matrix().nrows(), - w.len() - ); - } - - let psi_matrix = psi.matrix(); - let py = psi_matrix * w.weights(); - - let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) - }); - - Ok(posterior.into()) -} diff --git a/src/routines/settings.rs b/src/routines/settings.rs index 3471a683f..b14409dab 100644 --- a/src/routines/settings.rs +++ b/src/routines/settings.rs @@ -270,6 +270,13 @@ pub struct Advanced { /// /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer) pub tolerance: f64, + /// Use log-space computations for improved numerical stability + /// + /// When true, likelihoods are computed and stored in log space throughout the algorithm. + /// This prevents underflow issues when dealing with many observations or extreme parameter values. + /// The log-sum-exp trick is used to maintain numerical stability in weighted sum operations. + /// Default is true for better numerical properties. + pub log_space: bool, } impl Default for Advanced { @@ -278,6 +285,7 @@ impl Default for Advanced { min_distance: 1e-4, nm_steps: 100, tolerance: 1e-6, + log_space: true, } } } diff --git a/src/structs/psi.rs b/src/structs/psi.rs index c63756cae..07a4d8f59 100644 --- a/src/structs/psi.rs +++ b/src/structs/psi.rs @@ -4,7 +4,7 @@ use faer::Mat; use faer_ext::IntoFaer; use faer_ext::IntoNdarray; use ndarray::{Array2, ArrayView2}; -use pharmsol::prelude::simulator::psi; +use pharmsol::prelude::simulator::{log_psi, psi}; use pharmsol::Data; use pharmsol::Equation; use pharmsol::ErrorModels; @@ -13,20 +13,41 @@ use serde::{Deserialize, Serialize}; use super::theta::Theta; /// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column) +/// +/// The matrix can store either regular likelihoods or log-likelihoods depending on how it was constructed. +/// Use the `is_log_space` flag to determine which representation is stored. #[derive(Debug, Clone, PartialEq)] pub struct Psi { matrix: Mat, + /// Whether the matrix contains log-likelihoods (true) or regular likelihoods (false) + is_log_space: bool, } impl Psi { pub fn new() -> Self { - Psi { matrix: Mat::new() } + Psi { + matrix: Mat::new(), + is_log_space: false, + } + } + + /// Create a new Psi in log space + pub fn new_log() -> Self { + Psi { + matrix: Mat::new(), + is_log_space: true, + } } pub fn matrix(&self) -> &Mat { &self.matrix } + /// Returns true if the matrix stores log-likelihoods + pub fn is_log_space(&self) -> bool { + self.is_log_space + } + pub fn nspp(&self) -> usize { self.matrix.nrows() } @@ -101,7 +122,10 @@ impl Psi { // Create matrix from rows let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); - Ok(Psi { matrix: mat }) + Ok(Psi { + matrix: mat, + is_log_space: false, + }) } } @@ -114,27 +138,39 @@ impl Default for Psi { impl From> for Psi { fn from(array: Array2) -> Self { let matrix = array.view().into_faer().to_owned(); - Psi { matrix } + Psi { + matrix, + is_log_space: false, + } } } impl From> for Psi { fn from(matrix: Mat) -> Self { - Psi { matrix } + Psi { + matrix, + is_log_space: false, + } } } impl From> for Psi { fn from(array_view: ArrayView2<'_, f64>) -> Self { let matrix = array_view.into_faer().to_owned(); - Psi { matrix } + Psi { + matrix, + is_log_space: false, + } } } impl From<&Array2> for Psi { fn from(array: &Array2) -> Self { let matrix = array.view().into_faer().to_owned(); - Psi { matrix } + Psi { + matrix, + is_log_space: false, + } } } @@ -208,7 +244,10 @@ impl<'de> Deserialize<'de> for Psi { // Create matrix from rows let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); - Ok(Psi { matrix: mat }) + Ok(Psi { + matrix: mat, + is_log_space: false, + }) } } @@ -216,6 +255,35 @@ impl<'de> Deserialize<'de> for Psi { } } +/// Helper struct for creating Psi with log-space flag +pub struct PsiBuilder { + array: Array2, + is_log_space: bool, +} + +impl PsiBuilder { + pub fn new(array: Array2) -> Self { + Self { + array, + is_log_space: false, + } + } + + pub fn log_space(mut self, is_log: bool) -> Self { + self.is_log_space = is_log; + self + } + + pub fn build(self) -> Psi { + let matrix = self.array.view().into_faer().to_owned(); + Psi { + matrix, + is_log_space: self.is_log_space, + } + } +} + +/// Calculate the likelihood matrix (regular space) pub(crate) fn calculate_psi( equation: &impl Equation, subjects: &Data, @@ -233,7 +301,65 @@ pub(crate) fn calculate_psi( cache, )?; - Ok(psi_ndarray.view().into()) + Ok(PsiBuilder::new(psi_ndarray).log_space(false).build()) +} + +/// Calculate the log-likelihood matrix (log space) +/// +/// This computes log-likelihoods directly, which is numerically more stable +/// than computing likelihoods and then taking logarithms. This is especially +/// important when dealing with many observations or extreme parameter values. +pub(crate) fn calculate_log_psi( + equation: &impl Equation, + subjects: &Data, + theta: &Theta, + error_models: &ErrorModels, + progress: bool, + cache: bool, +) -> Result { + let log_psi_ndarray = log_psi( + equation, + subjects, + &theta.matrix().clone().as_ref().into_ndarray().to_owned(), + error_models, + progress, + cache, + )?; + + Ok(PsiBuilder::new(log_psi_ndarray).log_space(true).build()) +} + +/// Unified psi calculation that dispatches based on the `use_log_space` flag. +/// +/// This function eliminates the need for repeated if/else blocks in algorithm code. +/// +/// # Arguments +/// +/// * `equation` - The model equation +/// * `subjects` - The data +/// * `theta` - The support points +/// * `error_models` - The error models +/// * `progress` - Whether to show progress +/// * `cache` - Whether to use caching +/// * `use_log_space` - If true, calculates log-likelihoods; otherwise, regular likelihoods +/// +/// # Returns +/// +/// A Psi matrix with the appropriate `is_log_space` flag set. +pub(crate) fn calculate_psi_dispatch( + equation: &impl Equation, + subjects: &Data, + theta: &Theta, + error_models: &ErrorModels, + progress: bool, + cache: bool, + use_log_space: bool, +) -> Result { + if use_log_space { + calculate_log_psi(equation, subjects, theta, error_models, progress, cache) + } else { + calculate_psi(equation, subjects, theta, error_models, progress, cache) + } } #[cfg(test)] From bcd365becf2a2cc5ebf7fd9e57da9aa53a16537c Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 2 Dec 2025 15:37:27 +0100 Subject: [PATCH 2/3] Formatting --- src/bestdose/posterior.rs | 10 +++++++++- src/routines/math.rs | 16 +++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/bestdose/posterior.rs b/src/bestdose/posterior.rs index 8df890117..b5147b269 100644 --- a/src/bestdose/posterior.rs +++ b/src/bestdose/posterior.rs @@ -101,7 +101,15 @@ pub fn npagfull11_filter( // Calculate psi matrix P(data|theta_i) for all support points // Use log-space or regular space based on setting - let psi = calculate_psi_dispatch(eq, past_data, population_theta, error_models, false, true, use_log_space)?; + let psi = calculate_psi_dispatch( + eq, + past_data, + population_theta, + error_models, + false, + true, + use_log_space, + )?; // First burke call to get initial posterior probabilities let (initial_weights, _) = burke_ipm(&psi)?; diff --git a/src/routines/math.rs b/src/routines/math.rs index dbd24b514..f606f2a77 100644 --- a/src/routines/math.rs +++ b/src/routines/math.rs @@ -36,7 +36,12 @@ pub fn logsumexp(values: &[f64]) -> f64 { // At least one value is +inf f64::INFINITY } else { - max_val + values.iter().map(|&x| (x - max_val).exp()).sum::().ln() + max_val + + values + .iter() + .map(|&x| (x - max_val).exp()) + .sum::() + .ln() } } @@ -161,15 +166,12 @@ mod tests { #[test] fn test_logsumexp_rows() { - let matrix = vec![ - vec![-1.0, -2.0], - vec![-3.0, -4.0], - ]; + let matrix = vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]; let result = logsumexp_rows(2, 2, |i, j| matrix[i][j]); - + let expected_0 = logsumexp(&[-1.0, -2.0]); let expected_1 = logsumexp(&[-3.0, -4.0]); - + assert!((result[0] - expected_0).abs() < 1e-10); assert!((result[1] - expected_1).abs() < 1e-10); } From 764862498773ff64ba7973a349e83917ca27aa8c Mon Sep 17 00:00:00 2001 From: Markus Hovd Date: Thu, 11 Dec 2025 16:51:03 +0100 Subject: [PATCH 3/3] chore: Log likelihood suggestions (#235) * wip * wip * 2 tests are failing * Fix tests * Add to_space for converting --- src/algorithms/mod.rs | 5 +- src/algorithms/npag.rs | 20 ++- src/algorithms/npod.rs | 47 +++---- src/algorithms/postprob.rs | 8 +- src/bestdose/posterior.rs | 14 +- src/routines/estimation/ipm.rs | 31 ++--- src/routines/estimation/qr.rs | 55 ++++---- src/routines/output/posterior.rs | 5 +- src/routines/settings.rs | 5 +- src/structs/psi.rs | 212 ++++++++++++++----------------- 10 files changed, 190 insertions(+), 212 deletions(-) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 5c5dd1edb..35eb2bd34 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -37,7 +37,10 @@ pub trait Algorithms: Sync + Send + 'static { // Count problematic values in psi let mut nan_count = 0; let mut inf_count = 0; - let is_log_space = self.psi().is_log_space(); + let is_log_space = match self.psi().space() { + crate::structs::psi::Space::Linear => false, + crate::structs::psi::Space::Log => true, + }; let psi = self.psi().matrix().as_ref().into_ndarray(); // First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space) diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 7c377b020..e83f20d64 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -7,7 +7,7 @@ use crate::routines::math::logsumexp; use crate::routines::settings::Settings; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; -use crate::structs::psi::{calculate_psi_dispatch, Psi}; +use crate::structs::psi::{calculate_psi, Psi}; use crate::structs::theta::Theta; use crate::structs::weights::Weights; @@ -162,7 +162,7 @@ impl Algorithms for NPAG { self.eps /= 2.; if self.eps <= THETA_E { // Compute f1 = sum(log(pyl)) where pyl = psi * w - self.f1 = if self.psi.is_log_space() { + self.f1 = if self.psi.space() == crate::structs::psi::Space::Log { // For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w))) let log_w: Vec = w.weights().iter().map(|&x| x.ln()).collect(); (0..psi.nrows()) @@ -214,16 +214,14 @@ impl Algorithms for NPAG { } fn estimation(&mut self) -> Result<()> { - let use_log_space = self.settings.advanced().log_space; - - self.psi = calculate_psi_dispatch( + self.psi = calculate_psi( &self.equation, &self.data, &self.theta, &self.error_models, self.cycle == 1 && self.settings.config().progress, self.cycle != 1, - use_log_space, + self.settings.advanced().space, )?; if let Err(err) = self.validate_psi() { @@ -296,8 +294,6 @@ impl Algorithms for NPAG { } fn optimizations(&mut self) -> Result<()> { - let use_log_space = self.settings.advanced().log_space; - self.error_models .clone() .iter_mut() @@ -318,24 +314,24 @@ impl Algorithms for NPAG { let mut error_model_down = self.error_models.clone(); error_model_down.set_factor(outeq, gamma_down)?; - let psi_up = calculate_psi_dispatch( + let psi_up = calculate_psi( &self.equation, &self.data, &self.theta, &error_model_up, false, true, - use_log_space, + self.settings.advanced().space, )?; - let psi_down = calculate_psi_dispatch( + let psi_down = calculate_psi( &self.equation, &self.data, &self.theta, &error_model_down, false, true, - use_log_space, + self.settings.advanced().space, )?; let (lambda_up, objf_up) = burke_ipm(&psi_up) diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 50f2a299c..0cb11da4e 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -2,6 +2,7 @@ use crate::algorithms::StopReason; use crate::routines::initialization::sample_space; use crate::routines::math::logsumexp; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; +use crate::structs::psi::calculate_psi; use crate::structs::weights::Weights; use crate::{ algorithms::Status, @@ -12,10 +13,7 @@ use crate::{ settings::Settings, }, }, - structs::{ - psi::{calculate_psi_dispatch, Psi}, - theta::Theta, - }, + structs::{psi::Psi, theta::Theta}, }; use pharmsol::SppOptimizer; @@ -205,16 +203,14 @@ impl Algorithms for NPOD { } fn estimation(&mut self) -> Result<()> { - let use_log_space = self.settings.advanced().log_space; - - self.psi = calculate_psi_dispatch( + self.psi = calculate_psi( &self.equation, &self.data, &self.theta, &self.error_models, self.cycle == 1 && self.settings.config().progress, self.cycle != 1, - use_log_space, + self.settings.advanced().space, )?; if let Err(err) = self.validate_psi() { @@ -282,8 +278,6 @@ impl Algorithms for NPOD { } fn optimizations(&mut self) -> Result<()> { - let use_log_space = self.settings.advanced().log_space; - self.error_models .clone() .iter_mut() @@ -306,23 +300,23 @@ impl Algorithms for NPOD { let mut error_model_down = self.error_models.clone(); error_model_down.set_factor(outeq, gamma_down)?; - let psi_up = calculate_psi_dispatch( + let psi_up = calculate_psi( &self.equation, &self.data, &self.theta, &error_model_up, false, true, - use_log_space, + self.settings.advanced().space, )?; - let psi_down = calculate_psi_dispatch( + let psi_down = calculate_psi( &self.equation, &self.data, &self.theta, &error_model_down, false, true, - use_log_space, + self.settings.advanced().space, )?; let (lambda_up, objf_up) = burke_ipm(&psi_up) @@ -365,20 +359,19 @@ impl Algorithms for NPOD { // Compute pyl = P(Y|L) for each subject // In log-space, we need to use logsumexp and then exp to get regular pyl - let pyl = if self.psi.is_log_space() { - // pyl[i] = sum_j(exp(log_psi[i,j]) * w[j]) = sum_j(exp(log_psi[i,j] + log(w[j]))) - // Using logsumexp for stability, then exp to get regular values - let log_w: Array1 = w.iter().map(|&x| x.ln()).collect(); - let mut pyl = Array1::zeros(psi_mat.nrows()); - for i in 0..psi_mat.nrows() { - let combined: Vec = (0..psi_mat.ncols()) - .map(|j| psi_mat[[i, j]] + log_w[j]) - .collect(); - pyl[i] = logsumexp(&combined).exp(); + let pyl = match self.settings.advanced().space { + crate::structs::psi::Space::Log => { + let log_w: Array1 = w.iter().map(|&x| x.ln()).collect(); + let mut pyl = Array1::zeros(psi_mat.nrows()); + for i in 0..psi_mat.nrows() { + let combined: Vec = (0..psi_mat.ncols()) + .map(|j| psi_mat[[i, j]] + log_w[j]) + .collect(); + pyl[i] = logsumexp(&combined).exp(); + } + pyl } - pyl - } else { - psi_mat.dot(&w) + crate::structs::psi::Space::Linear => psi_mat.dot(&w), }; // Add new point to theta based on the optimization of the D function diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 9e9250d74..a909bfce7 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -2,7 +2,7 @@ use crate::{ algorithms::{Status, StopReason}, prelude::algorithms::Algorithms, structs::{ - psi::{calculate_psi_dispatch, Psi}, + psi::{calculate_psi, Psi}, theta::Theta, weights::Weights, }, @@ -119,16 +119,14 @@ impl Algorithms for POSTPROB { } fn estimation(&mut self) -> Result<()> { - let use_log_space = self.settings.advanced().log_space; - - self.psi = calculate_psi_dispatch( + self.psi = calculate_psi( &self.equation, &self.data, &self.theta, &self.error_models, false, false, - use_log_space, + self.settings.advanced().space, )?; (self.w, self.objf) = burke_ipm(&self.psi).context("Error in IPM")?; diff --git a/src/bestdose/posterior.rs b/src/bestdose/posterior.rs index b5147b269..87b9a35fc 100644 --- a/src/bestdose/posterior.rs +++ b/src/bestdose/posterior.rs @@ -58,7 +58,8 @@ use crate::algorithms::Algorithms; use crate::algorithms::Status; use crate::prelude::*; use crate::routines::estimation::ipm::burke_ipm; -use crate::structs::psi::calculate_psi_dispatch; +use crate::structs::psi::calculate_psi; +use crate::structs::psi::Space; use crate::structs::theta::Theta; use crate::structs::weights::Weights; use pharmsol::prelude::*; @@ -95,20 +96,20 @@ pub fn npagfull11_filter( past_data: &Data, eq: &ODE, error_models: &ErrorModels, - use_log_space: bool, + space: Space, ) -> Result<(Theta, Weights, Weights)> { tracing::info!("Stage 1.1: NPAGFULL11 Bayesian filtering"); // Calculate psi matrix P(data|theta_i) for all support points // Use log-space or regular space based on setting - let psi = calculate_psi_dispatch( + let psi = calculate_psi( eq, past_data, population_theta, error_models, false, true, - use_log_space, + space, )?; // First burke call to get initial posterior probabilities @@ -327,9 +328,6 @@ pub fn calculate_two_step_posterior( ) -> Result<(Theta, Weights, Weights)> { tracing::info!("=== STAGE 1: Posterior Density Calculation ==="); - // Use log-space based on settings - let use_log_space = settings.advanced().log_space; - // Step 1.1: NPAGFULL11 filtering (returns filtered posterior AND filtered prior) let (filtered_theta, filtered_posterior_weights, filtered_population_weights) = npagfull11_filter( @@ -338,7 +336,7 @@ pub fn calculate_two_step_posterior( past_data, eq, error_models, - use_log_space, + settings.advanced().space, )?; // Step 1.2: NPAGFULL refinement diff --git a/src/routines/estimation/ipm.rs b/src/routines/estimation/ipm.rs index 41cf1055e..d297babdb 100644 --- a/src/routines/estimation/ipm.rs +++ b/src/routines/estimation/ipm.rs @@ -563,10 +563,9 @@ pub fn burke_log(log_psi: &Psi) -> anyhow::Result<(Weights, f64)> { /// /// Returns an error if the underlying IPM optimization fails. pub fn burke_ipm(psi: &Psi) -> anyhow::Result<(Weights, f64)> { - if psi.is_log_space() { - burke_log(psi) - } else { - burke(psi) + match psi.space() { + crate::structs::psi::Space::Linear => burke(psi), + crate::structs::psi::Space::Log => burke_log(psi), } } @@ -813,7 +812,6 @@ mod tests { fn test_burke_log_identity() { // Test with identity matrix converted to log space // log(1) = 0, log(0) = -inf, but we use a small positive value instead - use crate::structs::psi::PsiBuilder; use ndarray::Array2; let n = 10; @@ -826,7 +824,8 @@ mod tests { } }); - let psi = PsiBuilder::new(log_mat).log_space(true).build(); + let mat = Mat::from_fn(n, n, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); let (lam, obj) = burke_log(&psi).unwrap(); @@ -847,14 +846,15 @@ mod tests { fn test_burke_log_uniform() { // Test with uniform matrix in log space // log(1) = 0 everywhere - use crate::structs::psi::PsiBuilder; + use ndarray::Array2; let n_sub = 10; let n_point = 10; let log_mat = Array2::from_shape_fn((n_sub, n_point), |_| 0.0); // log(1) = 0 - let psi = PsiBuilder::new(log_mat).log_space(true).build(); + let mat = Mat::from_fn(n_sub, n_point, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); let (lam, obj) = burke_log(&psi).unwrap(); @@ -875,7 +875,6 @@ mod tests { fn test_burke_log_consistency_with_regular() { // Test that burke_log produces the same results as burke // when given equivalent inputs - use crate::structs::psi::PsiBuilder; use ndarray::Array2; let n_sub = 5; @@ -889,11 +888,13 @@ mod tests { // Create the equivalent log-space matrix let log_mat = regular_mat.mapv(|x| x.ln()); - let log_psi = PsiBuilder::new(log_mat).log_space(true).build(); + + let mat = Mat::from_fn(n_sub, n_point, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); // Run both algorithms let (lam_regular, obj_regular) = burke(®ular_psi).unwrap(); - let (lam_log, obj_log) = burke_log(&log_psi).unwrap(); + let (lam_log, obj_log) = burke_log(&psi).unwrap(); // The weights should be very similar for i in 0..n_point { @@ -908,7 +909,6 @@ mod tests { fn test_burke_log_handles_very_small_likelihoods() { // Test that log-space IPM handles very small likelihoods that would // underflow in regular space - use crate::structs::psi::PsiBuilder; use ndarray::Array2; let n_sub = 5; @@ -920,7 +920,8 @@ mod tests { -500.0 + (i as f64) * 0.1 + (j as f64) * 0.05 }); - let psi = PsiBuilder::new(log_mat).log_space(true).build(); + let mat = Mat::from_fn(n_point, n_sub, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); // This should succeed without underflow issues let result = burke_log(&psi); @@ -944,7 +945,6 @@ mod tests { #[test] fn test_burke_log_with_varying_magnitudes() { // Test with log-likelihoods of varying magnitudes - use crate::structs::psi::PsiBuilder; use ndarray::Array2; let n_sub = 8; @@ -960,7 +960,8 @@ mod tests { } }); - let psi = PsiBuilder::new(log_mat).log_space(true).build(); + let mat = Mat::from_fn(n_point, n_sub, |i, j| log_mat[(j, i)]); + let psi = Psi::new_log(mat); let (lam, obj) = burke_log(&psi).unwrap(); diff --git a/src/routines/estimation/qr.rs b/src/routines/estimation/qr.rs index 81fd66c23..bda5e21e0 100644 --- a/src/routines/estimation/qr.rs +++ b/src/routines/estimation/qr.rs @@ -34,36 +34,39 @@ fn logsumexp(values: &[f64]) -> f64 { pub fn qrd(psi: &Psi) -> Result<(Mat, Vec)> { let mut mat = psi.matrix().to_owned(); - if psi.is_log_space() { - // For log-space: apply softmax normalization - // softmax(x)_i = exp(x_i) / sum(exp(x_j)) = exp(x_i - logsumexp(x)) - for index in 0..mat.nrows() { - let log_values: Vec = (0..mat.ncols()).map(|j| *mat.get(index, j)).collect(); - let log_sum = logsumexp(&log_values); - - if log_sum.is_infinite() && log_sum < 0.0 { - bail!( - "In log_psi, the row with index {} has all -inf values (zero probability)", - index - ); - } - - // Convert to normalized probabilities via softmax - for j in 0..mat.ncols() { - *mat.get_mut(index, j) = (log_values[j] - log_sum).exp(); + match psi.space() { + crate::structs::psi::Space::Linear => { + // For regular space: normalize rows to sum to 1 + for (index, row) in mat.row_iter_mut().enumerate() { + let row_sum: f64 = row.as_ref().iter().sum(); + + if row_sum.abs() == 0.0 { + bail!("In psi, the row with index {} sums to zero", index); + } + row.iter_mut().for_each(|x| *x /= row_sum); } } - } else { - // For regular space: normalize rows to sum to 1 - for (index, row) in mat.row_iter_mut().enumerate() { - let row_sum: f64 = row.as_ref().iter().sum(); - - if row_sum.abs() == 0.0 { - bail!("In psi, the row with index {} sums to zero", index); + crate::structs::psi::Space::Log => { + // For log-space: apply softmax normalization + // softmax(x)_i = exp(x_i) / sum(exp(x_j)) = exp(x_i - logsumexp(x)) + for index in 0..mat.nrows() { + let log_values: Vec = (0..mat.ncols()).map(|j| *mat.get(index, j)).collect(); + let log_sum = logsumexp(&log_values); + + if log_sum.is_infinite() && log_sum < 0.0 { + bail!( + "In log_psi, the row with index {} has all -inf values (zero probability)", + index + ); + } + + // Convert to normalized probabilities via softmax + for j in 0..mat.ncols() { + *mat.get_mut(index, j) = (log_values[j] - log_sum).exp(); + } } - row.iter_mut().for_each(|x| *x /= row_sum); } - } + }; // Perform column pivoted QR decomposition let qr: ColPivQr = mat.col_piv_qr(); diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs index 827c9ddb5..3204acc9c 100644 --- a/src/routines/output/posterior.rs +++ b/src/routines/output/posterior.rs @@ -38,7 +38,10 @@ impl Posterior { } let psi_matrix = psi.matrix(); - let is_log_space = psi.is_log_space(); + let is_log_space = match psi.space() { + crate::structs::psi::Space::Linear => false, + crate::structs::psi::Space::Log => true, + }; // Calculate py[i] = sum_j(psi[i,j] * w[j]) for each subject i // In log-space: py[i] = logsumexp_j(log_psi[i,j] + log(w[j])) diff --git a/src/routines/settings.rs b/src/routines/settings.rs index b14409dab..46dcb1515 100644 --- a/src/routines/settings.rs +++ b/src/routines/settings.rs @@ -1,6 +1,7 @@ use crate::algorithms::Algorithm; use crate::routines::initialization::Prior; use crate::routines::output::OutputFile; +use crate::structs::psi::Space; use anyhow::{bail, Result}; use pharmsol::prelude::data::ErrorModels; @@ -276,7 +277,7 @@ pub struct Advanced { /// This prevents underflow issues when dealing with many observations or extreme parameter values. /// The log-sum-exp trick is used to maintain numerical stability in weighted sum operations. /// Default is true for better numerical properties. - pub log_space: bool, + pub space: Space, } impl Default for Advanced { @@ -285,7 +286,7 @@ impl Default for Advanced { min_distance: 1e-4, nm_steps: 100, tolerance: 1e-6, - log_space: true, + space: Space::Log, } } } diff --git a/src/structs/psi.rs b/src/structs/psi.rs index 07a4d8f59..bbc986cd6 100644 --- a/src/structs/psi.rs +++ b/src/structs/psi.rs @@ -12,6 +12,15 @@ use serde::{Deserialize, Serialize}; use super::theta::Theta; +/// Enum to represent whether the [Psi] matrix is in linear space or log space +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub enum Space { + /// Linear space (for regular likelihoods) + Linear, + /// Log space (for log-likelihoods) + Log, +} + /// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column) /// /// The matrix can store either regular likelihoods or log-likelihoods depending on how it was constructed. @@ -19,23 +28,30 @@ use super::theta::Theta; #[derive(Debug, Clone, PartialEq)] pub struct Psi { matrix: Mat, - /// Whether the matrix contains log-likelihoods (true) or regular likelihoods (false) - is_log_space: bool, + space: Space, } impl Psi { pub fn new() -> Self { Psi { matrix: Mat::new(), - is_log_space: false, + space: Space::Linear, } } /// Create a new Psi in log space - pub fn new_log() -> Self { + pub fn new_linear(mat: Mat) -> Self { Psi { - matrix: Mat::new(), - is_log_space: true, + matrix: mat, + space: Space::Linear, + } + } + + /// Create a new Psi in log space + pub fn new_log(mat: Mat) -> Self { + Psi { + matrix: mat, + space: Space::Log, } } @@ -43,9 +59,46 @@ impl Psi { &self.matrix } - /// Returns true if the matrix stores log-likelihoods - pub fn is_log_space(&self) -> bool { - self.is_log_space + /// Get the [Space] (Linear or Log) of the Psi matrix + pub fn space(&self) -> Space { + self.space + } + + /// Set the [Space] (Linear or Log) of the Psi matrix + /// + /// Note: This does not update the actual matrix values, only the space flag. + pub fn set_space(&mut self, space: Space) { + self.space = space; + } + + /// Convert the Psi matrix to the specified [Space] (Linear or Log) + /// This modifies the matrix values accordingly. + pub fn to_space(&mut self, space: Space) -> &mut Self { + match (space, self.space) { + (Space::Linear, Space::Log) => { + // Convert from log to linear + for col in self.matrix.col_iter_mut() { + col.iter_mut().for_each(|val| { + *val = val.exp(); + }); + } + } + (Space::Log, Space::Linear) => { + // Convert from linear to log + + for col in self.matrix.col_iter_mut() { + col.iter_mut().for_each(|val| { + *val = val.ln(); + }); + } + } + _ => { + // No conversion needed + } + } + + self.space = space; + self } pub fn nspp(&self) -> usize { @@ -124,7 +177,7 @@ impl Psi { Ok(Psi { matrix: mat, - is_log_space: false, + space: Space::Linear, }) } } @@ -140,7 +193,7 @@ impl From> for Psi { let matrix = array.view().into_faer().to_owned(); Psi { matrix, - is_log_space: false, + space: Space::Linear, } } } @@ -149,7 +202,7 @@ impl From> for Psi { fn from(matrix: Mat) -> Self { Psi { matrix, - is_log_space: false, + space: Space::Linear, } } } @@ -159,7 +212,7 @@ impl From> for Psi { let matrix = array_view.into_faer().to_owned(); Psi { matrix, - is_log_space: false, + space: Space::Linear, } } } @@ -169,7 +222,7 @@ impl From<&Array2> for Psi { let matrix = array.view().into_faer().to_owned(); Psi { matrix, - is_log_space: false, + space: Space::Linear, } } } @@ -196,7 +249,7 @@ impl Serialize for Psi { } impl<'de> Deserialize<'de> for Psi { - fn deserialize(deserializer: D) -> std::result::Result + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { @@ -244,10 +297,9 @@ impl<'de> Deserialize<'de> for Psi { // Create matrix from rows let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); - Ok(Psi { - matrix: mat, - is_log_space: false, - }) + let psi = Psi::new_linear(mat); + + Ok(psi) } } @@ -255,34 +307,6 @@ impl<'de> Deserialize<'de> for Psi { } } -/// Helper struct for creating Psi with log-space flag -pub struct PsiBuilder { - array: Array2, - is_log_space: bool, -} - -impl PsiBuilder { - pub fn new(array: Array2) -> Self { - Self { - array, - is_log_space: false, - } - } - - pub fn log_space(mut self, is_log: bool) -> Self { - self.is_log_space = is_log; - self - } - - pub fn build(self) -> Psi { - let matrix = self.array.view().into_faer().to_owned(); - Psi { - matrix, - is_log_space: self.is_log_space, - } - } -} - /// Calculate the likelihood matrix (regular space) pub(crate) fn calculate_psi( equation: &impl Equation, @@ -291,75 +315,33 @@ pub(crate) fn calculate_psi( error_models: &ErrorModels, progress: bool, cache: bool, + space: Space, ) -> Result { - let psi_ndarray = psi( - equation, - subjects, - &theta.matrix().clone().as_ref().into_ndarray().to_owned(), - error_models, - progress, - cache, - )?; - - Ok(PsiBuilder::new(psi_ndarray).log_space(false).build()) -} - -/// Calculate the log-likelihood matrix (log space) -/// -/// This computes log-likelihoods directly, which is numerically more stable -/// than computing likelihoods and then taking logarithms. This is especially -/// important when dealing with many observations or extreme parameter values. -pub(crate) fn calculate_log_psi( - equation: &impl Equation, - subjects: &Data, - theta: &Theta, - error_models: &ErrorModels, - progress: bool, - cache: bool, -) -> Result { - let log_psi_ndarray = log_psi( - equation, - subjects, - &theta.matrix().clone().as_ref().into_ndarray().to_owned(), - error_models, - progress, - cache, - )?; - - Ok(PsiBuilder::new(log_psi_ndarray).log_space(true).build()) -} - -/// Unified psi calculation that dispatches based on the `use_log_space` flag. -/// -/// This function eliminates the need for repeated if/else blocks in algorithm code. -/// -/// # Arguments -/// -/// * `equation` - The model equation -/// * `subjects` - The data -/// * `theta` - The support points -/// * `error_models` - The error models -/// * `progress` - Whether to show progress -/// * `cache` - Whether to use caching -/// * `use_log_space` - If true, calculates log-likelihoods; otherwise, regular likelihoods -/// -/// # Returns -/// -/// A Psi matrix with the appropriate `is_log_space` flag set. -pub(crate) fn calculate_psi_dispatch( - equation: &impl Equation, - subjects: &Data, - theta: &Theta, - error_models: &ErrorModels, - progress: bool, - cache: bool, - use_log_space: bool, -) -> Result { - if use_log_space { - calculate_log_psi(equation, subjects, theta, error_models, progress, cache) - } else { - calculate_psi(equation, subjects, theta, error_models, progress, cache) - } + let psi_mat = match space { + Space::Linear => psi( + equation, + subjects, + &theta.matrix().clone().as_ref().into_ndarray().to_owned(), + error_models, + progress, + cache, + ), + Space::Log => log_psi( + equation, + subjects, + &theta.matrix().clone().as_ref().into_ndarray().to_owned(), + error_models, + progress, + cache, + ), + }?; + + let psi = Psi { + matrix: psi_mat.view().into_faer().to_owned(), + space, + }; + + Ok(psi) } #[cfg(test)]