From 6722b73b9316e3dc84c7eaa5f81794ea901eaa45 Mon Sep 17 00:00:00 2001 From: Jochen Kiemes Date: Sun, 7 Dec 2025 13:44:40 +0100 Subject: [PATCH] add MetalCompilerPlugin support --- Package.swift | 23 +- README.md | 2 + Source/Cmlx/.metal-compiler-plugin.json | 13 + Source/Cmlx/mlx-generated/metal/bf16.h | 304 +------------------ Source/Cmlx/mlx-generated/metal/random.metal | 103 ------- 5 files changed, 35 insertions(+), 410 deletions(-) create mode 100644 Source/Cmlx/.metal-compiler-plugin.json delete mode 100644 Source/Cmlx/mlx-generated/metal/random.metal diff --git a/Package.swift b/Package.swift index 49c39b2b..6b1b3d45 100644 --- a/Package.swift +++ b/Package.swift @@ -1,16 +1,19 @@ -// swift-tools-version: 5.10 +// swift-tools-version: 6.2 // The swift-tools-version declares the minimum version of Swift required to build this package. // Copyright © 2024 Apple Inc. +import Foundation import PackageDescription +let inXcode = ProcessInfo.processInfo.environment["XCODE_VERSION_ACTUAL"] != nil + let package = Package( name: "mlx-swift", platforms: [ - .macOS("13.3"), - .iOS(.v16), - .tvOS(.v16), + .macOS("14.0"), + .iOS(.v17), + .tvOS(.v17), .visionOS(.v1), ], @@ -26,7 +29,8 @@ let package = Package( ], dependencies: [ // for Complex type - .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0") + .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"), + .package(url: "https://github.com/schwa/MetalCompilerPlugin", branch: "main"), ], targets: [ .target( @@ -150,7 +154,14 @@ let package = Package( .linkedFramework("Foundation"), .linkedFramework("Metal"), .linkedFramework("Accelerate"), - ] + ], + + plugins: + // Optional: Use plugin for custom Metal compilation + // needed for swift build. Xcode does it automatically + inXcode ? [] : [ + .plugin(name: "MetalCompilerPlugin", package: "MetalCompilerPlugin") + ], ), .testTarget( name: "CmlxTests", diff --git a/README.md b/README.md index 3a71d45d..e1cb45a1 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ dependencies: [.product(name: "MLX", package: "mlx-swift"), > SwiftPM (command line) cannot build the Metal shaders so the ultimate build has to be done > via Xcode. +Update: Using [Metal Compiler Plugin](https://github.com/schwa/MetalCompilerPlugin), the library will be compiled and stored as default.metallib. + ### xcodebuild Although `SwiftPM` (command line) cannot build the Metal shaders, `xcodebuild` can and diff --git a/Source/Cmlx/.metal-compiler-plugin.json b/Source/Cmlx/.metal-compiler-plugin.json new file mode 100644 index 00000000..cf002e37 --- /dev/null +++ b/Source/Cmlx/.metal-compiler-plugin.json @@ -0,0 +1,13 @@ +{ + "xcrun": true, + "find-inputs": true, + "include-dependencies": false, + "dependency-path-suffix": "include", + "include-paths": ["mlx/mlx/backend/metal/kernels/metal_3_1", "mlx/mlx/backend/metal/kernels", "mlx"], + "output": "default.metallib", + "flags": ["-gline-tables-only", "-frecord-sources"], + "plugin-logging": false, + "verbose-logging": false, + "metal-enable-logging": false, + "logging-prefix": "[Metal]" +} diff --git a/Source/Cmlx/mlx-generated/metal/bf16.h b/Source/Cmlx/mlx-generated/metal/bf16.h index f5d48670..aa3c3c78 100644 --- a/Source/Cmlx/mlx-generated/metal/bf16.h +++ b/Source/Cmlx/mlx-generated/metal/bf16.h @@ -6,309 +6,11 @@ using namespace metal; -///////////////////////////////////////////////////////////////////////////// -// Helpers -///////////////////////////////////////////////////////////////////////////// - -constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { - // Check for nan - if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > - _fp_encoding_traits::inf_mask) { - return uint16_t(as_type(0x7FC0)); - } - // Take bits - uint32_t float_bits = as_type(x); - - // Round to nearest even - float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); - - // Take upper 16 bits - return float_bits >> 16; -} - -constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { - // Upper 16 bits are the data and lower 16 bits are 0s - return as_type((uint32_t)x << 16); -} - -struct _MLX_BFloat16; - -template -static constexpr constant bool can_convert_to_bfloat = - !is_same_v && is_convertible_v; - -template -static constexpr constant bool can_convert_from_bfloat = - !is_same_v && is_convertible_v; - -///////////////////////////////////////////////////////////////////////////// -// Bfloat struct -///////////////////////////////////////////////////////////////////////////// - -struct _MLX_BFloat16 { - ///////////////////////////////////////////////////////////////////////////// - // Constructors - uint16_t bits_; - _MLX_BFloat16() thread = default; - _MLX_BFloat16() threadgroup = default; - _MLX_BFloat16() device = default; - _MLX_BFloat16() constant = default; - - struct bits_to_bfloat_struct {}; - static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { - return bits_to_bfloat_struct(); - } - constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) - : bits_(bits) {} - - ///////////////////////////////////////////////////////////////////////////// - // Conversions to bfloat - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC _MLX_BFloat16(T x) thread - : bits_(float_to_bfloat_bits(static_cast(x))) {} - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup - : bits_(float_to_bfloat_bits(static_cast(x))) {} - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC _MLX_BFloat16(T x) device - : bits_(float_to_bfloat_bits(static_cast(x))) {} - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC _MLX_BFloat16(T x) constant - : bits_(float_to_bfloat_bits(static_cast(x))) {} - - ///////////////////////////////////////////////////////////////////////////// - // Conversions from bfloat - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC operator T() const thread { - return static_cast(bfloat_bits_to_float(bits_)); - } - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC operator T() const threadgroup { - return static_cast(bfloat_bits_to_float(bits_)); - } - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC operator T() const device { - return static_cast(bfloat_bits_to_float(bits_)); - } - - template < - typename T, - typename = typename enable_if>::type> - constexpr METAL_FUNC operator T() const constant { - return static_cast(bfloat_bits_to_float(bits_)); - } -}; - -///////////////////////////////////////////////////////////////////////////// -// Bfloat operators -///////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////// -// Unary ops -constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { - return -static_cast(x); -} - -///////////////////////////////////////////////////////////////////////////// -// Binary operators -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -///////////////////////////////////////////////////////////////////////////// -// Arithmetic Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, float, half, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); - -bfloat_binop(+, operator+); -bfloat_binop(-, operator-); -bfloat_binop(*, operator*); -bfloat_binop(/, operator/); - -///////////////////////////////////////////////////////////////////////////// -// Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, half, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); - -bfloat_compop(>, operator>); -bfloat_compop(<, operator<); -bfloat_compop(>=, operator>=); -bfloat_compop(<=, operator<=); -bfloat_compop(==, operator==); -bfloat_compop(!=, operator!=); - -#undef bfloat_compop -#undef bfloat_binop_base -#undef bfloat_binop_helper -#undef bfloat_binop - -///////////////////////////////////////////////////////////////////////////// -// Inplace Operators -#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, itype rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } \ - constexpr METAL_FUNC addr_space itype& __operator__( \ - addr_space itype& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } - -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ - bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ - bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ - bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); - -#define bfloat_inplace_op(itype) \ - bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ - bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ - bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ - bfloat_inplace_op_addr_space_helper(/, operator/=, itype); - -bfloat_inplace_op(float); -bfloat_inplace_op(half); -bfloat_inplace_op(int16_t); -bfloat_inplace_op(int32_t); -bfloat_inplace_op(int64_t); -bfloat_inplace_op(uint16_t); -bfloat_inplace_op(uint32_t); -bfloat_inplace_op(uint64_t); - -#undef bfloat_inplace_op_helper -#undef bfloat_inplace_op_addr_space_helper -#undef bfloat_inplace_op - -#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } - -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ - bfloat_inplace_op_helper(__op__, __operator__, device); \ - bfloat_inplace_op_helper(__op__, __operator__, thread); \ - bfloat_inplace_op_helper(__op__, __operator__, threadgroup); - -bfloat_inplace_op_addr_space_helper(+, operator+=); -bfloat_inplace_op_addr_space_helper(-, operator-=); -bfloat_inplace_op_addr_space_helper(*, operator*=); -bfloat_inplace_op_addr_space_helper(/, operator/=); - -#undef bfloat_inplace_op_helper -#undef bfloat_inplace_op_addr_space_helper - -///////////////////////////////////////////////////////////////////////////// -// Bfloat typedef -///////////////////////////////////////////////////////////////////////////// - -typedef struct _MLX_BFloat16 bfloat16_t; - -///////////////////////////////////////////////////////////////////////////// -// Bfloat numeric limits -///////////////////////////////////////////////////////////////////////////// - -#pragma METAL internals : enable - -namespace metal { - -template <> -struct _numeric_limits_impl : _fp_numeric_limits_impl_base { - static constexpr constant int digits = 8; - static constexpr constant int digits10 = 2; - static constexpr constant int max_digits10 = 4; - static constexpr constant int radix = 2; - static constexpr constant int min_exponent = -125; - static constexpr constant int min_exponent10 = -37; - static constexpr constant int max_exponent = 128; - static constexpr constant int max_exponent10 = 38; - - static constexpr bfloat16_t min() { - return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t lowest() { - return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t max() { - return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t epsilon() { - return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t round_error() { - return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t infinity() { - return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t quiet_NaN() { - return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t signaling_NaN() { - return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); - } - static constexpr bfloat16_t denorm_min() { - return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat()); - } -}; - -METAL_FUNC bool isnan(_MLX_BFloat16 x) { - return x != x; -} - -} // namespace metal - -#pragma METAL internals : disable +typedef bfloat bfloat16_t; inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { - return x.bits_; + return as_type(x); } inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { - return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()); + return as_type(x); } diff --git a/Source/Cmlx/mlx-generated/metal/random.metal b/Source/Cmlx/mlx-generated/metal/random.metal deleted file mode 100644 index eb6234d8..00000000 --- a/Source/Cmlx/mlx-generated/metal/random.metal +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include "utils.h" - -static constexpr constant uint32_t rotations[2][4] = { - {13, 15, 26, 6}, - {17, 29, 16, 24}}; - -union rbits { - uint2 val; - uchar4 bytes[2]; -}; - -rbits threefry2x32_hash(const thread uint2& key, uint2 count) { - uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; - - rbits v; - v.val.x = count.x + ks[0]; - v.val.y = count.y + ks[1]; - - for (int i = 0; i < 5; ++i) { - for (auto r : rotations[i % 2]) { - v.val.x += v.val.y; - v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); - v.val.y ^= v.val.x; - } - v.val.x += ks[(i + 1) % 3]; - v.val.y += ks[(i + 2) % 3] + i + 1; - } - - return v; -} - -[[kernel]] void rbitsc( - device const uint32_t* keys, - device char* out, - constant const bool& odd, - constant const uint& bytes_per_key, - uint2 grid_dim [[threads_per_grid]], - uint2 index [[thread_position_in_grid]]) { - auto kidx = 2 * index.x; - auto key = uint2(keys[kidx], keys[kidx + 1]); - auto half_size = grid_dim.y - odd; - out += index.x * bytes_per_key; - bool drop_last = odd && (index.y == half_size); - auto bits = threefry2x32_hash( - key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); - size_t idx = size_t(index.y) << 2; - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[0][i]; - } - if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { - int edge_bytes = (bytes_per_key % 4); - for (int i = 0; i < edge_bytes; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } else { - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } - } -} - -[[kernel]] void rbits( - device const uint32_t* keys, - device char* out, - constant const bool& odd, - constant const uint& bytes_per_key, - constant const int& ndim, - constant const int* key_shape, - constant const int64_t* key_strides, - uint2 grid_dim [[threads_per_grid]], - uint2 index [[thread_position_in_grid]]) { - auto kidx = 2 * index.x; - auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim); - auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); - auto key = uint2(keys[k1_elem], keys[k2_elem]); - auto half_size = grid_dim.y - odd; - out += size_t(index.x) * bytes_per_key; - bool drop_last = odd && (index.y == half_size); - auto bits = threefry2x32_hash( - key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); - size_t idx = size_t(index.y) << 2; - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[0][i]; - } - if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { - int edge_bytes = (bytes_per_key % 4); - for (int i = 0; i < edge_bytes; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } else { - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } - } -}