|
p = 1. - jax.scipy.stats.norm.cdf(noise_required_to_win) |
Current Code:
p = 1. - jax.scipy.stats.norm.cdf(noise_required_to_win)
Recommended Replacement:
from jax.scipy.special import ndtr
p = 1. - ndtr(noise_required_to_win)
The current code uses jax.scipy.stats.norm.cdf, which is based on a frozen distribution object (rv_continuous). This approach is flexible but adds unnecessary overhead for computing the standard normal cumulative distribution function (CDF), especially in performance-critical JAX pipelines.
In contrast, jax.scipy.special.ndtr is a low-level, efficient implementation of the standard normal CDF. It avoids object instantiation and works better with JAX’s JIT compilation and vectorized operations.
Both implementations are numerically equivalent, but ndtr is faster and more appropriate when only the standard normal distribution is needed.