Skip to content

[CUDA][Performance] Topk is slow #3064

@awni

Description

@awni

MLX topk on CUDA is pretty slow in some cases (especially compared to PyTorch).

Here is a benchmark:

import time
import mlx.core as mx

b = 2048
v = 8192
k = 32

q = mx.random.normal(shape=(b, v)).astype(mx.bfloat16)

def fun(q):
    for _ in range(50):
        idx = mx.argpartition(-q, kth=k-1, axis=-1)[:, :k]
        values = mx.take_along_axis(q, idx, axis=-1)
        q = mx.put_along_axis(q, idx, values, axis=-1)
    mx.eval(q)

for _ in range(20):
    fun(q)

tic = time.time()
for _ in range(20):
    fun(q)
toc = time.time()
ms = 1e3 * (toc - tic)
print(f"MLX {ms=:.3f}")

import torch


q = torch.randn(size=(b, v)).to("cuda").to(torch.bfloat16)
def topk_old(q):
    return idx, values


def fun(q):
    for _ in range(50):
        values, idx = torch.topk(q, k=k, axis=-1)
        q = torch.scatter(q, -1, idx, values)
    torch.cuda.synchronize()

for _ in range(20):
    fun(q)

tic = time.time()
for _ in range(20):
    fun(q)
toc = time.time()
ms = 1e3 * (toc - tic)
print(f"PyTorch {ms=:.3f}")

On a spark:

MLX ms=2975.014
PyTorch ms=919.764

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions