diff --git a/tests/profiler_test.py b/tests/profiler_test.py index a803e010859f..76fd57760865 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -32,6 +32,7 @@ import jax._src.test_util as jtu from jax._src import profiler +from jax._src.lib import jaxlib_extension_version from jax import jit @@ -508,16 +509,67 @@ def on_profile(): unittest.mock.ANY, ) - def test_advanced_configuration_getter(self): - options = jax.profiler.ProfileOptions() - advanced_config = { - "tpu_trace_mode": "TRACE_COMPUTE", - "tpu_num_sparse_cores_to_trace": 1, - "enableFwThrottleEvent": True, - } - options.advanced_configuration = advanced_config - returned_config = options.advanced_configuration - self.assertDictEqual(returned_config, advanced_config) + @jtu.run_on_devices("gpu") + @jtu.thread_unsafe_test() + def test_rocm_profiling(self): + """Test that ROCm profiling captures GPU kernel events.""" + # ROCm-only gate using supported API + from jax.extend import backend as jax_backend + + be = jax_backend.get_backend() + platform_version = getattr(be, "platform_version", "") or "" + if "rocm" not in platform_version.lower(): + self.skipTest(f"Not ROCm backend: {platform_version}") + + with tempfile.TemporaryDirectory() as tmpdir: + with jax.profiler.trace(tmpdir): + # Test multiple matmul shapes + shapes = [(32, 32), (64, 128), (256, 256), (512, 128)] + for i, (m, n) in enumerate(shapes): + x = jax.random.normal(jax.random.key(i * 2), (m, n)) + y = jax.random.normal(jax.random.key(i * 2 + 1), (n, m)) + jnp.dot(x, y).block_until_ready() + + proto_path = glob.glob( + os.path.join(tmpdir, "**/*.xplane.pb"), recursive=True + ) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Sanity check that serialized proto contains GPU traces + self.assertIn(b"/device:GPU", proto) + + @jtu.run_on_devices("gpu") + @jtu.thread_unsafe_test() + def test_rocm_kernel_details_in_trace_json(self): + """Test that ROCm profiling captures kernel_details in trace.json.gz.""" + # ROCm-only gate using supported API + from jax.extend import backend as jax_backend + + be = jax_backend.get_backend() + platform_version = getattr(be, "platform_version", "") or "" + if "rocm" not in platform_version.lower(): + self.skipTest(f"Not ROCm backend: {platform_version}") + + with tempfile.TemporaryDirectory() as tmpdir: + with jax.profiler.trace(tmpdir): + # Test multiple matmul shapes + shapes = [(64, 64), (128, 256), (512, 512), (1024, 256)] + for i, (m, n) in enumerate(shapes): + x = jax.random.normal(jax.random.key(i * 2), (m, n)) + y = jax.random.normal(jax.random.key(i * 2 + 1), (n, m)) + jnp.dot(x, y).block_until_ready() + + # Find and read trace.json.gz file + import gzip + trace_files = glob.glob( + os.path.join(tmpdir, "**/*.trace.json.gz"), recursive=True + ) + self.assertEqual(len(trace_files), 1) + with gzip.open(trace_files[0], "rt") as f: + trace_content = f.read() + # Sanity check that trace contains kernel_details + self.assertIn("kernel_details", trace_content) if __name__ == "__main__":