From 3f4a69e9b501d432d8b39125efe36c530f82051d Mon Sep 17 00:00:00 2001 From: Magaonka Date: Thu, 29 Jan 2026 20:28:52 -0600 Subject: [PATCH 1/6] Add missing sharding test case from upstream Update test_collectives to include upstream test case for nested sharding configuration ((None, ("dp", "tp"), None), (None, ("dp"), None)). Also restructure expected_hlos from flat tuples to nested lists to support multiple patterns per configuration and improve test granularity with subTest loop. --- tests/scaled_matmul_stablehlo_test.py | 33 +++++++++++++++------------ 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index a88e3ca772d5..75c26737ac25 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -44,17 +44,19 @@ ((None, "dp", "tp"), (None, "dp", "tp")), ((None, "tp", None), (None, "tp", None)), ((None, None, "tp"), (None, "tp", None)), + ((None, ("dp", "tp"), None), (None, ("dp"), None)), ] c_name = "__cudnn$blockScaledDot" expected_hlos = [ - (c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), - ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), - (c_name,), - ("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]", c_name), - (c_name,), - ("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name), + [("all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), (c_name,)], + [("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [(c_name,)], + [("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [(c_name,)], + [("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [("all-gather", "f8e4m3fn[2,256,1024]", "replica_groups=[2,2]<=[2,2]"), (c_name,)], ] expected_output_spec = [ PartitionSpec('dp',), @@ -65,13 +67,14 @@ PartitionSpec(None, 'dp'), PartitionSpec(None, 'tp', None), PartitionSpec(None, None, 'tp'), + PartitionSpec(None, ('dp', 'tp'), None), ] # The GSPMD sharding logic inserts additional reduce-scatters which don't exist # in Shardy. if not config.use_shardy_partitioner.value: expected_output_spec[5] = PartitionSpec(None, 'dp', 'tp') - expected_hlos[5] += ("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}") + expected_hlos[5] += [("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}")] sharding_configs = { input_sharding: (hlo, output_spec) @@ -290,12 +293,12 @@ def test_collectives(self, in_shardings, block_scale_configs): expected_hlo = sharding_configs[in_shardings][0] hlo_text = get_hlo_text(in_shardings, block_scale_configs) - hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL - ) - self.assertRegex( - hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}" - ) + for expected_hlo_patterns in expected_hlo: + hlo_pattern_str = r".*".join(map(re.escape, expected_hlo_patterns)) + hlo_pattern = re.compile(hlo_pattern_str, flags=re.DOTALL) + # Check all patterns in case of failures + with self.subTest(pattern=hlo_pattern_str): + self.assertRegex(hlo_text, hlo_pattern, msg=f"Failed to find pattern: {hlo_pattern_str}") @jtu.sample_product( contract=[160, 96], From 5ab925796585d8269fd0356fccccf3280f2f70e7 Mon Sep 17 00:00:00 2001 From: Magaonka Date: Thu, 29 Jan 2026 20:30:16 -0600 Subject: [PATCH 2/6] Fix overlapping partition specs in scaled_matmul partitioner Add _are_specs_overlapping() helper to correctly detect when partition specs share axis names. Fixes DuplicateSpecError when using nested sharding specs like ('dp', 'tp') that overlap with single specs like 'dp'. Prevents creating invalid PartitionSpecs with duplicate axis names in both input and output shardings. --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 9382fffc4012..0724a8d4525d 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -175,6 +175,12 @@ def _enable_all_reduce(lhs, rhs): _, n_spec, rhs_k_spec = rhs.spec return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None +def _are_specs_overlapping(lhs, rhs): + if lhs is None or rhs is None: + return False + lhs = (lhs,) if isinstance(lhs, str) else lhs + rhs = (rhs,) if isinstance(rhs, str) else rhs + return not set(lhs).isdisjoint(rhs) def _get_output_sharding(shardings): lhs, rhs = shardings[0], shardings[1] @@ -241,7 +247,8 @@ def named_sharding(lhs, rhs, lhs_specs, rhs_specs): lhs_specs[2] = None rhs_specs[2] = None m_spec, n_spec = lhs_specs[1], rhs_specs[1] - if m_spec == n_spec: + # Check if m_spec and n_spec share any axis names to avoid duplicates + if _are_specs_overlapping(m_spec, n_spec): rhs_specs[1] = None return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) @@ -259,7 +266,8 @@ def _supported_out_sharding(lhs, rhs, reduce_scatter_dim): out_n_spec = k_spec else: out_m_spec = m_spec - out_n_spec = n_spec if m_spec != n_spec else None + # Check if m_spec and n_spec share any axis names to avoid duplicates + out_n_spec = n_spec if not _are_specs_overlapping(m_spec, n_spec) else None return [NamedSharding(lhs.mesh, P(batch_spec, out_m_spec, out_n_spec))] From fe3ccbd736dab56771eb26645918c17b101a4913 Mon Sep 17 00:00:00 2001 From: Magaonka Date: Fri, 30 Jan 2026 12:51:52 -0600 Subject: [PATCH 3/6] Skip cuDNN and Blackwell checks on ROCm devices Update ScaledMatmulTest setUp to only check cuDNN on CUDA devices. ROCm uses hipBLASLt or XLA fallback path. --- tests/scaled_matmul_stablehlo_test.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 75c26737ac25..70bb0caadcf1 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -272,12 +272,13 @@ class ScaledMatmulTest(jtu.JaxTestCase): def setUp(self): super().setUp() - try: - check_cudnn_version() - except RuntimeError as e: - self.skipTest(str(e)) - if not jtu.is_cuda_compute_capability_at_least("10.0"): - self.skipTest("Requires at least Blackwell arch") + if jtu.test_device_matches(["cuda"]): + try: + check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Requires at least Blackwell arch") mxfp8_configs = create_mxfp8_configs() From d498963a2b3b32a3b0716fbc1133a4188b0e80f0 Mon Sep 17 00:00:00 2001 From: Magaonka Date: Fri, 30 Jan 2026 12:52:25 -0600 Subject: [PATCH 4/6] Add platform-specific custom call target constants Define c_name_cuda and c_name_rocm for platform-specific matmul targets: - CUDA: __cudnn$blockScaledDot - ROCm: __cublas$lt$matmul$mx --- tests/scaled_matmul_stablehlo_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 70bb0caadcf1..2296a3ad6d36 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -46,7 +46,9 @@ ((None, None, "tp"), (None, "tp", None)), ((None, ("dp", "tp"), None), (None, ("dp"), None)), ] -c_name = "__cudnn$blockScaledDot" +c_name_cuda = "__cudnn$blockScaledDot" +c_name_rocm = "__cublas$lt$matmul$mx" +c_name = c_name_cuda expected_hlos = [ [("all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), (c_name,)], [("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], From 9e49d3f414b9209cdf0bb231fad43a6b5208a2b2 Mon Sep 17 00:00:00 2001 From: Magaonka Date: Fri, 30 Jan 2026 16:54:37 -0600 Subject: [PATCH 5/6] Add ROCm specific HLO calls and fallback handling Add platform detection to all 4 ScaledMatmulTest methods: - test_collectives - test_scaled_matmul_nvfp4 - test_scaled_matmul - test_scaled_matmul_sharded Primary path checks for hipBLASLt custom calls: - __cublas$lt$matmul$mx (MX format) - __cublas$lt$matmul (generic cublasLT) Fallback path checks for dequantize + matmul operations: - __triton_gemm - __cublas$gemm --- tests/scaled_matmul_stablehlo_test.py | 116 +++++++++++++++++++++----- 1 file changed, 96 insertions(+), 20 deletions(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 2296a3ad6d36..9a8d49623a86 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -293,15 +293,46 @@ def test_collectives(self, in_shardings, block_scale_configs): if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4: self.skipTest("Partition Test enabled for at least 4 GPUs") + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + expected_hlo = sharding_configs[in_shardings][0] + expected_hlo = [ + tuple(platform_c_name if x == c_name else x for x in pattern) + for pattern in expected_hlo + ] + hlo_text = get_hlo_text(in_shardings, block_scale_configs) for expected_hlo_patterns in expected_hlo: hlo_pattern_str = r".*".join(map(re.escape, expected_hlo_patterns)) hlo_pattern = re.compile(hlo_pattern_str, flags=re.DOTALL) - # Check all patterns in case of failures - with self.subTest(pattern=hlo_pattern_str): - self.assertRegex(hlo_text, hlo_pattern, msg=f"Failed to find pattern: {hlo_pattern_str}") + + if jtu.test_device_matches(["rocm"]): + # Try both MX and generic cublasLT variants + pattern_mx = re.compile(hlo_pattern_str, flags=re.DOTALL) + pattern_generic = re.compile( + r".*".join([re.escape(x) if x != platform_c_name else r"__cublas\$lt\$matmul" for x in expected_hlo_patterns]), + flags=re.DOTALL + ) + primary_matched = re.search(pattern_mx, hlo_text) or re.search(pattern_generic, hlo_text) + + if not primary_matched: + fallback_patterns = [ + re.compile(r".*".join([re.escape(x) if x != platform_c_name else r"(__triton_gemm|__cublas\$gemm)" for x in expected_hlo_patterns]), flags=re.DOTALL) + ] + pattern_matched = any(re.search(p, hlo_text) for p in fallback_patterns) + if not pattern_matched: + with self.subTest(pattern=hlo_pattern_str): + self.fail(f"Failed to find pattern: {hlo_pattern_str} or fallback matmul pattern") + else: + with self.subTest(pattern=hlo_pattern_str): + self.assertTrue(True) + else: + with self.subTest(pattern=hlo_pattern_str): + self.assertRegex(hlo_text, hlo_pattern, msg=f"Failed to find pattern: {hlo_pattern_str}") @jtu.sample_product( contract=[160, 96], @@ -341,10 +372,26 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): .compile() .as_text() ) + + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) - self.assertRegex(hlo_text, hlo_pattern) + r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)]) + ) + + if jtu.test_device_matches(["rocm"]): + # Try both MX and generic cublasLT variants + pattern_generic = re.compile(r"custom\-call.*__cublas\$lt\$matmul", flags=re.DOTALL) + primary_matched = re.search(hlo_pattern, hlo_text) or re.search(pattern_generic, hlo_text) + + if not primary_matched: + if "__triton_gemm" not in hlo_text and "__cublas$gemm" not in hlo_text: + self.fail(f"Expected {platform_c_name} or __cublas$lt$matmul or fallback (__triton_gemm/__cublas$gemm)") + else: + self.assertRegex(hlo_text, hlo_pattern) out = j_scaled_matmul(a_q, b_q, a_s, b_s) out_ref = jnp.einsum( @@ -385,10 +432,26 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): .compile() .as_text() ) + + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) - self.assertRegex(hlo_text, hlo_pattern) + r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)]) + ) + + if jtu.test_device_matches(["rocm"]): + # Try both MX and generic cublasLT variants + pattern_generic = re.compile(r"custom\-call.*__cublas\$lt\$matmul", flags=re.DOTALL) + primary_matched = re.search(hlo_pattern, hlo_text) or re.search(pattern_generic, hlo_text) + + if not primary_matched: + if "__triton_gemm" not in hlo_text and "__cublas$gemm" not in hlo_text: + self.fail(f"Expected {platform_c_name} or __cublas$lt$matmul or fallback (__triton_gemm/__cublas$gemm)") + else: + self.assertRegex(hlo_text, hlo_pattern) out = j_scaled_matmul(a_q, b_q, a_scales, b_scales) out_ref = np.einsum( @@ -433,10 +496,23 @@ def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs): scaled_matmul_wrapper, in_shardings=input_shardings ) hlo_compiled = j_scaled_matmul.lower(*args).compile() + hlo_text = hlo_compiled.as_text() + + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) + r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)]) ) - self.assertRegex(hlo_compiled.as_text(), hlo_pattern) + + if jtu.test_device_matches(["rocm"]) and not re.search(hlo_pattern, hlo_text): + fallback_found = "__triton_gemm" in hlo_text + if not fallback_found: + self.fail(f"Expected {platform_c_name} or fallback (__triton_gemm)") + else: + self.assertRegex(hlo_text, hlo_pattern) j_ref = jax.jit( partial( @@ -480,7 +556,7 @@ def setUp(self): (1024, 2048), ], ) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_quantize_nvfp4(self, shape): # To test the q-dq logic is valid with XLA output_type = jnp.float32 @@ -504,7 +580,7 @@ def fn(a): a, rtol=0.2, atol=0.5) @jtu.sample_product(value=[1e6, 1/4096]) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_quantize_requires_global_scale(self, value): output_type = jnp.float32 k1, k2 = jax.random.split(jax.random.key(0), 2) @@ -524,7 +600,7 @@ def test_quantize_requires_global_scale(self, value): ((30, 64), (100, 64), (([1], [1]), ([], []))), ] ) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): output_type = jnp.float32 (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( @@ -595,7 +671,7 @@ def fwd(a, b, use_normalized=False): ], output_type=[jnp.float32, jnp.float16, jnp.bfloat16], ) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_dot_general_nvfp4(self, configs, output_type): (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( generate_nvfp4_quantized_tensors(configs[:-1], output_type) @@ -673,7 +749,7 @@ def _grad_clip(amax, x, grad): ], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], ) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_dot_general(self, configs, output_type): cast_to_representable = partial( quantize_dequantize, @@ -722,7 +798,7 @@ def fwd(a, b, is_ref=False): self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) @jtu.sample_product(in_shardings=sharding_configs) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_dot_general_sharded(self, in_shardings): if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") @@ -793,7 +869,7 @@ def fwd(a, b, is_ref=False): ((2, 128, 128), (128, 2, 128), (0, 1, 2)), ] ) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_dot_general_vmap(self, configs): cast_to_representable = partial( quantize_dequantize, @@ -839,7 +915,7 @@ def fwd(a, b, is_ref=False): self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_remat_checkpoint_dots(self): input = jnp.ones((1, 128, 128)) config = create_nvfp4_configs([input])[0] @@ -874,7 +950,7 @@ def f(x): # Check that the custom backward for scaled_matmul is used. self.assertEqual(jaxpr.count('bwd=scaled_dot_bwd'), 1) - @jtu.run_on_devices("gpu") + @jtu.run_on_devices("cuda") def test_remat_checkpoint_dots_with_no_batch_dims(self): input = jnp.ones((1, 128, 128)) batched_input = jnp.ones((16, 128, 128)) From 816d5f4ed035872b889f3696fdd2e01adbaa0695 Mon Sep 17 00:00:00 2001 From: Magaonka Date: Fri, 30 Jan 2026 16:54:55 -0600 Subject: [PATCH 6/6] Enable scaled_matmul tests on ROCm platform Enable 4 test methods in ScaledMatmulTest: - test_collectives - test_scaled_matmul_nvfp4 - test_scaled_matmul - test_scaled_matmul_sharded --- tests/scaled_matmul_stablehlo_test.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 9a8d49623a86..95969481fc26 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -288,7 +288,7 @@ def setUp(self): in_shardings=sharding_configs, block_scale_configs=[mxfp8_configs,], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_collectives(self, in_shardings, block_scale_configs): if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4: self.skipTest("Partition Test enabled for at least 4 GPUs") @@ -339,7 +339,7 @@ def test_collectives(self, in_shardings, block_scale_configs): lhs_non_contract=[240, 100], dtype=[jnp.float32, jnp.bfloat16, jnp.float16], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_scaled_matmul_nvfp4( self, contract, lhs_non_contract, dtype, ): @@ -407,7 +407,7 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): dtype=[jnp.float16, jnp.bfloat16, jnp.float32], block_scale_configs=[mxfp8_configs,], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_scaled_matmul( self, contract, lhs_non_contract, dtype, block_scale_configs, ): @@ -465,7 +465,7 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): in_shardings=sharding_configs, block_scale_configs=[mxfp8_configs,], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs): if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") @@ -556,7 +556,7 @@ def setUp(self): (1024, 2048), ], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_quantize_nvfp4(self, shape): # To test the q-dq logic is valid with XLA output_type = jnp.float32 @@ -580,7 +580,7 @@ def fn(a): a, rtol=0.2, atol=0.5) @jtu.sample_product(value=[1e6, 1/4096]) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_quantize_requires_global_scale(self, value): output_type = jnp.float32 k1, k2 = jax.random.split(jax.random.key(0), 2) @@ -600,7 +600,7 @@ def test_quantize_requires_global_scale(self, value): ((30, 64), (100, 64), (([1], [1]), ([], []))), ] ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): output_type = jnp.float32 (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( @@ -671,7 +671,7 @@ def fwd(a, b, use_normalized=False): ], output_type=[jnp.float32, jnp.float16, jnp.bfloat16], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_dot_general_nvfp4(self, configs, output_type): (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( generate_nvfp4_quantized_tensors(configs[:-1], output_type) @@ -749,7 +749,7 @@ def _grad_clip(amax, x, grad): ], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_dot_general(self, configs, output_type): cast_to_representable = partial( quantize_dequantize, @@ -798,7 +798,7 @@ def fwd(a, b, is_ref=False): self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) @jtu.sample_product(in_shardings=sharding_configs) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_dot_general_sharded(self, in_shardings): if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") @@ -869,7 +869,7 @@ def fwd(a, b, is_ref=False): ((2, 128, 128), (128, 2, 128), (0, 1, 2)), ] ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_dot_general_vmap(self, configs): cast_to_representable = partial( quantize_dequantize, @@ -915,7 +915,7 @@ def fwd(a, b, is_ref=False): self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_remat_checkpoint_dots(self): input = jnp.ones((1, 128, 128)) config = create_nvfp4_configs([input])[0] @@ -950,7 +950,7 @@ def f(x): # Check that the custom backward for scaled_matmul is used. self.assertEqual(jaxpr.count('bwd=scaled_dot_bwd'), 1) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_remat_checkpoint_dots_with_no_batch_dims(self): input = jnp.ones((1, 128, 128)) batched_input = jnp.ones((16, 128, 128))