Skip to content

Replace jax.scipy.stats.norm.cdf with jax.scipy.special.ndtr for improved efficiency #200

@SaFE-APIOpt

Description

@SaFE-APIOpt

p = 1. - jax.scipy.stats.norm.cdf(noise_required_to_win)

Current Code:
p = 1. - jax.scipy.stats.norm.cdf(noise_required_to_win)
Recommended Replacement:

from jax.scipy.special import ndtr

p = 1. - ndtr(noise_required_to_win)

The current code uses jax.scipy.stats.norm.cdf, which is based on a frozen distribution object (rv_continuous). This approach is flexible but adds unnecessary overhead for computing the standard normal cumulative distribution function (CDF), especially in performance-critical JAX pipelines.

In contrast, jax.scipy.special.ndtr is a low-level, efficient implementation of the standard normal CDF. It avoids object instantiation and works better with JAX’s JIT compilation and vectorized operations.

Both implementations are numerically equivalent, but ndtr is faster and more appropriate when only the standard normal distribution is needed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions