-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Problem Description
Description
I encountered a massive performance regression on MI300X compared to H100 when performing many small float64 determinants (slogdet) inside a vmap (MCMC loop). Will impact many QMC codes
- Operation:
jnp.linalg.slogdeton shapes like(Batch, 1, 1)or(Batch, 2, 2). - Context: Unrolled loop in JAX (MCMC steps).
- Observed Behavior:
- H100 (NVIDIA): ~5 seconds per step.
- MI300X (AMD): ~1000 seconds per step (100% GPU Utilization reported).
- Analytic Fix: Replacing
slogdetwith analytic math ('a') or (ad - bc) for N=1 or 2 dropped runtime on MI300X by factor of 140. Still not as fast as the H100 but much better.
Hypothesis
It appears rocSOLVER or the XLA backend lacks a specialized path for small batched float64 matrices. It seems to launch a full-weight solver kernel for every tiny operation, or falls back to an extremely inefficient scalar implementation, causing massive launch overhead or serialization. I have a trace file from the jax profiler for a step if that is helpful.
I tried to produce a snippet to reproduce a 6x slowdown for N=2 and 5000x for N=1. The real-world impact was a mix of these (embedded in a complex call graph). I haven't run the benchmark on the H100 yet.
Environment
MI300X
AMD-SMI 26.1.0+5df6c765
amdgpu version: 6.16.6
ROCm version: 7.1.0
VBIOS version: 00123529 |
Platform: Linux Guest (AMD dev cloud)
jax 0.7.1
jax-rocm7-pjrt 0.7.1
jax-rocm7-plugin 0.7.1
jaxlib 0.7.1
** Reproduce **
Operating System
24.04.3 LTS (Noble Numbat)
CPU
INTEL(R) XEON(R) PLATINUM 8568Y+
GPU
AMD Instinct MI300X VF
ROCm Version
7.1.0
ROCm Component
No response
Steps to Reproduce
import jax
import jax.numpy as jnp
import time
CRITICAL: Enable float64.
The slow path on MI300X is specific to double-precision solver kernels.
jax.config.update("jax_enable_x64", True)
Configuration matching your FermiNet/MCMC workload
N_BATCH = 4096
N_DETS = 64 # (Batch, Determinants, N, N)
N_DIM = 1 # The "killer" dimension (small matrix)
DTYPE = jnp.float64
def hard_path_loss(x):
"""
The Slow Way: generic slogdet.
Forward pass is okay, but BACKWARD pass requires A^-1 (triangular_solve).
This hits the unoptimized kernel path on ROCm for small N.
"""
signs, logdets = jnp.linalg.slogdet(x)
return jnp.sum(logdets)
def easy_path_loss(x):
"""
The Fast Way: Analytic determinant for 2x2.
det = ad - bc.
JAX fuses this entire operation (and its gradient) into a single kernel.
No solver launch required.
"""
if N_DIM==1:
det = x[..., 0, 0]
logdet = jnp.log(jnp.abs(det) + 1e-20)
return jnp.sum(logdet)
# x shape: (..., 2, 2)
a = x[..., 0, 0]
b = x[..., 0, 1]
c = x[..., 1, 0]
d = x[..., 1, 1]
det = (a * d) - (b * c)
logdet = jnp.log(jnp.abs(det) + 1e-20)
return jnp.sum(logdet)
We benchmark the GRADIENT calculation, as that's what happens in MCMC training
grad_hard = jax.jit(jax.grad(hard_path_loss))
grad_easy = jax.jit(jax.grad(easy_path_loss))
def benchmark(name, fn, x_input, n_iter=10):
# Compile and Warmup
print(f"Compiling {name}...")
warmup = fn.lower(x_input).compile()
_ = warmup(x_input).block_until_ready()
print(f"Running {name} ({n_iter} iterations)...")
start = time.time()
for _ in range(n_iter):
out = warmup(x_input)
out.block_until_ready() # Force GPU sync
end = time.time()
avg_time = (end - start) / n_iter
print(f" -> Avg Time: {avg_time:.6f} s")
return avg_time
if name == "main":
print(f"Backend: {jax.devices()[0].platform.upper()}")
print(f"Input Shape: {(N_BATCH, N_DETS, N_DIM, N_DIM)} | Dtype: {DTYPE}")
print("-" * 40)
# Generate Data
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (N_BATCH, N_DETS, N_DIM, N_DIM), dtype=DTYPE)
# Run Benchmarks
t_hard = benchmark("Hard Path (slogdet)", grad_hard, x)
t_easy = benchmark("Easy Path (ad-bc)", grad_easy, x)
# Report
print("-" * 40)
print(f"Speedup Factor: {t_hard / t_easy:.1f}x")
if t_hard > 0.1 and t_easy < 0.01:
print("CONCLUSION: Reproduced. Analytic path avoids slow kernel launch.")
else:
print("CONCLUSION: Performance gap not observed.")
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
ROCk module version 6.16.6 is loaded
HSA System Attributes
Runtime Version: 1.18
Runtime Ext Version: 1.14
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
XNACK enabled: NO
DMAbuf Support: YES
VMM Support: YES
==========
HSA Agents
Agent 1
Name: INTEL(R) XEON(R) PLATINUM 8568Y+
Uuid: CPU-XX
Marketing Name: INTEL(R) XEON(R) PLATINUM 8568Y+
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 0
BDFID: 0
Internal Node ID: 0
Compute Unit: 20
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Memory Properties:
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 4
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:
Agent 2
Name: gfx942
Uuid: GPU-f2250ba5c72e9953
Marketing Name: AMD Instinct MI300X VF
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 4096(0x1000) KB
L3: 262144(0x40000) KB
Chip ID: 29877(0x74b5)
ASIC Revision: 1(0x1)
Cacheline Size: 128(0x80)
Max Clock Freq. (MHz): 2100
BDFID: 33536
Internal Node ID: 1
Compute Unit: 304
SIMDs per CU: 4
Shader Engines: 32
Shader Arrs. per Eng.: 1
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Memory Properties:
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 64(0x40)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 2048(0x800)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 189
SDMA engine uCode:: 25
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 200998912(0xbfb0000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 200998912(0xbfb0000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 200998912(0xbfb0000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 4
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Recommended Granule:0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx942:sramecc+:xnack-
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
ISA 2
Name: amdgcn-amd-amdhsa--gfx9-4-generic:sramecc+:xnack-
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
*** Done ***
Additional Information
No response
Metadata
Metadata
Assignees
Labels
Type
Projects
Status