From be3564d7ee163eb4584a154dbea7550234670562 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Fri, 6 Mar 2026 22:52:36 +0000 Subject: [PATCH 1/4] Implement approx_tanh for ROCm using OCML tanh function AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction like NVIDIA's PTX tanh.approx. This implements approx_tanh for ROCm using: - For f32 (and f16/bf16 via casting): Triton's __triton_hip_fast_tanhf which uses a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) - For f64: OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) Changes: - Add f64 support to approx_tanh function - Add ROCm platform detection in _elementwise_inline_asm_lowering - Add _approx_tanh_rocm_lowering function for ROCm-specific lowering - Add test_approx_tanh test with f16/bf16/f32/f64 support See: https://github.com/triton-lang/triton/pull/7780 (cherry picked from commit 39ceb951a029a63aba783b6dd61ed74548408916) --- jax/_src/pallas/triton/primitives.py | 92 ++++++++++++++++++++++++++++ tests/pallas/ops_test.py | 41 +++++++++++++ 2 files changed, 133 insertions(+) diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8e763c3d8e6a..25423ebba471 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -48,6 +48,11 @@ def approx_tanh(x: jax.Array) -> jax.Array: elif x.dtype == jnp.float32: asm = "tanh.approx.f32 $0, $1;" constraint = "f" + elif x.dtype == jnp.float64: + # f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64) + # CUDA does not have a PTX instruction for f64 approximate tanh + asm = "tanh.approx.f64 $0, $1;" + constraint = "d" else: raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") @@ -119,6 +124,13 @@ def _elementwise_inline_asm_lowering( result_shape_dtypes, ): del result_shape_dtypes # Unused. + + # For ROCm, PTX inline assembly is not supported. For tanh.approx, we use + # Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and + # OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780 + if ctx.context.platform == "rocm" and "tanh.approx" in asm: + return _approx_tanh_rocm_lowering(ctx, *args) + return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], asm, @@ -129,6 +141,86 @@ def _elementwise_inline_asm_lowering( ).result +def _approx_tanh_rocm_lowering( + ctx: lowering.LoweringRuleContext, + *args, +): + """Lower approx_tanh for ROCm. + + AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. + + For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf + which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + See: https://github.com/triton-lang/triton/pull/7780 + + For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) + since fast_tanhf only supports f32. + """ + from jax._src.lib.mlir import ir + from jax._src.lib.mlir.dialects import arith as arith_dialect + + [arg] = args + [out_aval] = ctx.avals_out + in_dtype = ctx.avals_in[0].dtype + + # Helper to get IR type for a dtype + def dtype_to_ir_type(dtype): + dtype = jnp.dtype(dtype) + return mlir.dtype_to_ir_type(dtype) + + # f64: use __ocml_tanh_f64 (fast_tanhf only supports f32) + if in_dtype == jnp.float64: + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="", + libpath="", + symbol="__ocml_tanh_f64", + pure=True, + ) + return [result] + + # fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back. + needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) + + if needs_cast: + # Cast input to f32 (extend) + f32_type = dtype_to_ir_type(jnp.float32) + if out_aval.shape: + f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) + else: + f32_result_type = f32_type + arg_f32 = arith_dialect.extf(f32_result_type, arg) + + # Call __triton_hip_fast_tanhf (fast exp-based implementation) + tanh_result = tt_dialect.extern_elementwise( + f32_result_type, + [arg_f32], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + # Cast result back to original dtype (truncate) + out_type = mlir.aval_to_ir_type(out_aval) + result = arith_dialect.truncf(out_type, tanh_result) + else: + # f32: call __triton_hip_fast_tanhf directly + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + return [result] + + def debug_barrier() -> None: """Synchronizes all kernel executions in the grid.""" return debug_barrier_p.bind() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e8933618b0e0..40ae54089ee2 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1878,6 +1878,47 @@ def kernel(o_ref): np.testing.assert_allclose(f(), kernel()) + @parameterized.parameters("float16", "bfloat16", "float32", "float64") + def test_approx_tanh(self, dtype): + self.skip_if_mosaic_gpu() + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpret mode") + + if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + # Enable x64 for f64 test if not already enabled, restore after test + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/jax-ml/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + @parameterized.parameters( ((2, 4), (8,)), ((2, 4), (8, 1)), From 6f87beb247728902d6eaf89c09815a3fa9693386 Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Wed, 4 Mar 2026 20:52:37 +0000 Subject: [PATCH 2/4] Address review comments for approx_tanh ROCm implementation - Remove verbose comment in _elementwise_inline_asm_lowering - Inline dtype_to_ir_type helper, use mlir.dtype_to_ir_type directly - Move ir and arith_dialect imports to top-level - Add TypeError for float64 on non-ROCm platforms - Simplify _approx_tanh_rocm_lowering with needs_cast pattern - Move test_approx_tanh from ops_test.py to triton_pallas_test.py - Fix triton_pallas_test setUp to allow ROCm devices (cherry picked from commit 600fbd3968501504991f29a49bebd9433aa9dcdd) --- jax/_src/pallas/triton/primitives.py | 72 ++++++++++------------------ tests/pallas/ops_test.py | 41 ---------------- tests/pallas/triton_pallas_test.py | 16 +++++-- 3 files changed, 38 insertions(+), 91 deletions(-) diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 25423ebba471..38f1a1e0c979 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -22,6 +22,9 @@ import jax from jax._src import core as jax_core from jax._src import state +from jax._src import tree_util +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import primitives as pallas_primitives @@ -125,11 +128,14 @@ def _elementwise_inline_asm_lowering( ): del result_shape_dtypes # Unused. - # For ROCm, PTX inline assembly is not supported. For tanh.approx, we use - # Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and - # OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780 - if ctx.context.platform == "rocm" and "tanh.approx" in asm: - return _approx_tanh_rocm_lowering(ctx, *args) + if "tanh.approx" in asm: + if ctx.context.platform == "rocm": + return _approx_tanh_rocm_lowering(ctx, *args) + if ctx.avals_in[0].dtype == jnp.float64: + raise TypeError( + "approx_tanh does not support float64 on CUDA; it is only" + " supported on ROCm" + ) return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], @@ -148,27 +154,12 @@ def _approx_tanh_rocm_lowering( """Lower approx_tanh for ROCm. AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. - - For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf - which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) See: https://github.com/triton-lang/triton/pull/7780 - - For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) - since fast_tanhf only supports f32. """ - from jax._src.lib.mlir import ir - from jax._src.lib.mlir.dialects import arith as arith_dialect - [arg] = args [out_aval] = ctx.avals_out in_dtype = ctx.avals_in[0].dtype - # Helper to get IR type for a dtype - def dtype_to_ir_type(dtype): - dtype = jnp.dtype(dtype) - return mlir.dtype_to_ir_type(dtype) - - # f64: use __ocml_tanh_f64 (fast_tanhf only supports f32) if in_dtype == jnp.float64: result_type = mlir.aval_to_ir_type(out_aval) result = tt_dialect.extern_elementwise( @@ -181,42 +172,29 @@ def dtype_to_ir_type(dtype): ) return [result] - # fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back. needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) if needs_cast: - # Cast input to f32 (extend) - f32_type = dtype_to_ir_type(jnp.float32) + f32_type = mlir.dtype_to_ir_type(jnp.dtype(jnp.float32)) if out_aval.shape: f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) else: f32_result_type = f32_type - arg_f32 = arith_dialect.extf(f32_result_type, arg) - - # Call __triton_hip_fast_tanhf (fast exp-based implementation) - tanh_result = tt_dialect.extern_elementwise( - f32_result_type, - [arg_f32], - libname="libdevice", - libpath="", - symbol="__triton_hip_fast_tanhf", - pure=True, - ) + arg = arith_dialect.extf(f32_result_type, arg) + + result_type = f32_result_type if needs_cast else mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + [arg], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) - # Cast result back to original dtype (truncate) + if needs_cast: out_type = mlir.aval_to_ir_type(out_aval) - result = arith_dialect.truncf(out_type, tanh_result) - else: - # f32: call __triton_hip_fast_tanhf directly - result_type = mlir.aval_to_ir_type(out_aval) - result = tt_dialect.extern_elementwise( - result_type, - list(args), - libname="libdevice", - libpath="", - symbol="__triton_hip_fast_tanhf", - pure=True, - ) + result = arith_dialect.truncf(out_type, result) return [result] diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 40ae54089ee2..e8933618b0e0 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1878,47 +1878,6 @@ def kernel(o_ref): np.testing.assert_allclose(f(), kernel()) - @parameterized.parameters("float16", "bfloat16", "float32", "float64") - def test_approx_tanh(self, dtype): - self.skip_if_mosaic_gpu() - - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented on TPU") - - if self.INTERPRET: - self.skipTest("approx_tanh is not supported in interpret mode") - - if (dtype == "bfloat16" and - jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") - - if dtype == "float64": - if jtu.test_device_matches(["cuda"]): - self.skipTest("f64 approx_tanh is only supported on ROCm") - - # Enable x64 for f64 test if not already enabled, restore after test - original_x64 = jax.config.x64_enabled - if dtype == "float64" and not original_x64: - jax.config.update("jax_enable_x64", True) - self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) - - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), - ) - def kernel(x_ref, o_ref): - o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) - - x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) - # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/jax-ml/jax/issues/11014. - np.testing.assert_allclose( - kernel(x).astype(jnp.float32), - jnp.tanh(x).astype(jnp.float32), - atol=5e-3, - rtol=5e-3, - ) - @parameterized.parameters( ((2, 4), (8,)), ((2, 4), (8, 1)), diff --git a/tests/pallas/triton_pallas_test.py b/tests/pallas/triton_pallas_test.py index e113ad63823e..723f2ef98fca 100644 --- a/tests/pallas/triton_pallas_test.py +++ b/tests/pallas/triton_pallas_test.py @@ -47,8 +47,8 @@ def setUp(self): if not self.INTERPRET: self.skipTest("On CPU the test works only in interpret mode") elif jtu.test_device_matches(["gpu"]): - if jtu.test_device_matches(["cuda"]) and \ - not jtu.is_cuda_compute_capability_at_least("9.0"): + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") if plgpu is None: self.skipTest("plgpu not available on this platform") @@ -287,15 +287,25 @@ def kernel(x_ref, mask_ref, o_ref): with self.assertRaisesRegex(ValueError, "Cannot broadcast"): kernel(x, mask) - @parameterized.parameters("float16", "bfloat16", "float32") + @parameterized.parameters("float16", "bfloat16", "float32", "float64") def test_approx_tanh(self, dtype): if self.INTERPRET: self.skipTest("approx_tanh is not supported in interpret mode") if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), ) From 7fa83ec1fd8e3ce742d805b6d5e52d4210dd5688 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Mon, 2 Mar 2026 21:01:43 +0000 Subject: [PATCH 3/4] Skip test_topology_jit_serialize on ROCm GPU The test is already skipped on CUDA (b/442353988) due to HLO debug metadata (source column numbers) being embedded in compiled output, causing semantically identical compilations to produce different as_text() results. The same issue occurs on ROCm. (cherry picked from commit 70b2b99a7fd9c604a8fb3652afb518a9d5b5e38f) --- tests/aot_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/aot_test.py b/tests/aot_test.py index 74da4c8c9923..aeecb1ada52b 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -75,7 +75,7 @@ def test_topology_jit_serialize(self): if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: raise unittest.SkipTest('Compilation caching not yet supported.') - if jtu.is_device_cuda(): + if jtu.is_device_cuda() or jtu.is_device_rocm(): raise unittest.SkipTest('Broken on GPU: b/442353988') @jax.jit From d2c6ccaaa57b300ebf739e9a42199df0555d7046 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Mon, 2 Mar 2026 15:09:32 -0600 Subject: [PATCH 4/4] Update tests/aot_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> (cherry picked from commit 5ec841984ad994292defb75ba0882ec78275ee88) --- tests/aot_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/aot_test.py b/tests/aot_test.py index aeecb1ada52b..b3760b1fcf05 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -75,7 +75,7 @@ def test_topology_jit_serialize(self): if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: raise unittest.SkipTest('Compilation caching not yet supported.') - if jtu.is_device_cuda() or jtu.is_device_rocm(): + if jtu.test_device_matches(['gpu']): raise unittest.SkipTest('Broken on GPU: b/442353988') @jax.jit