diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8e763c3d8e6a..38f1a1e0c979 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -22,6 +22,9 @@ import jax from jax._src import core as jax_core from jax._src import state +from jax._src import tree_util +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import primitives as pallas_primitives @@ -48,6 +51,11 @@ def approx_tanh(x: jax.Array) -> jax.Array: elif x.dtype == jnp.float32: asm = "tanh.approx.f32 $0, $1;" constraint = "f" + elif x.dtype == jnp.float64: + # f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64) + # CUDA does not have a PTX instruction for f64 approximate tanh + asm = "tanh.approx.f64 $0, $1;" + constraint = "d" else: raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") @@ -119,6 +127,16 @@ def _elementwise_inline_asm_lowering( result_shape_dtypes, ): del result_shape_dtypes # Unused. + + if "tanh.approx" in asm: + if ctx.context.platform == "rocm": + return _approx_tanh_rocm_lowering(ctx, *args) + if ctx.avals_in[0].dtype == jnp.float64: + raise TypeError( + "approx_tanh does not support float64 on CUDA; it is only" + " supported on ROCm" + ) + return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], asm, @@ -129,6 +147,58 @@ def _elementwise_inline_asm_lowering( ).result +def _approx_tanh_rocm_lowering( + ctx: lowering.LoweringRuleContext, + *args, +): + """Lower approx_tanh for ROCm. + + AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. + See: https://github.com/triton-lang/triton/pull/7780 + """ + [arg] = args + [out_aval] = ctx.avals_out + in_dtype = ctx.avals_in[0].dtype + + if in_dtype == jnp.float64: + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="", + libpath="", + symbol="__ocml_tanh_f64", + pure=True, + ) + return [result] + + needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) + + if needs_cast: + f32_type = mlir.dtype_to_ir_type(jnp.dtype(jnp.float32)) + if out_aval.shape: + f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) + else: + f32_result_type = f32_type + arg = arith_dialect.extf(f32_result_type, arg) + + result_type = f32_result_type if needs_cast else mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + [arg], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + if needs_cast: + out_type = mlir.aval_to_ir_type(out_aval) + result = arith_dialect.truncf(out_type, result) + + return [result] + + def debug_barrier() -> None: """Synchronizes all kernel executions in the grid.""" return debug_barrier_p.bind() diff --git a/tests/aot_test.py b/tests/aot_test.py index 74da4c8c9923..b3760b1fcf05 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -75,7 +75,7 @@ def test_topology_jit_serialize(self): if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: raise unittest.SkipTest('Compilation caching not yet supported.') - if jtu.is_device_cuda(): + if jtu.test_device_matches(['gpu']): raise unittest.SkipTest('Broken on GPU: b/442353988') @jax.jit diff --git a/tests/pallas/triton_pallas_test.py b/tests/pallas/triton_pallas_test.py index e113ad63823e..723f2ef98fca 100644 --- a/tests/pallas/triton_pallas_test.py +++ b/tests/pallas/triton_pallas_test.py @@ -47,8 +47,8 @@ def setUp(self): if not self.INTERPRET: self.skipTest("On CPU the test works only in interpret mode") elif jtu.test_device_matches(["gpu"]): - if jtu.test_device_matches(["cuda"]) and \ - not jtu.is_cuda_compute_capability_at_least("9.0"): + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") if plgpu is None: self.skipTest("plgpu not available on this platform") @@ -287,15 +287,25 @@ def kernel(x_ref, mask_ref, o_ref): with self.assertRaisesRegex(ValueError, "Cannot broadcast"): kernel(x, mask) - @parameterized.parameters("float16", "bfloat16", "float32") + @parameterized.parameters("float16", "bfloat16", "float32", "float64") def test_approx_tanh(self, dtype): if self.INTERPRET: self.skipTest("approx_tanh is not supported in interpret mode") if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), )