Skip to content

Conversation

@nastya236
Copy link
Contributor

@nastya236 nastya236 commented Feb 2, 2026

Fuse unary ops into reduction.

It will be useful for:

  • mx::max(mx::abs(x)) for nvfp4 quantization
  • clip_grad_norm
def bench(f, *args, **kwargs):
    for i in range(N_warmup):
        x = mx.eval(f(*args, **kwargs))

    s = time.perf_counter_ns()
    for i in range(N_iter_bench):
        x = mx.eval(f(*args, **kwargs))
    e = time.perf_counter_ns()
    return (e - s) * 1e-9

def abs_max(a):
    return mx.max(mx.abs(a))

shape = (4*4096, 11008)
x = mx.random.uniform(shape=shape)
y = mx.random.uniform(shape=shape)
f_com = mx.compile(abs_max)

time_fused = bench(f_com, x)
time_unfused = bench(abs_max, y)

Fused time: 0.001803 s
Unfused time: 0.004267 s

TODO:

  • Not sure if we want to: do the same for row_reduce and col_reduce, seems that we don't do unary + row/col_reduce very often
  • Metal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant