Skip to content

Numerical instability in _exact_wnorm / parameter information gain for small Dirichlet counts (MINVAL sensitivity, EFE oscillations) #340

@conorheins

Description

@conorheins

Summary

A user reported bumpy / oscillatory minima in EFE and degraded control performance on v1.0.0_alpha when use_param_info_gain=True and MINVAL = jnp.finfo(float).eps. Setting MINVAL = 1e-8 reduced oscillations and improved performance. This points to a numerical stability issue, likely in the parameter information gain path (calc_pA_info_gain / calc_pB_info_gain).

Suspected cause

_exact_wnorm currently computes terms including 1/A and digamma(A) for Dirichlet counts A. For very small concentrations, 1/A and digamma(A) become huge with opposite signs and should cancel (since digamma(x) ~ -1/x as x→0). In finite precision (esp. float32/GPU), this cancellation can be unstable, causing large fluctuations in wnormparam_info_gain → EFE.

Raising MINVAL reduces intermediate magnitudes, but changes the effective model, so I prefer a principled fix.

Proposed fix

Rewrite _exact_wnorm to eliminate explicit 1/x terms using:

$$\psi(x+1) = \psi(x) + \frac{1}{x} $$

so that 1/A + digamma(A) becomes digamma(A+1). Equivalent but more numerically stable:

$$w = \log(\alpha_0) - \log(\alpha) + \psi(\alpha+1) - \psi(\alpha_0+1) $$

(up to the historical minus sign in the current implementation).

Tests to add

  • _exact_wnorm matches the stable Dirichlet-KL closed form above (within tolerance).
  • “Small symmetric concentration” stress-case (uniform tiny counts) stays O(1) (≈ log(K)), not O(1/t), and gradients are finite.
  • calc_pA_info_gain matches a direct mean-field reference computed from _exact_wnorm.

Links

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions