Skip to content

[MI300X] 200x Slowdown on small float64 slogdets vs H100 (Launch Overhead/Fallback) #259

@StephenJHardy

Description

@StephenJHardy

Problem Description

Description
I encountered a massive performance regression on MI300X compared to H100 when performing many small float64 determinants (slogdet) inside a vmap (MCMC loop). Will impact many QMC codes

  • Operation: jnp.linalg.slogdet on shapes like (Batch, 1, 1) or (Batch, 2, 2).
  • Context: Unrolled loop in JAX (MCMC steps).
  • Observed Behavior:
    • H100 (NVIDIA): ~5 seconds per step.
    • MI300X (AMD): ~1000 seconds per step (100% GPU Utilization reported).
    • Analytic Fix: Replacing slogdet with analytic math ('a') or (ad - bc) for N=1 or 2 dropped runtime on MI300X by factor of 140. Still not as fast as the H100 but much better.

Hypothesis
It appears rocSOLVER or the XLA backend lacks a specialized path for small batched float64 matrices. It seems to launch a full-weight solver kernel for every tiny operation, or falls back to an extremely inefficient scalar implementation, causing massive launch overhead or serialization. I have a trace file from the jax profiler for a step if that is helpful.

I tried to produce a snippet to reproduce a 6x slowdown for N=2 and 5000x for N=1. The real-world impact was a mix of these (embedded in a complex call graph). I haven't run the benchmark on the H100 yet.

Environment

MI300X

AMD-SMI 26.1.0+5df6c765
amdgpu version: 6.16.6
ROCm version: 7.1.0
VBIOS version: 00123529 |
Platform: Linux Guest (AMD dev cloud)

jax 0.7.1
jax-rocm7-pjrt 0.7.1
jax-rocm7-plugin 0.7.1
jaxlib 0.7.1

** Reproduce **

Operating System

24.04.3 LTS (Noble Numbat)

CPU

INTEL(R) XEON(R) PLATINUM 8568Y+

GPU

AMD Instinct MI300X VF

ROCm Version

7.1.0

ROCm Component

No response

Steps to Reproduce

import jax
import jax.numpy as jnp
import time

CRITICAL: Enable float64.

The slow path on MI300X is specific to double-precision solver kernels.

jax.config.update("jax_enable_x64", True)

Configuration matching your FermiNet/MCMC workload

N_BATCH = 4096
N_DETS = 64 # (Batch, Determinants, N, N)
N_DIM = 1 # The "killer" dimension (small matrix)
DTYPE = jnp.float64

def hard_path_loss(x):
"""
The Slow Way: generic slogdet.
Forward pass is okay, but BACKWARD pass requires A^-1 (triangular_solve).
This hits the unoptimized kernel path on ROCm for small N.
"""
signs, logdets = jnp.linalg.slogdet(x)
return jnp.sum(logdets)

def easy_path_loss(x):
"""
The Fast Way: Analytic determinant for 2x2.
det = ad - bc.
JAX fuses this entire operation (and its gradient) into a single kernel.
No solver launch required.
"""
if N_DIM==1:
det = x[..., 0, 0]
logdet = jnp.log(jnp.abs(det) + 1e-20)
return jnp.sum(logdet)

# x shape: (..., 2, 2)                                                                                                                                                                           
a = x[..., 0, 0]
b = x[..., 0, 1]
c = x[..., 1, 0]
d = x[..., 1, 1]

det = (a * d) - (b * c)
logdet = jnp.log(jnp.abs(det) + 1e-20)
return jnp.sum(logdet)

We benchmark the GRADIENT calculation, as that's what happens in MCMC training

grad_hard = jax.jit(jax.grad(hard_path_loss))
grad_easy = jax.jit(jax.grad(easy_path_loss))
def benchmark(name, fn, x_input, n_iter=10):
# Compile and Warmup
print(f"Compiling {name}...")
warmup = fn.lower(x_input).compile()
_ = warmup(x_input).block_until_ready()

print(f"Running {name} ({n_iter} iterations)...")
start = time.time()
for _ in range(n_iter):
    out = warmup(x_input)
    out.block_until_ready() # Force GPU sync                                                                                                                                                     
end = time.time()

avg_time = (end - start) / n_iter
print(f"  -> Avg Time: {avg_time:.6f} s")
return avg_time

if name == "main":
print(f"Backend: {jax.devices()[0].platform.upper()}")
print(f"Input Shape: {(N_BATCH, N_DETS, N_DIM, N_DIM)} | Dtype: {DTYPE}")
print("-" * 40)

# Generate Data                                                                                                                                                                                  
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (N_BATCH, N_DETS, N_DIM, N_DIM), dtype=DTYPE)

# Run Benchmarks                                                                                                                                                                                 
t_hard = benchmark("Hard Path (slogdet)", grad_hard, x)
t_easy = benchmark("Easy Path (ad-bc)", grad_easy, x)

# Report                                                                                                                                                                                         
print("-" * 40)
print(f"Speedup Factor: {t_hard / t_easy:.1f}x")
if t_hard > 0.1 and t_easy < 0.01:
    print("CONCLUSION: Reproduced. Analytic path avoids slow kernel launch.")
else:
    print("CONCLUSION: Performance gap not observed.")

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

ROCk module version 6.16.6 is loaded

HSA System Attributes

Runtime Version: 1.18
Runtime Ext Version: 1.14
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
XNACK enabled: NO
DMAbuf Support: YES
VMM Support: YES

==========
HSA Agents


Agent 1


Name: INTEL(R) XEON(R) PLATINUM 8568Y+
Uuid: CPU-XX
Marketing Name: INTEL(R) XEON(R) PLATINUM 8568Y+
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 0
BDFID: 0
Internal Node ID: 0
Compute Unit: 20
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Memory Properties:
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 4
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 247409248(0xebf2a60) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:


Agent 2


Name: gfx942
Uuid: GPU-f2250ba5c72e9953
Marketing Name: AMD Instinct MI300X VF
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 4096(0x1000) KB
L3: 262144(0x40000) KB
Chip ID: 29877(0x74b5)
ASIC Revision: 1(0x1)
Cacheline Size: 128(0x80)
Max Clock Freq. (MHz): 2100
BDFID: 33536
Internal Node ID: 1
Compute Unit: 304
SIMDs per CU: 4
Shader Engines: 32
Shader Arrs. per Eng.: 1
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Memory Properties:
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 64(0x40)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 2048(0x800)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 189
SDMA engine uCode:: 25
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 200998912(0xbfb0000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 200998912(0xbfb0000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 200998912(0xbfb0000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 4
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Recommended Granule:0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx942:sramecc+:xnack-
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
ISA 2
Name: amdgcn-amd-amdhsa--gfx9-4-generic:sramecc+:xnack-
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
*** Done ***

Additional Information

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions