Skip to content

Commit 98ce375

Browse files
committed
Update to MLX 0.23.0
1 parent 99da298 commit 98ce375

File tree

7 files changed

+37
-4
lines changed

7 files changed

+37
-4
lines changed

deps/mlx

Submodule mlx updated 88 files

index.d.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export namespace core {
4545
const int64: Dtype;
4646
const float16: Dtype;
4747
const float32: Dtype;
48+
const float64: Dtype;
4849
const bfloat16: Dtype;
4950
const complex64: Dtype;
5051

@@ -201,6 +202,7 @@ export namespace core {
201202
function atleast3d(...arrays: array[]): array;
202203
function issubdtype(a: Dtype | DtypeCategory, b: Dtype | DtypeCategory): boolean;
203204
function bitwiseAnd(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
205+
function bitwiseInvert(array: ScalarOrArray, s?: StreamOrDevice): array;
204206
function bitwiseOr(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
205207
function bitwiseXor(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
206208
function broadcastArrays(arrays: array[], s?: StreamOrDevice): array[];
@@ -437,6 +439,10 @@ export namespace core {
437439
function svd(array: ScalarOrArray, s?: StreamOrDevice): array[];
438440
function eigvalsh(array: ScalarOrArray, uplo?: string, s?: StreamOrDevice): array[];
439441
function eigh(array: ScalarOrArray, uplo?: string, s?: StreamOrDevice): array[];
442+
function lu(array: ScalarOrArray, s?: StreamOrDevice): array[];
443+
function luFactor(array: ScalarOrArray, s?: StreamOrDevice): [array, array];
444+
function solve(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
445+
function solveTriangular(a: ScalarOrArray, b: ScalarOrArray, upper: boolean, s?: StreamOrDevice): array;
440446
}
441447

442448
// Fast operations.

lib/nn/layers/convolution.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ export class Conv2d extends Module {
123123

124124
override toStringExtra(): string {
125125
return `${this.weight.shape[3]}, ${this.weight.shape[0]}, ` +
126-
`kernelSize=${this.weight.shape.slice(1, 2)}, stride=${this.stride}, ` +
126+
`kernelSize=${this.weight.shape.slice(1, 3)}, stride=${this.stride}, ` +
127127
`padding=${this.padding}, dilation=${this.dilation}, ` +
128128
`groups=${this.groups}, bias=${!!this.bias}`;
129129
}
@@ -187,7 +187,7 @@ export class Conv3d extends Module {
187187

188188
override toStringExtra(): string {
189189
return `${this.weight.shape[3]}, ${this.weight.shape[0]}, ` +
190-
`kernelSize=${this.weight.shape.slice(1, 3)}, stride=${this.stride}, ` +
190+
`kernelSize=${this.weight.shape.slice(1, 4)}, stride=${this.stride}, ` +
191191
`padding=${this.padding}, dilation=${this.dilation}, ` +
192192
`bias=${!!this.bias}`;
193193
}

src/array.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ T JsTypedArrayToMxArray(napi_env env, napi_value value) {
231231
return CreateInstance<T>(static_cast<uint32_t*>(data), shape, mx::uint32);
232232
case napi_float32_array:
233233
return CreateInstance<T>(static_cast<float*>(data), shape, mx::float32);
234+
case napi_float64_array:
235+
return CreateInstance<T>(static_cast<double*>(data), shape, mx::float64);
234236
default:
235237
napi_throw_type_error(env, nullptr, "Unsupported TypedArray type.");
236238
return T();
@@ -330,6 +332,8 @@ auto VisitArrayData(F&& visitor, mx::array* a) {
330332
return visitor(a->data<mx::float16_t>());
331333
case mx::float32:
332334
return visitor(a->data<float>());
335+
case mx::float64:
336+
return visitor(a->data<double>());
333337
case mx::bfloat16:
334338
return visitor(a->data<mx::bfloat16_t>());
335339
case mx::complex64:
@@ -399,6 +403,7 @@ napi_value ToTypedArray(mx::array* a, napi_env env) {
399403
case mx::int16: type = napi_int16_array; break;
400404
case mx::int32: type = napi_int32_array; break;
401405
case mx::float32: type = napi_float32_array; break;
406+
case mx::float64: type = napi_float64_array; break;
402407
default:
403408
napi_throw_type_error(env, nullptr, "No matching TypedArray for dtype.");
404409
return nullptr;
@@ -565,6 +570,8 @@ napi_status Type<mx::Dtype>::ToNode(napi_env env,
565570
return ConvertToNode(env, &mx::float16, result);
566571
if (value == mx::float32)
567572
return ConvertToNode(env, &mx::float32, result);
573+
if (value == mx::float64)
574+
return ConvertToNode(env, &mx::float64, result);
568575
if (value == mx::bfloat16)
569576
return ConvertToNode(env, &mx::bfloat16, result);
570577
if (value == mx::complex64)
@@ -778,6 +785,7 @@ void InitArray(napi_env env, napi_value exports) {
778785
"int64", &mx::int64,
779786
"float16", &mx::float16,
780787
"float32", &mx::float32,
788+
"float64", &mx::float64,
781789
"bfloat16", &mx::bfloat16,
782790
"complex64", &mx::complex64);
783791

src/linalg.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,9 @@ void InitLinalg(napi_env env, napi_value exports) {
4747
"pinv", &mx::linalg::pinv,
4848
"cross", &linalg::Cross,
4949
"eigvalsh", &mx::linalg::eigvalsh,
50-
"eigh", &mx::linalg::eigvalsh);
50+
"eigh", &mx::linalg::eigvalsh,
51+
"lu", &mx::linalg::lu,
52+
"luFactor", &mx::linalg::lu_factor,
53+
"solve", &mx::linalg::solve,
54+
"solveTriangular", &mx::linalg::solve_triangular);
5155
}

src/ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ void InitOps(napi_env env, napi_value exports) {
906906
"atleast3d", NdOpWrapper(&mx::atleast_1d, &mx::atleast_3d),
907907
"issubdtype", &ops::IsSubDtype,
908908
"bitwiseAnd", BinOpWrapper(&mx::bitwise_and),
909+
"bitwiseInvert", &mx::bitwise_invert,
909910
"bitwiseOr", BinOpWrapper(&mx::bitwise_or),
910911
"bitwiseXor", BinOpWrapper(&mx::bitwise_xor),
911912
"leftShift", BinOpWrapper(&mx::left_shift),

tests/vmap.spec.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,4 +521,18 @@ describe('vmap', () => {
521521
let out = mx.vmap(fun, [0, 1, 1])(a, idx, upd);
522522
assert.deepEqual(out.shape, [4, 5, 1]);
523523
});
524+
525+
it('vmapSplitVmap', () => {
526+
const fun = (x: any) => {
527+
const [a, b] = mx.split(x, 2, 1);
528+
return mx.concatenate([b, a], 1);
529+
}
530+
531+
const x = mx.ones([5, 6, 7]);
532+
const y = mx.ones([5, 4, 6, 7]);
533+
const fx = fun(x);
534+
const fy = mx.vmap(fun, 1)(y);
535+
assert.deepEqual(fx.shape, [5, 6, 7]);
536+
assert.deepEqual(fy.shape, [4, 5, 6, 7]);
537+
});
524538
});

0 commit comments

Comments
 (0)