Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions tests/profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import jax._src.test_util as jtu

from jax._src import profiler
from jax._src.lib import jaxlib_extension_version
from jax import jit


Expand Down Expand Up @@ -508,16 +509,67 @@ def on_profile():
unittest.mock.ANY,
)

def test_advanced_configuration_getter(self):
options = jax.profiler.ProfileOptions()
advanced_config = {
"tpu_trace_mode": "TRACE_COMPUTE",
"tpu_num_sparse_cores_to_trace": 1,
"enableFwThrottleEvent": True,
}
options.advanced_configuration = advanced_config
returned_config = options.advanced_configuration
self.assertDictEqual(returned_config, advanced_config)
@jtu.run_on_devices("gpu")
@jtu.thread_unsafe_test()
def test_rocm_profiling(self):
"""Test that ROCm profiling captures GPU kernel events."""
# ROCm-only gate using supported API
from jax.extend import backend as jax_backend

be = jax_backend.get_backend()
platform_version = getattr(be, "platform_version", "") or ""
if "rocm" not in platform_version.lower():
self.skipTest(f"Not ROCm backend: {platform_version}")

with tempfile.TemporaryDirectory() as tmpdir:
with jax.profiler.trace(tmpdir):
# Test multiple matmul shapes
shapes = [(32, 32), (64, 128), (256, 256), (512, 128)]
for i, (m, n) in enumerate(shapes):
x = jax.random.normal(jax.random.key(i * 2), (m, n))
y = jax.random.normal(jax.random.key(i * 2 + 1), (n, m))
jnp.dot(x, y).block_until_ready()

proto_path = glob.glob(
os.path.join(tmpdir, "**/*.xplane.pb"), recursive=True
)
self.assertEqual(len(proto_path), 1)
with open(proto_path[0], "rb") as f:
proto = f.read()
# Sanity check that serialized proto contains GPU traces
self.assertIn(b"/device:GPU", proto)

@jtu.run_on_devices("gpu")
@jtu.thread_unsafe_test()
def test_rocm_kernel_details_in_trace_json(self):
"""Test that ROCm profiling captures kernel_details in trace.json.gz."""
# ROCm-only gate using supported API
from jax.extend import backend as jax_backend

be = jax_backend.get_backend()
platform_version = getattr(be, "platform_version", "") or ""
if "rocm" not in platform_version.lower():
self.skipTest(f"Not ROCm backend: {platform_version}")

with tempfile.TemporaryDirectory() as tmpdir:
with jax.profiler.trace(tmpdir):
# Test multiple matmul shapes
shapes = [(64, 64), (128, 256), (512, 512), (1024, 256)]
for i, (m, n) in enumerate(shapes):
x = jax.random.normal(jax.random.key(i * 2), (m, n))
y = jax.random.normal(jax.random.key(i * 2 + 1), (n, m))
jnp.dot(x, y).block_until_ready()

# Find and read trace.json.gz file
import gzip
trace_files = glob.glob(
os.path.join(tmpdir, "**/*.trace.json.gz"), recursive=True
)
self.assertEqual(len(trace_files), 1)
with gzip.open(trace_files[0], "rt") as f:
trace_content = f.read()
# Sanity check that trace contains kernel_details
self.assertIn("kernel_details", trace_content)


if __name__ == "__main__":
Expand Down