diff --git a/qwix/_src/core/qarray.py b/qwix/_src/core/qarray.py index 4b17872..7ac5d0d 100644 --- a/qwix/_src/core/qarray.py +++ b/qwix/_src/core/qarray.py @@ -204,6 +204,9 @@ def take(x: jax.Array) -> jax.Array: return jax.tree.map(take, array) +_VALID_SCALE_DTYPES = (jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64) + + def validate_qarray(array: QArray): """Validates the internal consistency of a QArray.""" if not isinstance(array.qvalue, jax.Array): @@ -220,7 +223,7 @@ def validate_qarray(array: QArray): ) if array.qvalue.dtype.itemsize > 1: raise ValueError(f'{array.qvalue.dtype} is not a valid type for qvalue.') - if array.scale.dtype not in (jnp.bfloat16, jnp.float32, jnp.float64): + if array.scale.dtype not in _VALID_SCALE_DTYPES: raise ValueError(f'{array.scale.dtype} is not a valid type for scale.') if array.zero_point is not None: if array.zero_point.ndim != array.qvalue.ndim: