From a61fcd45cb083a8dd62fc7e88667b28ee5c702f9 Mon Sep 17 00:00:00 2001 From: Jonathan Terhorst Date: Thu, 23 Oct 2025 15:58:57 -0400 Subject: [PATCH] enforce return type consistency in gpu_ll --- src/phlashlib/gpu.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/phlashlib/gpu.py b/src/phlashlib/gpu.py index aec186e..7e78116 100644 --- a/src/phlashlib/gpu.py +++ b/src/phlashlib/gpu.py @@ -166,7 +166,10 @@ def _gpu_ll_fwd( data, vmap_method="broadcast_all", ) - return f, jax.tree.map(lambda a, b: jnp.astype(a, b.dtype), df, log_pp) + return_dtype = jax.eval_shape(_gpu_ll_helper, log_pp, data).dtype + f = f.astype(return_dtype) + df = jax.tree.map(lambda a, b: jnp.astype(a, b.dtype), df, log_pp) + return f, df def _gpu_ll_bwd(df: PSMCParamsType, g: ScalarLike) -> tuple[PSMCParamsType, None]: