diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 0b43ac8d..790fe333 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -231,6 +231,12 @@ def __call__(self, x_BxLxD: jax.Array): q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + # Ensure Q, K, V have consistent dtypes. Operations like qk_norm and + # attn_scale multiplication can promote Q/K to a higher precision (e.g. + # float32) while V stays in the computation dtype (e.g. bfloat16). + # jax.nn.dot_product_attention requires K and V to have matching dtypes. + v_BxLxHxDh = v_BxLxHxDh.astype(k_BxLxHxDh.dtype) + out_BxLxHxDh = jax.nn.dot_product_attention( q_BxLxHxDh, k_BxLxHxDh,