diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 9382fffc4012..0724a8d4525d 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -175,6 +175,12 @@ def _enable_all_reduce(lhs, rhs): _, n_spec, rhs_k_spec = rhs.spec return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None +def _are_specs_overlapping(lhs, rhs): + if lhs is None or rhs is None: + return False + lhs = (lhs,) if isinstance(lhs, str) else lhs + rhs = (rhs,) if isinstance(rhs, str) else rhs + return not set(lhs).isdisjoint(rhs) def _get_output_sharding(shardings): lhs, rhs = shardings[0], shardings[1] @@ -241,7 +247,8 @@ def named_sharding(lhs, rhs, lhs_specs, rhs_specs): lhs_specs[2] = None rhs_specs[2] = None m_spec, n_spec = lhs_specs[1], rhs_specs[1] - if m_spec == n_spec: + # Check if m_spec and n_spec share any axis names to avoid duplicates + if _are_specs_overlapping(m_spec, n_spec): rhs_specs[1] = None return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) @@ -259,7 +266,8 @@ def _supported_out_sharding(lhs, rhs, reduce_scatter_dim): out_n_spec = k_spec else: out_m_spec = m_spec - out_n_spec = n_spec if m_spec != n_spec else None + # Check if m_spec and n_spec share any axis names to avoid duplicates + out_n_spec = n_spec if not _are_specs_overlapping(m_spec, n_spec) else None return [NamedSharding(lhs.mesh, P(batch_spec, out_m_spec, out_n_spec))] diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index a88e3ca772d5..95969481fc26 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -44,17 +44,21 @@ ((None, "dp", "tp"), (None, "dp", "tp")), ((None, "tp", None), (None, "tp", None)), ((None, None, "tp"), (None, "tp", None)), + ((None, ("dp", "tp"), None), (None, ("dp"), None)), ] -c_name = "__cudnn$blockScaledDot" +c_name_cuda = "__cudnn$blockScaledDot" +c_name_rocm = "__cublas$lt$matmul$mx" +c_name = c_name_cuda expected_hlos = [ - (c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), - ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), - (c_name,), - ("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]", c_name), - (c_name,), - ("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name), + [("all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), (c_name,)], + [("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [(c_name,)], + [("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [(c_name,)], + [("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)], + [("all-gather", "f8e4m3fn[2,256,1024]", "replica_groups=[2,2]<=[2,2]"), (c_name,)], ] expected_output_spec = [ PartitionSpec('dp',), @@ -65,13 +69,14 @@ PartitionSpec(None, 'dp'), PartitionSpec(None, 'tp', None), PartitionSpec(None, None, 'tp'), + PartitionSpec(None, ('dp', 'tp'), None), ] # The GSPMD sharding logic inserts additional reduce-scatters which don't exist # in Shardy. if not config.use_shardy_partitioner.value: expected_output_spec[5] = PartitionSpec(None, 'dp', 'tp') - expected_hlos[5] += ("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}") + expected_hlos[5] += [("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}")] sharding_configs = { input_sharding: (hlo, output_spec) @@ -269,12 +274,13 @@ class ScaledMatmulTest(jtu.JaxTestCase): def setUp(self): super().setUp() - try: - check_cudnn_version() - except RuntimeError as e: - self.skipTest(str(e)) - if not jtu.is_cuda_compute_capability_at_least("10.0"): - self.skipTest("Requires at least Blackwell arch") + if jtu.test_device_matches(["cuda"]): + try: + check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Requires at least Blackwell arch") mxfp8_configs = create_mxfp8_configs() @@ -282,27 +288,58 @@ def setUp(self): in_shardings=sharding_configs, block_scale_configs=[mxfp8_configs,], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_collectives(self, in_shardings, block_scale_configs): if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4: self.skipTest("Partition Test enabled for at least 4 GPUs") + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + expected_hlo = sharding_configs[in_shardings][0] + expected_hlo = [ + tuple(platform_c_name if x == c_name else x for x in pattern) + for pattern in expected_hlo + ] + hlo_text = get_hlo_text(in_shardings, block_scale_configs) - hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL - ) - self.assertRegex( - hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}" - ) + for expected_hlo_patterns in expected_hlo: + hlo_pattern_str = r".*".join(map(re.escape, expected_hlo_patterns)) + hlo_pattern = re.compile(hlo_pattern_str, flags=re.DOTALL) + + if jtu.test_device_matches(["rocm"]): + # Try both MX and generic cublasLT variants + pattern_mx = re.compile(hlo_pattern_str, flags=re.DOTALL) + pattern_generic = re.compile( + r".*".join([re.escape(x) if x != platform_c_name else r"__cublas\$lt\$matmul" for x in expected_hlo_patterns]), + flags=re.DOTALL + ) + primary_matched = re.search(pattern_mx, hlo_text) or re.search(pattern_generic, hlo_text) + + if not primary_matched: + fallback_patterns = [ + re.compile(r".*".join([re.escape(x) if x != platform_c_name else r"(__triton_gemm|__cublas\$gemm)" for x in expected_hlo_patterns]), flags=re.DOTALL) + ] + pattern_matched = any(re.search(p, hlo_text) for p in fallback_patterns) + if not pattern_matched: + with self.subTest(pattern=hlo_pattern_str): + self.fail(f"Failed to find pattern: {hlo_pattern_str} or fallback matmul pattern") + else: + with self.subTest(pattern=hlo_pattern_str): + self.assertTrue(True) + else: + with self.subTest(pattern=hlo_pattern_str): + self.assertRegex(hlo_text, hlo_pattern, msg=f"Failed to find pattern: {hlo_pattern_str}") @jtu.sample_product( contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float32, jnp.bfloat16, jnp.float16], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_scaled_matmul_nvfp4( self, contract, lhs_non_contract, dtype, ): @@ -335,10 +372,26 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): .compile() .as_text() ) + + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) - self.assertRegex(hlo_text, hlo_pattern) + r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)]) + ) + + if jtu.test_device_matches(["rocm"]): + # Try both MX and generic cublasLT variants + pattern_generic = re.compile(r"custom\-call.*__cublas\$lt\$matmul", flags=re.DOTALL) + primary_matched = re.search(hlo_pattern, hlo_text) or re.search(pattern_generic, hlo_text) + + if not primary_matched: + if "__triton_gemm" not in hlo_text and "__cublas$gemm" not in hlo_text: + self.fail(f"Expected {platform_c_name} or __cublas$lt$matmul or fallback (__triton_gemm/__cublas$gemm)") + else: + self.assertRegex(hlo_text, hlo_pattern) out = j_scaled_matmul(a_q, b_q, a_s, b_s) out_ref = jnp.einsum( @@ -354,7 +407,7 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): dtype=[jnp.float16, jnp.bfloat16, jnp.float32], block_scale_configs=[mxfp8_configs,], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_scaled_matmul( self, contract, lhs_non_contract, dtype, block_scale_configs, ): @@ -379,10 +432,26 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): .compile() .as_text() ) + + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) - self.assertRegex(hlo_text, hlo_pattern) + r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)]) + ) + + if jtu.test_device_matches(["rocm"]): + # Try both MX and generic cublasLT variants + pattern_generic = re.compile(r"custom\-call.*__cublas\$lt\$matmul", flags=re.DOTALL) + primary_matched = re.search(hlo_pattern, hlo_text) or re.search(pattern_generic, hlo_text) + + if not primary_matched: + if "__triton_gemm" not in hlo_text and "__cublas$gemm" not in hlo_text: + self.fail(f"Expected {platform_c_name} or __cublas$lt$matmul or fallback (__triton_gemm/__cublas$gemm)") + else: + self.assertRegex(hlo_text, hlo_pattern) out = j_scaled_matmul(a_q, b_q, a_scales, b_scales) out_ref = np.einsum( @@ -396,7 +465,7 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): in_shardings=sharding_configs, block_scale_configs=[mxfp8_configs,], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs): if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") @@ -427,10 +496,23 @@ def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs): scaled_matmul_wrapper, in_shardings=input_shardings ) hlo_compiled = j_scaled_matmul.lower(*args).compile() + hlo_text = hlo_compiled.as_text() + + if jtu.test_device_matches(["rocm"]): + platform_c_name = c_name_rocm + else: + platform_c_name = c_name_cuda + hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) + r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)]) ) - self.assertRegex(hlo_compiled.as_text(), hlo_pattern) + + if jtu.test_device_matches(["rocm"]) and not re.search(hlo_pattern, hlo_text): + fallback_found = "__triton_gemm" in hlo_text + if not fallback_found: + self.fail(f"Expected {platform_c_name} or fallback (__triton_gemm)") + else: + self.assertRegex(hlo_text, hlo_pattern) j_ref = jax.jit( partial(