Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions jax/_src/pallas/triton/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import jax
from jax._src import core as jax_core
from jax._src import state
from jax._src import tree_util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas import primitives as pallas_primitives
Expand All @@ -48,6 +51,11 @@ def approx_tanh(x: jax.Array) -> jax.Array:
elif x.dtype == jnp.float32:
asm = "tanh.approx.f32 $0, $1;"
constraint = "f"
elif x.dtype == jnp.float64:
# f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64)
# CUDA does not have a PTX instruction for f64 approximate tanh
asm = "tanh.approx.f64 $0, $1;"
constraint = "d"
else:
raise TypeError(f"approx_tanh does not accept {x.dtype} arrays")

Expand Down Expand Up @@ -119,6 +127,16 @@ def _elementwise_inline_asm_lowering(
result_shape_dtypes,
):
del result_shape_dtypes # Unused.

if "tanh.approx" in asm:
if ctx.context.platform == "rocm":
return _approx_tanh_rocm_lowering(ctx, *args)
if ctx.avals_in[0].dtype == jnp.float64:
raise TypeError(
"approx_tanh does not support float64 on CUDA; it is only"
" supported on ROCm"
)

return tt_dialect.ElementwiseInlineAsmOp(
[*map(mlir.aval_to_ir_type, ctx.avals_out)],
asm,
Expand All @@ -129,6 +147,58 @@ def _elementwise_inline_asm_lowering(
).result


def _approx_tanh_rocm_lowering(
ctx: lowering.LoweringRuleContext,
*args,
):
"""Lower approx_tanh for ROCm.

AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction.
See: https://github.com/triton-lang/triton/pull/7780
"""
[arg] = args
[out_aval] = ctx.avals_out
in_dtype = ctx.avals_in[0].dtype

if in_dtype == jnp.float64:
result_type = mlir.aval_to_ir_type(out_aval)
result = tt_dialect.extern_elementwise(
result_type,
list(args),
libname="",
libpath="",
symbol="__ocml_tanh_f64",
pure=True,
)
return [result]

needs_cast = in_dtype in (jnp.float16, jnp.bfloat16)

if needs_cast:
f32_type = mlir.dtype_to_ir_type(jnp.dtype(jnp.float32))
if out_aval.shape:
f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type)
else:
f32_result_type = f32_type
arg = arith_dialect.extf(f32_result_type, arg)

result_type = f32_result_type if needs_cast else mlir.aval_to_ir_type(out_aval)
result = tt_dialect.extern_elementwise(
result_type,
[arg],
libname="libdevice",
libpath="",
symbol="__triton_hip_fast_tanhf",
pure=True,
)

if needs_cast:
out_type = mlir.aval_to_ir_type(out_aval)
result = arith_dialect.truncf(out_type, result)

return [result]


def debug_barrier() -> None:
"""Synchronizes all kernel executions in the grid."""
return debug_barrier_p.bind()
Expand Down
2 changes: 1 addition & 1 deletion tests/aot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_topology_jit_serialize(self):

if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
raise unittest.SkipTest('Compilation caching not yet supported.')
if jtu.is_device_cuda():
if jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Broken on GPU: b/442353988')

@jax.jit
Expand Down
16 changes: 13 additions & 3 deletions tests/pallas/triton_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def setUp(self):
if not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
elif jtu.test_device_matches(["gpu"]):
if jtu.test_device_matches(["cuda"]) and \
not jtu.is_cuda_compute_capability_at_least("9.0"):
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Only works on GPU with capability >= sm90")
if plgpu is None:
self.skipTest("plgpu not available on this platform")
Expand Down Expand Up @@ -287,15 +287,25 @@ def kernel(x_ref, mask_ref, o_ref):
with self.assertRaisesRegex(ValueError, "Cannot broadcast"):
kernel(x, mask)

@parameterized.parameters("float16", "bfloat16", "float32")
@parameterized.parameters("float16", "bfloat16", "float32", "float64")
def test_approx_tanh(self, dtype):
if self.INTERPRET:
self.skipTest("approx_tanh is not supported in interpret mode")

if (dtype == "bfloat16" and
jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")

if dtype == "float64":
if jtu.test_device_matches(["cuda"]):
self.skipTest("f64 approx_tanh is only supported on ROCm")

original_x64 = jax.config.x64_enabled
if dtype == "float64" and not original_x64:
jax.config.update("jax_enable_x64", True)
self.addCleanup(lambda: jax.config.update("jax_enable_x64", False))

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype),
)
Expand Down