diff --git a/src/phlashlib/gpu.py b/src/phlashlib/gpu.py index aec186e..7e78116 100644 --- a/src/phlashlib/gpu.py +++ b/src/phlashlib/gpu.py @@ -166,7 +166,10 @@ def _gpu_ll_fwd( data, vmap_method="broadcast_all", ) - return f, jax.tree.map(lambda a, b: jnp.astype(a, b.dtype), df, log_pp) + return_dtype = jax.eval_shape(_gpu_ll_helper, log_pp, data).dtype + f = f.astype(return_dtype) + df = jax.tree.map(lambda a, b: jnp.astype(a, b.dtype), df, log_pp) + return f, df def _gpu_ll_bwd(df: PSMCParamsType, g: ScalarLike) -> tuple[PSMCParamsType, None]: