-
Notifications
You must be signed in to change notification settings - Fork 119
Description
Summary
A user reported bumpy / oscillatory minima in EFE and degraded control performance on v1.0.0_alpha when use_param_info_gain=True and MINVAL = jnp.finfo(float).eps. Setting MINVAL = 1e-8 reduced oscillations and improved performance. This points to a numerical stability issue, likely in the parameter information gain path (calc_pA_info_gain / calc_pB_info_gain).
Suspected cause
_exact_wnorm currently computes terms including 1/A and digamma(A) for Dirichlet counts A. For very small concentrations, 1/A and digamma(A) become huge with opposite signs and should cancel (since digamma(x) ~ -1/x as x→0). In finite precision (esp. float32/GPU), this cancellation can be unstable, causing large fluctuations in wnorm → param_info_gain → EFE.
Raising MINVAL reduces intermediate magnitudes, but changes the effective model, so I prefer a principled fix.
Proposed fix
Rewrite _exact_wnorm to eliminate explicit 1/x terms using:
so that 1/A + digamma(A) becomes digamma(A+1). Equivalent but more numerically stable:
(up to the historical minus sign in the current implementation).
Tests to add
_exact_wnormmatches the stable Dirichlet-KL closed form above (within tolerance).- “Small symmetric concentration” stress-case (uniform tiny counts) stays O(1) (≈
log(K)), not O(1/t), and gradients are finite. calc_pA_info_gainmatches a direct mean-field reference computed from_exact_wnorm.
Links
- Related:
get_likelihood_single_modalitycomputes expected probability instead of expected log probability for soft observations #336 (distributed obs likelihood issue) - Fix applied: Fix distributed obs likelihood #337
- Related: Implementing exact computation of expected information gain about parameters, i.e. novelty, in pymdp #191 (replacing approximate spm_wnorm with exact info gain)
- Fix applied: Fixed param info gain computation and created tests - Fixes Issue #191 #193