Skip to content

rebeccajae/metal-bindgen

Repository files navigation

metal-bindgen

When writing Metal kernel code you have to create a binding to that kernel in Swift. This is moderately annoying, and when a lot of kernels have similar calling conventions, you end up with a lot of probably copy-pasted bindings that may become buggy if you're not keeping track of them.

metal-bindgen exists as it solved my specific problems with Metal shaders and writing the wrappers for them.

I took some inspiration from the mechanisms that SWIG provides, but tried to make them as painless as possible. However, this tool does not, and should not strive to allow any arbitrary support code. If you need additional support code, you may have to write wrapper functions or modify the generated code.

Example

Take the below shader, it adds the number 2.0 to everything.

kernel void add_two(device const float* matrix [[buffer(0)]],
                    device float* output [[buffer(1)]],
                    constant uint& count [[buffer(2)]],
                    uint index [[thread_position_in_grid]]) {
    // Bail if this thread index is larger than the number of items.
    if (index >= count) return;
    output[index] = matrix[index] + 2.0;
}

You have to write a few lines of boilerplate to handle setting up the operation and dispatch it. Most of the information needed to construct this can be simplified into a handful of annotation comments.

With metal-bindgen you just add some annotation comments

// @swift_function_name(AddTwo)
// @swift_binding(matrix: MPSMatrix, output: MPSMatrix)
// @swift_bind_value(count: UInt32, matrix.rows * matrix.columns)
// @swift_threadgroup_geometry(32, 1, 1)
kernel void add_two(device const float* matrix [[buffer(0)]],
                    device float* output [[buffer(1)]],
                    constant uint& count [[buffer(2)]],
                    uint index [[thread_position_in_grid]]) {
    // Bail if this thread index is larger than the number of items.
    if (index >= count) return;
    output[index] = matrix[index] + 2.0;
}

and you get the binding code for free.

Usage

@swift_function_name(String)

A rename of the underlying kernel. If omitted, the resultant function name is the kernel name with the first letter uppercased.

@swift_binding(params...)

The signature of the resulting method call. MPSMatrix and MTLBuffer params are automatically routed to parameters that are named the same on the kernel.

@swift_bind_value(paramName: Type, expr)

Bind the given parameter to the result value of expr.

@swift_ensure(cond, errorMessage: String)

Guard on cond, if cond is false, throw .preconditionFailed(errorMessage).

Use these rarely, prefer wrapping generated code if you need complex preconditions.

@swift_threadgroup_geometry(w: Int, h: Int, d: Int)

Specific threadgroup geometry of kernel execution. If not specified, defaults to (32, 1, 1).

@swift_grid_size(w: expr, h: expr, d: expr)

Provides the expr value in the threadgroupsPerGrid calculation of

(expr + threadsPerThreadgroup.dimension - 1) / threadsPerThreadgroup.dimension

That is, a kernel with threadgroup geometry (32, 1, 1) and grid size (1, 1, 1) will have a width of

(1 + 32 - 1) / 32 = 32/32 = 1

This can be used to generate complex placements that depend on input and output dimensions.

Caveats

This probably doesn't cover all use cases. In fact, I'm pretty certain it can't cover them. I'll do what I can to support different use cases, but want to keep things simple-ish.

About

MPS binding code made easy!

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages