-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Labels
Description
MLX topk on CUDA is pretty slow in some cases (especially compared to PyTorch).
Here is a benchmark:
import time
import mlx.core as mx
b = 2048
v = 8192
k = 32
q = mx.random.normal(shape=(b, v)).astype(mx.bfloat16)
def fun(q):
for _ in range(50):
idx = mx.argpartition(-q, kth=k-1, axis=-1)[:, :k]
values = mx.take_along_axis(q, idx, axis=-1)
q = mx.put_along_axis(q, idx, values, axis=-1)
mx.eval(q)
for _ in range(20):
fun(q)
tic = time.time()
for _ in range(20):
fun(q)
toc = time.time()
ms = 1e3 * (toc - tic)
print(f"MLX {ms=:.3f}")
import torch
q = torch.randn(size=(b, v)).to("cuda").to(torch.bfloat16)
def topk_old(q):
return idx, values
def fun(q):
for _ in range(50):
values, idx = torch.topk(q, k=k, axis=-1)
q = torch.scatter(q, -1, idx, values)
torch.cuda.synchronize()
for _ in range(20):
fun(q)
tic = time.time()
for _ in range(20):
fun(q)
toc = time.time()
ms = 1e3 * (toc - tic)
print(f"PyTorch {ms=:.3f}")On a spark:
MLX ms=2975.014
PyTorch ms=919.764