Metal: Add deformable conv2d and performance optimizations #3355
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Adds Metal GPU kernels for deformable convolution 2D (DCNv2) on Apple Silicon, plus several performance optimizations for the Metal backend.
Deformable Conv2D Kernels
deformable_im2col: forward pass with bilinear sampling at learned offsetsdeformable_col2im: backward pass for input gradientsdeformable_col2im_coord: backward pass for offset and mask gradientsPerformance Optimizations
[B, N, C] + [C](~2.7x faster Linear layers)transpose_last2kernel for attention K^T pattern (~14x faster contiguous)Features
Benchmarks
Use Cases
Enables models that rely on deformable convolutions to run on Apple Silicon:
Origin
Deformable conv ported from mps-deform-conv, a standalone PyTorch MPS extension.
Test Plan
candle-metal-kernels/src/tests.rs