Skip to content

equinox.error_if calls can hurt computational efficiency #529

@michael-0brien

Description

@michael-0brien

We call equinox.error_if in various __init__s, like in the pose, CTF, and volumes in order to check that arrays are bounded correctly. This wraps jax.lax.cond, which is inefficient on GPU and may be hurting some users performance if they are initializing objects after JIT boundaries. There are two improvements to computational efficiency I'd suggest:

  • When equinox==0.13.3 is released make EQX_ON_ERROR=off the default for cryoJAX. Runtime error checking is only performed if the user wants.
  • We currently have calls on arrays passed to atom volumes that make sure array values are non-negative (gaussian amplitudes, B-factors). This scales as $O(N_{atoms})$ (even if EQX_ON_ERROR=off), which is not a good pattern. I propose removing these checks all together.

In the meantime, users should avoid instantiating certain cryoJAX objects across JIT boundaries for more performant code, or they should install the equinox dev branch (at the time of writing this) to access the EQX_ON_ERROR=off option.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions