Skip to content

Commit da1effa

Browse files
committed
Update to MLX 0.24.2
1 parent af09a39 commit da1effa

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

deps/mlx

Submodule mlx updated 67 files

src/fast.cc

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,41 @@ mx::array Rope(const mx::array& x,
2020
}
2121
}
2222

23+
mx::array ScaledDotProductAttention(
24+
const mx::array& queries,
25+
const mx::array& keys,
26+
const mx::array& values,
27+
const float scale,
28+
const std::variant<std::monostate, std::string, mx::array>& mask,
29+
mx::StreamOrDevice s) {
30+
bool has_mask = !std::holds_alternative<std::monostate>(mask);
31+
bool has_str_mask =
32+
has_mask && std::holds_alternative<std::string>(mask);
33+
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
34+
35+
if (has_mask) {
36+
if (has_str_mask) {
37+
auto mask_str = std::get<std::string>(mask);
38+
if (mask_str != "causal") {
39+
std::ostringstream msg;
40+
msg << "[scaled_dot_product_attention] invalid mask option '"
41+
<< mask_str << "'. Must be 'causal', or an array.";
42+
throw std::invalid_argument(msg.str());
43+
}
44+
return mx::fast::scaled_dot_product_attention(
45+
queries, keys, values, scale, mask_str, {}, s);
46+
} else {
47+
auto mask_arr = std::get<mx::array>(mask);
48+
return mx::fast::scaled_dot_product_attention(
49+
queries, keys, values, scale, "", {mask_arr}, s);
50+
}
51+
52+
} else {
53+
return mx::fast::scaled_dot_product_attention(
54+
queries, keys, values, scale, "", {}, s);
55+
}
56+
}
57+
2358
} // namespace fast_ops
2459

2560
void InitFast(napi_env env, napi_value exports) {
@@ -30,5 +65,5 @@ void InitFast(napi_env env, napi_value exports) {
3065
"rmsNorm", &mx::fast::rms_norm,
3166
"layerNorm", &mx::fast::layer_norm,
3267
"rope", &fast_ops::Rope,
33-
"scaledDotProductAttention", &mx::fast::scaled_dot_product_attention);
68+
"scaledDotProductAttention", &fast_ops::ScaledDotProductAttention);
3469
}

tests/ops.spec.ts

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,28 @@ describe('ops', () => {
580580
});
581581

582582
it('logsumexp', () => {
583-
const x = mx.array([[1.0, 2.0], [3.0, 4.0]]);
584-
const expected = 4.44;
585-
assert.closeTo(mx.logsumexp(x).item() as number, expected, 0.01);
583+
const logsumexp = (x: mx.array, axes = undefined) => {
584+
const maxs = mx.max(x, axes, true);
585+
return mx.add(mx.log(mx.sum(mx.exp(mx.subtract(x, maxs)), axes, true)), maxs);
586+
}
587+
588+
let x = mx.array([[1.0, 2.0], [3.0, 4.0]]);
589+
assert.closeTo(mx.logsumexp(x).item() as number, logsumexp(x).item() as number, 1e-7);
590+
591+
x = mx.random.uniform(0, 1, [1025]);
592+
assertArrayAllTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)));
593+
594+
// Transposed
595+
x = mx.random.uniform(0, 1, [2, 2, 8]).swapaxes(0, 1);
596+
assertArrayAllTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)));
597+
598+
// Broadcast
599+
x = mx.broadcastTo(mx.random.uniform(0, 1, [2, 1, 8]), [2, 2, 8]);
600+
assertArrayAllTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)));
601+
602+
// Large
603+
x = mx.broadcastTo(mx.random.uniform(0, 1, [2, 1, 8]), [2, 2, 8]);
604+
assertArrayAllTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)));
586605
});
587606

588607
it('mean', () => {
@@ -1243,6 +1262,16 @@ describe('ops', () => {
12431262
const x = mx.full([n], -Infinity);
12441263
assertArrayAllTrue(mx.isnan(mx.softmax(x)));
12451264
}
1265+
1266+
// Transposed inputs.
1267+
const a = mx.random.uniform(0, 1, [32, 32, 32]);
1268+
const b = mx.softmax(a, -1);
1269+
const c = mx.softmax(a.transpose(1, 0, 2), -1).transpose(1, 0, 2);
1270+
assert.equal(mx.abs(mx.subtract(b, c)).max().item(), 0.0);
1271+
1272+
assert.throws(() => {
1273+
mx.softmax(mx.array(1.0), -1);
1274+
}, Error);
12461275
});
12471276

12481277
it('concatenate', function() {

0 commit comments

Comments
 (0)