diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 54700bdcdb..21f916363f 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -399,7 +399,6 @@ bool ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s) { if (s.device == Device::cpu) { @@ -460,7 +459,12 @@ void ScaledDotProductAttention::eval_gpu( } } -bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { +bool ScaledDotProductAttentionVJP::use_fallback( + const array& q, + Stream s, + bool has_mask, + bool has_sinks, + int /* n_kv_heads */) { // The frontend adds a padding mask when sequence length is not a multiple of // tile size. if (q.shape(2) % 128 != 0) { diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 7d010e6c8c..bf84ca17b7 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -53,7 +53,7 @@ build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) build_kernel(rope) -build_kernel(scaled_dot_product_attention sdpa_vector.h) +build_kernel(scaled_dot_product_attention sdpa_vector.h sdpa_vector_vjp.h) if(MLX_METAL_VERSION GREATER_EQUAL 320) build_kernel(fence) endif() diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..45484d5fb2 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -3,6 +3,7 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/sdpa_vector.h" +#include "mlx/backend/metal/kernels/sdpa_vector_vjp.h" using namespace metal; @@ -41,4 +42,42 @@ using namespace metal; instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +// SDPA vector VJP instantiations +#define instantiate_sdpa_vector_vjp(type, qk_dim, value_dim) \ + instantiate_kernel( \ + "sdpa_vector_vjp_" #type "_" #qk_dim "_" #value_dim, \ + sdpa_vector_vjp, \ + type, \ + qk_dim, \ + value_dim) + +// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel +#define instantiate_sdpa_vector_vjp_heads(type) \ + instantiate_sdpa_vector_vjp(type, 64, 64) \ + instantiate_sdpa_vector_vjp(type, 96, 96) \ + instantiate_sdpa_vector_vjp(type, 128, 128) + +instantiate_sdpa_vector_vjp_heads(float) +instantiate_sdpa_vector_vjp_heads(bfloat16_t) +instantiate_sdpa_vector_vjp_heads(float16_t) + +// SDPA vector VJP accumulate instantiations (for half/bfloat16 with float32 accumulators) +#define instantiate_sdpa_vector_vjp_accumulate(type, qk_dim, value_dim) \ + instantiate_kernel( \ + "sdpa_vector_vjp_accumulate_" #type "_" #qk_dim "_" #value_dim, \ + sdpa_vector_vjp_accumulate, \ + type, \ + qk_dim, \ + value_dim) + +// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel +#define instantiate_sdpa_vector_vjp_accumulate_heads(type) \ + instantiate_sdpa_vector_vjp_accumulate(type, 64, 64) \ + instantiate_sdpa_vector_vjp_accumulate(type, 96, 96) \ + instantiate_sdpa_vector_vjp_accumulate(type, 128, 128) + +// Note: Only instantiate for half/bfloat16 since float32 doesn't need accumulate variant +instantiate_sdpa_vector_vjp_accumulate_heads(bfloat16_t) +instantiate_sdpa_vector_vjp_accumulate_heads(float16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index cccd6dce1c..7554316f2f 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -81,8 +81,10 @@ template out += o_offset * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator + // Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp) + const U log2e_scale = static_cast(scale * M_LOG2E_F); for (int i = 0; i < qk_per_thread; i++) { - q[i] = static_cast(scale) * queries[i]; + q[i] = log2e_scale * queries[i]; } for (int i = 0; i < v_per_thread; i++) { o[i] = 0; @@ -91,7 +93,9 @@ template U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && simd_gid == 0) { - max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + // Scale sink by M_LOG2E_F to match log2 domain + max_score = static_cast(M_LOG2E_F) * + static_cast(sinks[q_batch_head_idx % num_q_heads]); sum_exp_score = 1; } @@ -118,13 +122,14 @@ template } score = simd_sum(score); if (float_mask) { - score += static_cast(fmask[0]); + // Scale float mask by M_LOG2E_F to match log2 domain + score += static_cast(M_LOG2E_F) * static_cast(fmask[0]); } - // Update the accumulators + // Update the accumulators (using exp2 to match STEEL attention) U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + U factor = fast::exp2(max_score - new_max); + U exp_score = fast::exp2(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -156,7 +161,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); max_score = max_scores[simd_lid]; U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); + U factor = fast::exp2(max_score - new_max); sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); // Now we need to aggregate all the outputs @@ -247,15 +252,18 @@ template sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; - // Read the query + // Read the query and 0 the output accumulator + // Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp) + const U log2e_scale = static_cast(scale * M_LOG2E_F); for (int i = 0; i < qk_per_thread; i++) { - q[i] = static_cast(scale) * queries[i]; + q[i] = log2e_scale * queries[i]; } U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && block_idx == 0) { - max_score = static_cast(sinks[q_head_idx]); + // Scale sink by M_LOG2E_F to match log2 domain + max_score = static_cast(M_LOG2E_F) * static_cast(sinks[q_head_idx]); sum_exp_score = 1; } @@ -278,13 +286,14 @@ template score = simd_sum(score); if (float_mask) { - score += fmask[0]; + // Scale float mask by M_LOG2E_F to match log2 domain + score += static_cast(M_LOG2E_F) * static_cast(fmask[0]); } - // Update the accumulators + // Update the accumulators (using exp2 to match STEEL attention) U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + U factor = fast::exp2(max_score - new_max); + U exp_score = fast::exp2(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; diff --git a/mlx/backend/metal/kernels/sdpa_vector_vjp.h b/mlx/backend/metal/kernels/sdpa_vector_vjp.h new file mode 100644 index 0000000000..a66964e5c0 --- /dev/null +++ b/mlx/backend/metal/kernels/sdpa_vector_vjp.h @@ -0,0 +1,574 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/metal/kernels/atomic.h" + +using namespace metal; + +// Note: Function constants (has_mask, query_transposed, do_causal, bool_mask, +// float_mask, has_sinks) are defined in sdpa_vector.h with indices 20-25. +// This header assumes sdpa_vector.h is included first. + +/////////////////////////////////////////////////////////////////////////////// +// SDPA Vector VJP Kernel +// +// Computes gradients dQ, dK, dV for scaled dot-product attention backward pass. +// +// Forward: O = softmax(scale * Q @ K^T) @ V +// +// Backward (VJP): +// P = softmax(scale * Q @ K^T) [reconstructed from logsumexp] +// dV = P^T @ dO +// dP = dO @ V^T +// dS = P * (dP - sum(dP * P)) [softmax gradient] +// dQ = scale * dS @ K +// dK = scale * dS^T @ Q +// +// This kernel handles the "vector" case where Q_seq is small (<=8). +// Each threadgroup processes one (batch, head, q_seq) position. +// +// IMPORTANT: Stride Assumption +// This kernel uses input strides (k_head_stride, k_seq_stride, v_head_stride, +// v_seq_stride) for output array (d_keys, d_values) pointer arithmetic. +// The dispatch code must ensure that output arrays have matching strides: +// - d_k.strides() must match k.strides() for head and seq dimensions +// - d_v.strides() must match v.strides() for head and seq dimensions +// Failure to maintain this invariant will cause memory corruption. +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void sdpa_vector_vjp( + // Forward inputs + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + // Forward output and upstream gradient + const device T* out [[buffer(3)]], + const device T* d_out [[buffer(4)]], + // Logsumexp from forward (for numerically stable softmax reconstruction) + const device float* logsumexp [[buffer(5)]], + // Gradient outputs + device T* d_queries [[buffer(6)]], + device T* d_keys [[buffer(7)]], + device T* d_values [[buffer(8)]], + // Attention parameters + const constant int& gqa_factor [[buffer(9)]], + const constant int& N [[buffer(10)]], // KV sequence length + const constant size_t& k_head_stride [[buffer(11)]], + const constant size_t& k_seq_stride [[buffer(12)]], + const constant size_t& v_head_stride [[buffer(13)]], + const constant size_t& v_seq_stride [[buffer(14)]], + const constant float& scale [[buffer(15)]], + // Output (O/dO) stride parameters - STEEL forward may produce non-row-major + // layout Physical layout can be BLHV (strides [L*H*V, V, H*V, 1]) vs + // logical BHLV + const constant int& num_q_heads [[buffer(16)]], + const constant size_t& o_batch_stride [[buffer(17)]], + const constant size_t& o_head_stride [[buffer(18)]], + const constant size_t& o_seq_stride [[buffer(19)]], + // Optional mask inputs + const device bool* bmask [[buffer(20), function_constant(bool_mask)]], + const device T* fmask [[buffer(21), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(22), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(23), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(24), function_constant(has_mask)]], + // Optional attention sinks + const device T* sinks [[buffer(25), function_constant(has_sinks)]], + // Thread position info + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Block sizes matching forward kernel + constexpr int BN = 32; // Number of simdgroups (parallel KV positions) + constexpr int BD = 32; // Simdgroup width (threads per simdgroup) + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); + + typedef float U; // Accumulator type for numerical stability + + // Thread-local storage for queries, keys, values, gradients + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U v[v_per_thread]; + thread U dq[qk_per_thread]; // Gradient w.r.t. query + thread U d_o[v_per_thread]; // Upstream gradient + thread U o[v_per_thread]; // Forward output (for delta computation) + + // Threadgroup memory for reductions and communication + threadgroup U shared_delta[1]; // delta = sum(dO * O) + threadgroup U shared_dQ[BN * D]; // For dQ reduction across simdgroups + + // Compute positions (same as forward) + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + // Decompose batch_head index into batch and head for stride-based O/dO access + // STEEL forward produces BLHV physical layout, so we need explicit strides + const int batch_idx = q_batch_head_idx / num_q_heads; + const int head_idx = q_batch_head_idx % num_q_heads; + // LSE is row-major [B*H, L], so keep the combined index for LSE access + const int lse_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : lse_offset; + + // Set up input pointers + queries += q_offset * D + simd_lid * qk_per_thread; + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + + // Set up mask pointers if needed + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + + // Set up output/gradient pointers + // Use explicit strides for O/dO to handle BLHV physical layout from STEEL + // For BLHV strides: o_batch_stride = L*H*V, o_head_stride = V, o_seq_stride = + // H*V + out += batch_idx * o_batch_stride + head_idx * o_head_stride + + q_seq_idx * o_seq_stride + simd_lid * v_per_thread; + d_out += batch_idx * o_batch_stride + head_idx * o_head_stride + + q_seq_idx * o_seq_stride + simd_lid * v_per_thread; + // LSE is row-major [B*H, L] - no stride adjustment needed + logsumexp += lse_offset; + + d_queries += q_offset * D + simd_lid * qk_per_thread; + d_keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + d_values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + + // Load query (scaled by M_LOG2E_F to match exp2 domain) + const U log2e_scale = static_cast(scale * M_LOG2E_F); + const U inv_log2e = static_cast(1.0f / M_LOG2E_F); + for (int i = 0; i < qk_per_thread; i++) { + q[i] = log2e_scale * queries[i]; + } + + // Initialize dQ accumulator to zero + for (int i = 0; i < qk_per_thread; i++) { + dq[i] = 0; + } + + // Load forward output O and upstream gradient dO + for (int i = 0; i < v_per_thread; i++) { + o[i] = out[i]; + d_o[i] = d_out[i]; + } + + // Compute delta = sum(dO * O) - needed for softmax gradient + // This is invariant across all KV positions for this query + U local_delta = 0; + for (int i = 0; i < v_per_thread; i++) { + local_delta += d_o[i] * o[i]; + } + // Sum across simdgroup + local_delta = simd_sum(local_delta); + + // First simdgroup stores delta to shared memory + if (simd_gid == 0 && simd_lid == 0) { + shared_delta[0] = local_delta; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + U delta = shared_delta[0]; + + // Load logsumexp for this query position + U lse = logsumexp[0]; + + // Initialize shared_dQ to zero + for (int idx = simd_gid * BD + simd_lid; idx < BN * D; idx += BN * BD) { + shared_dQ[idx] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Main loop over KV sequence + for (int kv_idx = simd_gid; kv_idx < N; kv_idx += BN) { + bool use_key = true; + + // Apply causal or explicit mask + if (do_causal) { + use_key = kv_idx <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + + if (use_key) { + // Load key for this position + for (int j = 0; j < qk_per_thread; j++) { + k[j] = keys[j]; + } + + // Load value for this position + for (int j = 0; j < v_per_thread; j++) { + v[j] = values[j]; + } + + // Reconstruct attention score: S = scale * Q @ K^T + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + + // Add float mask if present (scaled by M_LOG2E_F to match log2 domain) + if (float_mask) { + score += static_cast(M_LOG2E_F) * static_cast(fmask[0]); + } + + // Reconstruct attention probability: P = exp2(S - logsumexp) + // Using exp2 to match STEEL attention domain (logsumexp is in log2 + // domain) + U prob = fast::exp2(score - lse); + + // Compute dP = dO @ V^T for this KV position + U dP = 0; + for (int j = 0; j < v_per_thread; j++) { + dP += d_o[j] * v[j]; + } + dP = simd_sum(dP); + + // Compute dS = P * (dP - delta) [softmax gradient] + U dS = prob * (dP - delta); + + // Accumulate dQ += scale * dS @ K + // Note: Although Q was scaled by M_LOG2E_F internally, the softmax + // gradient dS compensates for this because the overall softmax(S') = + // softmax(S). The gradient dQ = scale * dS @ K matches the reference. + for (int j = 0; j < qk_per_thread; j++) { + dq[j] += static_cast(scale) * dS * k[j]; + } + + // Compute dK = dS @ Q * scale + // Note: q[j] = scale * M_LOG2E_F * Q[j], so dS * q gives: + // dK = scale * M_LOG2E_F * dS @ Q + // Reference expects: dK = scale * dS @ Q + // So we multiply by inv_log2e to cancel the M_LOG2E_F + for (int j = 0; j < qk_per_thread; j++) { + U dk_val = inv_log2e * dS * q[j]; + // Atomic add - multiple query positions may contribute to same dK + mlx_atomic_fetch_add_explicit( + reinterpret_cast*>(d_keys), + static_cast(dk_val), + j); + } + + // Accumulate dV += P^T @ dO = P * dO (for this KV position) + // prob is scalar for this (q, kv) pair, broadcast to all dO elements + // Atomic add - multiple query positions may contribute to same dV + for (int j = 0; j < v_per_thread; j++) { + mlx_atomic_fetch_add_explicit( + reinterpret_cast*>(d_values), + static_cast(prob * d_o[j]), + j); + } + } + + // Move to next KV block + keys += inner_k_stride; + values += inner_v_stride; + d_keys += inner_k_stride; + d_values += inner_v_stride; + + if (bool_mask) { + bmask += BN * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * mask_kv_seq_stride; + } + } + + // Write accumulated dQ gradient + // Need to reduce across simdgroups since each processed different KV + // positions but they all contribute to the same query gradient + + // Store each simdgroup's partial dQ to shared memory + // NOTE: Use D (head dimension) not BD (simdgroup width) for the stride + // Each simdgroup needs D elements to store its full dQ contribution + for (int i = 0; i < qk_per_thread; i++) { + shared_dQ[simd_gid * D + simd_lid * qk_per_thread + i] = dq[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduce dQ across simdgroups + if (simd_gid == 0) { + for (int i = 0; i < qk_per_thread; i++) { + U sum_dq = 0; + for (int sg = 0; sg < BN; sg++) { + sum_dq += shared_dQ[sg * D + simd_lid * qk_per_thread + i]; + } + d_queries[i] = static_cast(sum_dq); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// SDPA Vector VJP Kernel - Separate dK/dV accumulation version +// +// This version is more suitable when multiple query positions exist, +// as dK and dV need proper accumulation across all query contributions. +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void sdpa_vector_vjp_accumulate( + // Forward inputs + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + // Forward output and upstream gradient + const device T* out [[buffer(3)]], + const device T* d_out [[buffer(4)]], + // Logsumexp from forward + const device float* logsumexp [[buffer(5)]], + // Gradient outputs (dK and dV are accumulated atomically) + device T* d_queries [[buffer(6)]], + device float* d_keys_accum [[buffer(7)]], // float for atomic accumulation + device float* d_values_accum [[buffer(8)]], // float for atomic accumulation + // Attention parameters + const constant int& gqa_factor [[buffer(9)]], + const constant int& N [[buffer(10)]], + const constant int& Q_seq [[buffer(11)]], // Number of query positions + const constant size_t& k_head_stride [[buffer(12)]], + const constant size_t& k_seq_stride [[buffer(13)]], + const constant size_t& v_head_stride [[buffer(14)]], + const constant size_t& v_seq_stride [[buffer(15)]], + const constant float& scale [[buffer(16)]], + // Output (O/dO) stride parameters - STEEL forward may produce non-row-major + // layout + const constant int& num_q_heads [[buffer(17)]], + const constant size_t& o_batch_stride [[buffer(18)]], + const constant size_t& o_head_stride [[buffer(19)]], + const constant size_t& o_seq_stride [[buffer(20)]], + // Optional mask inputs + const device bool* bmask [[buffer(21), function_constant(bool_mask)]], + const device T* fmask [[buffer(22), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(23), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(24), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(25), function_constant(has_mask)]], + // Thread position info + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); + + typedef float U; + + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U v[v_per_thread]; + thread U dq[qk_per_thread]; + thread U d_o[v_per_thread]; + thread U o[v_per_thread]; + + threadgroup U shared_delta[1]; + threadgroup U shared_dQ[BN * D]; // For dQ reduction across simdgroups + + // Position setup + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + // Decompose batch_head index for stride-based O/dO access + const int batch_idx = q_batch_head_idx / num_q_heads; + const int head_idx = q_batch_head_idx % num_q_heads; + const int lse_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : lse_offset; + + // Input pointer setup + const device T* q_ptr = queries + q_offset * D + simd_lid * qk_per_thread; + const device T* k_base = + keys + kv_head_idx * k_head_stride + simd_lid * qk_per_thread; + const device T* v_base = + values + kv_head_idx * v_head_stride + simd_lid * v_per_thread; + + // Use explicit strides for O/dO to handle BLHV physical layout + const device T* o_ptr = out + batch_idx * o_batch_stride + + head_idx * o_head_stride + q_seq_idx * o_seq_stride + + simd_lid * v_per_thread; + const device T* do_ptr = d_out + batch_idx * o_batch_stride + + head_idx * o_head_stride + q_seq_idx * o_seq_stride + + simd_lid * v_per_thread; + // LSE is row-major [B*H, L] - use lse_offset + U lse = logsumexp[lse_offset]; + + // Output pointer setup + device T* dq_ptr = d_queries + q_offset * D + simd_lid * qk_per_thread; + device float* dk_base = + d_keys_accum + kv_head_idx * k_head_stride + simd_lid * qk_per_thread; + device float* dv_base = + d_values_accum + kv_head_idx * v_head_stride + simd_lid * v_per_thread; + + // Mask pointer setup + const device bool* bm_ptr = bool_mask ? bmask + + q_batch_head_idx * mask_head_stride + q_seq_idx * mask_q_seq_stride + : nullptr; + const device T* fm_ptr = float_mask ? fmask + + q_batch_head_idx * mask_head_stride + q_seq_idx * mask_q_seq_stride + : nullptr; + + // Load query (scaled by M_LOG2E_F to match exp2 domain) + const U log2e_scale = static_cast(scale * M_LOG2E_F); + for (int i = 0; i < qk_per_thread; i++) { + q[i] = log2e_scale * q_ptr[i]; + dq[i] = 0; + } + + // Load O and dO, compute delta + U local_delta = 0; + for (int i = 0; i < v_per_thread; i++) { + o[i] = o_ptr[i]; + d_o[i] = do_ptr[i]; + local_delta += d_o[i] * o[i]; + } + local_delta = simd_sum(local_delta); + + if (simd_gid == 0 && simd_lid == 0) { + shared_delta[0] = local_delta; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + U delta = shared_delta[0]; + + // Process KV sequence + const device T* k_ptr = k_base + simd_gid * k_seq_stride; + const device T* v_ptr = v_base + simd_gid * v_seq_stride; + device float* dk_ptr = dk_base + simd_gid * k_seq_stride; + device float* dv_ptr = dv_base + simd_gid * v_seq_stride; + // NOTE: mask_kv_seq_stride is only defined when has_mask is true + // (function_constant) Initialize mask_offset only when mask is present to + // avoid undefined behavior + int mask_offset = (has_mask) ? simd_gid * mask_kv_seq_stride : 0; + + for (int kv_idx = simd_gid; kv_idx < N; kv_idx += BN) { + bool use_key = true; + + if (do_causal) { + use_key = kv_idx <= (N - Q_seq + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bm_ptr[mask_offset]; + } else if (float_mask) { + use_key = (fm_ptr[mask_offset] >= Limits::finite_min); + } + + if (use_key) { + // Load K, V + for (int j = 0; j < qk_per_thread; j++) { + k[j] = k_ptr[j]; + } + for (int j = 0; j < v_per_thread; j++) { + v[j] = v_ptr[j]; + } + + // Compute score and probability + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + + if (float_mask) { + // Scale float mask by M_LOG2E_F to match log2 domain + score += + static_cast(M_LOG2E_F) * static_cast(fm_ptr[mask_offset]); + } + + // Reconstruct probability: P = exp2(S - logsumexp) + // Using exp2 to match STEEL attention domain (logsumexp is in log2 + // domain) + U prob = fast::exp2(score - lse); + + // Compute dP + U dP = 0; + for (int j = 0; j < v_per_thread; j++) { + dP += d_o[j] * v[j]; + } + dP = simd_sum(dP); + + // Compute dS + U dS = prob * (dP - delta); + + // Accumulate dQ + // Note: We use scale (not log2e_scale) because: + // - The formula dS = P * (dP - delta) gives the exp-domain gradient + // - The exp2 Jacobian has a ln(2) factor, but it cancels with the + // M_LOG2E_F factor from Q scaling, so the net effect is just scale + for (int j = 0; j < qk_per_thread; j++) { + dq[j] += static_cast(scale) * dS * k[j]; + } + + // Atomic add to dK + // dK = scale * dS * Q (q_ptr has unscaled query) + // All threads in simdgroup contribute to different elements + for (int j = 0; j < qk_per_thread; j++) { + U dk_val = static_cast(scale) * dS * q_ptr[j]; + mlx_atomic_fetch_add_explicit( + reinterpret_cast*>(dk_ptr), + static_cast(dk_val), + j); + } + + // Atomic add to dV + // dV = prob * dO + for (int j = 0; j < v_per_thread; j++) { + mlx_atomic_fetch_add_explicit( + reinterpret_cast*>(dv_ptr), + static_cast(prob * d_o[j]), + j); + } + } + + // Advance pointers + k_ptr += inner_k_stride; + v_ptr += inner_v_stride; + dk_ptr += inner_k_stride; + dv_ptr += inner_v_stride; + // NOTE: Only update mask_offset when mask is present (mask_kv_seq_stride is + // function_constant) + if (has_mask) { + mask_offset += BN * mask_kv_seq_stride; + } + } + + // Reduce and write dQ + // NOTE: Use D (head dimension) not BD (simdgroup width) for the stride + // Each simdgroup needs D elements to store its full dQ contribution + for (int i = 0; i < qk_per_thread; i++) { + shared_dQ[simd_gid * D + simd_lid * qk_per_thread + i] = dq[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_gid == 0) { + for (int i = 0; i < qk_per_thread; i++) { + U sum_dq = 0; + for (int sg = 0; sg < BN; sg++) { + sum_dq += shared_dQ[sg * D + simd_lid * qk_per_thread + i]; + } + dq_ptr[i] = static_cast(sum_dq); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 4de11b0819..e0a8a7b6b3 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; +constant bool output_logsumexp [[function_constant(303)]]; template struct TransformScale { @@ -86,6 +87,7 @@ template < const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]], const device T* sinks [[buffer(7), function_constant(has_sinks)]], + device float* LSE [[buffer(8), function_constant(output_logsumexp)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -460,6 +462,32 @@ template < Otile.template row_bin_op(sum_score); threadgroup_barrier(mem_flags::mem_none); + // Output logsumexp if requested for VJP backward pass + // LSE = max_score + log2(sum_score) in log2 domain (matches STEEL convention) + // Physical storage shape: [B*H, qL], laid out as linear array indexed by (B*H + // + head)*qL + query_pos LSE_strides[0] = qL (stride between (batch, head) + // rows) LSE_strides[1] = 1 (stride between query positions within a row) + if (output_logsumexp) { + // Compute linear index for (batch, head) combination + // This matches the VJP kernel's indexing: (tidl.z * H + tidl.y) * + // LSE_strides[0] + device float* lse_out = + LSE + (tidl.z * params->H + tidl.y) * params->LSE_strides[0]; + + // Write one logsumexp per query position in this tile + // Each thread handles kRowsPT query positions + // align_Q=true means query length is aligned (all blocks full), so always + // write align_Q=false means last block is partial, so check bounds + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + int row_pos = tid.x * BQ + tm + sm + (i * decltype(Stile)::kFragRows); + if (align_Q || row_pos < params->qL) { + AccumType lse_val = max_score[i] + fast::log2(sum_score[i]); + lse_out[row_pos * params->LSE_strides[1]] = static_cast(lse_val); + } + } + } + // Store results O += (tm + sm) * params->O_strides[2] + sn; diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index f1cf09fada..4370589393 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -34,6 +34,7 @@ struct AttnParams { int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) + int64_t LSE_strides[2]; ///< LSE strides (B*H, L) - logsumexp output for VJP }; struct AttnMaskParams { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f09cacf1c7..11b5df1fc3 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,5 +1,7 @@ // Copyright © 2024 Apple Inc. +#include #include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" @@ -9,6 +11,7 @@ #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" +#include "mlx/ops.h" #include "mlx/utils.h" namespace mlx::core::fast { @@ -112,6 +115,7 @@ void sdpa_full_self_attention_nax( const int NQ_aligned = qL / bq; const int NK_aligned = kL / bk; + // NAX doesn't support logsumexp output - provide dummy strides AttnParams params{ /* int B = */ B, /* int H = */ H, @@ -136,7 +140,8 @@ void sdpa_full_self_attention_nax( /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, - /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}, + /* int64_t LSE_strides[2] = */ {0, 0}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); @@ -173,9 +178,12 @@ void sdpa_full_self_attention_metal( array& o, bool do_causal_, const std::optional& mask, - const std::optional& sinks) { + const std::optional& sinks, + bool output_logsumexp_ = false, + array* lse_out = nullptr) { + // NAX path does not support logsumexp output - skip when VJP needs it if (metal::is_nax_available() && q.shape(3) != 80 && - (env::enable_tf32() || q.dtype() != float32)) { + (env::enable_tf32() || q.dtype() != float32) && !output_logsumexp_) { return sdpa_full_self_attention_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, @@ -211,13 +219,15 @@ void sdpa_full_self_attention_metal( const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; const bool has_sinks = sinks.has_value(); + const bool output_logsumexp = output_logsumexp_; metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, {&do_causal, MTL::DataType::DataTypeBool, 301}, - {&has_sinks, MTL::DataType::DataTypeBool, 302}}; + {&has_sinks, MTL::DataType::DataTypeBool, 302}, + {&output_logsumexp, MTL::DataType::DataTypeBool, 303}}; std::string base_name; concatenate( @@ -250,7 +260,9 @@ void sdpa_full_self_attention_metal( "_do_causal_", (do_causal ? 't' : 'n'), "_has_sinks_", - (has_sinks ? 't' : 'n')); + (has_sinks ? 't' : 'n'), + "_lse_", + (output_logsumexp ? 't' : 'n')); auto& compute_encoder = d.get_command_encoder(s.index); @@ -275,6 +287,14 @@ void sdpa_full_self_attention_metal( const int NQ_aligned = qL / bq; const int NK_aligned = kL / bk; + // Compute LSE strides if outputting logsumexp: shape [B, H, qL, 1] + // The VJP kernel expects strides as: + // LSE_strides[0] = qL (stride between heads within same batch) + // LSE_strides[1] = 1 (stride between query positions) + // Linear index = (batch * H + head) * qL + query_pos + int64_t lse_str_head = qL; // Stride between heads + int64_t lse_str_qpos = 1; // Stride between query positions + AttnParams params{ /* int B = */ B, /* int H = */ H, @@ -299,7 +319,8 @@ void sdpa_full_self_attention_metal( /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, - /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}, + /* int64_t LSE_strides[2] = */ {lse_str_head, lse_str_qpos}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); @@ -319,6 +340,9 @@ void sdpa_full_self_attention_metal( if (has_sinks) { compute_encoder.set_input_array(*sinks, 7); } + if (output_logsumexp && lse_out != nullptr) { + compute_encoder.set_output_array(*lse_out, 8); + } MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); @@ -594,17 +618,11 @@ bool ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s) { - if (is_training) { - // It's faster for training on Metal to use the unfused SDPA for both - // forward and backward. - return true; - } - if (output_logsumexp) { - return true; - } + // Note: When output_logsumexp is true, the caller (fast.cpp) has already + // verified VJP availability with proper has_mask/has_sinks parameters. + // No redundant check needed here. if (s.device == Device::cpu) { return true; } @@ -681,7 +699,9 @@ void ScaledDotProductAttention::eval_gpu( bool has_arr_mask = inputs.size() > (3 + has_sinks_); // We are in vector mode ie single query - if (q_pre.shape(2) <= 8) { + // NOTE: Vector mode doesn't support logsumexp output needed for VJP. + // When output_logsumexp_ is true (training mode), use full attention instead. + if (q_pre.shape(2) <= 8 && !output_logsumexp_) { auto q_copy_unless = [](const array& arr) { if (arr.flags().row_contiguous) { return true; @@ -772,25 +792,592 @@ void ScaledDotProductAttention::eval_gpu( {str_oB, str_oH, str_oL, str_oD}, flags); + // Handle logsumexp output for VJP backward pass + array* lse_out = nullptr; + if (output_logsumexp_ && outputs.size() > 1) { + auto& lse = outputs[1]; + lse.set_data(allocator::malloc(lse.nbytes())); + lse_out = &outputs[1]; + } + auto mask = has_arr_mask ? std::optional{copy_unless(is_matrix_contiguous, inputs[3])} : std::nullopt; sdpa_full_self_attention_metal( - s, d, q, k, v, scale_, o, do_causal_, mask, sinks); + s, + d, + q, + k, + v, + scale_, + o, + do_causal_, + mask, + sinks, + output_logsumexp_, + lse_out); } d.add_temporaries(std::move(copies), s.index); } -bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { +bool ScaledDotProductAttentionVJP::use_fallback( + const array& q, + Stream s, + bool has_mask, + bool has_sinks, + int n_kv_heads) { + // Use fallback on CPU + if (s.device == Device::cpu) { + return true; + } + + const int query_head_dim = q.shape(-1); + const int query_seq_len = q.shape(2); + + // Vector VJP uses exp2() matching forward pass's log2 domain. + // Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP. + // Note: The accumulate variant (used for half/bfloat16) does NOT support + // sinks. + const bool vector_supported_head_dim = + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); + const bool is_float32 = (q.dtype() == float32); + + // For short sequences (seq <= 8), prefer vector VJP if head dim is supported. + // However, sinks are only supported in the float32 vector VJP kernel, + // not in the accumulate variant used for half/bfloat16. + if (query_seq_len <= 8 && vector_supported_head_dim) { + // If sinks are present and dtype is not float32, must use fallback + // because sdpa_vector_vjp_accumulate doesn't support sinks. + if (has_sinks && !is_float32) { + return true; // Must use unfused attention for sinks with half/bfloat16 + } + return false; // Use vector VJP + } + + // For longer sequences (L > 8), use fallback (unfused attention) + // STEEL VJP for longer sequences will be added in a future PR return true; } +namespace { + +void sdpa_vector_vjp_dispatch( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const array& out, + const array& d_out, + const array& logsumexp, + array& d_q, + array& d_k, + array& d_v, + float scale, + bool do_causal, + const std::optional& mask, + const std::optional& sinks) { + // Set the kernel name (matching forward pattern) + std::string kname; + kname.reserve(64); + kname += "sdpa_vector_vjp_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(v.shape(-1)); + + // Compute the necessary sizes (same as forward) + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); + size_t k_seq_stride = k.strides()[2]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); + size_t v_seq_stride = v.strides()[2]; + + // Vector VJP kernel uses input strides for output pointer arithmetic. + // Verify output strides match input strides to prevent memory corruption. + // Stride requirements: + // d_k head stride must match k head stride + // d_k seq stride must match k seq stride + // d_v head stride must match v head stride + // d_v seq stride must match v seq stride + size_t d_k_head_stride = d_k.shape(1) == 1 ? d_k.strides(0) : d_k.strides(1); + size_t d_v_head_stride = d_v.shape(1) == 1 ? d_v.strides(0) : d_v.strides(1); + if (d_k_head_stride != k_head_stride || d_k.strides()[2] != k_seq_stride || + d_v_head_stride != v_head_stride || d_v.strides()[2] != v_seq_stride) { + throw std::runtime_error( + "Stride mismatch in vector VJP kernel: " + "output array strides must match input array strides. " + "This may occur with non-contiguous array views."); + } + + MTL::Size group_dims(1024, 1, 1); + MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1); + + bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; + bool query_transposed = !q.flags().row_contiguous; + bool has_sinks_flag = sinks.has_value(); + + // Function constants (same indices as forward) + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks_flag, MTL::DataType::DataTypeBool, 25}, + }; + + std::string hash_name = kname; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_sinks_flag ? "_sinks" : "_nosinks"; + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set kernel arguments + // Inputs: Q, K, V, O, dO, logsumexp + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_input_array(out, 3); + compute_encoder.set_input_array(d_out, 4); + compute_encoder.set_input_array(logsumexp, 5); + + // Outputs: dQ, dK, dV + compute_encoder.set_output_array(d_q, 6); + compute_encoder.set_output_array(d_k, 7); + compute_encoder.set_output_array(d_v, 8); + + // Parameters + compute_encoder.set_bytes(gqa_factor, 9); + compute_encoder.set_bytes(N, 10); + compute_encoder.set_bytes(k_head_stride, 11); + compute_encoder.set_bytes(k_seq_stride, 12); + compute_encoder.set_bytes(v_head_stride, 13); + compute_encoder.set_bytes(v_seq_stride, 14); + compute_encoder.set_bytes(scale, 15); + + // Output (O/dO) stride parameters - handle BLHV physical layout from STEEL + // For BLHV layout: strides are [L*H*V, V, H*V, 1] vs logical [B, H, L, V] + int num_q_heads = q.shape(1); + size_t o_batch_stride = out.strides(0); + size_t o_head_stride = out.shape(1) == 1 ? 0 : out.strides(1); + size_t o_seq_stride = out.strides(2); + compute_encoder.set_bytes(num_q_heads, 16); + compute_encoder.set_bytes(o_batch_stride, 17); + compute_encoder.set_bytes(o_head_stride, 18); + compute_encoder.set_bytes(o_seq_stride, 19); + + // Optional mask inputs (buffer indices shifted by 4) + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array( + m, 20 + float_mask); // 20 for bool, 21 for float + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); + compute_encoder.set_bytes(kv_seq_stride, 22); + compute_encoder.set_bytes(q_seq_stride, 23); + compute_encoder.set_bytes(head_stride, 24); + } + + // Optional sinks (buffer index shifted by 4) + if (has_sinks_flag) { + compute_encoder.set_input_array(*sinks, 25); + } + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +// Dispatch function for vector VJP with float32 accumulators (for +// half/bfloat16) This variant uses the sdpa_vector_vjp_accumulate kernel which +// has device float* signature for dK and dV buffers, allowing correct pointer +// arithmetic for atomic float operations. +void sdpa_vector_vjp_accumulate_dispatch( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const array& out, + const array& d_out, + const array& logsumexp, + array& d_q, + array& d_k_accum, // float32 accumulator buffer + array& d_v_accum, // float32 accumulator buffer + float scale, + bool do_causal, + const std::optional& mask, + const std::optional& sinks) { + // Set the kernel name (uses accumulate variant) + std::string kname; + kname.reserve(64); + kname += "sdpa_vector_vjp_accumulate_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(v.shape(-1)); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int Q_seq = q.shape(2); // Number of query positions + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); + size_t k_seq_stride = k.strides()[2]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); + size_t v_seq_stride = v.strides()[2]; + + // Vector VJP kernel uses input strides for output pointer arithmetic. + // Verify accumulator buffer strides match input strides to prevent memory + // corruption. Stride requirements: + // d_k_accum head stride must match k head stride + // d_k_accum seq stride must match k seq stride + // d_v_accum head stride must match v head stride + // d_v_accum seq stride must match v seq stride + size_t d_k_head_stride = + d_k_accum.shape(1) == 1 ? d_k_accum.strides(0) : d_k_accum.strides(1); + size_t d_v_head_stride = + d_v_accum.shape(1) == 1 ? d_v_accum.strides(0) : d_v_accum.strides(1); + if (d_k_head_stride != k_head_stride || + d_k_accum.strides()[2] != k_seq_stride || + d_v_head_stride != v_head_stride || + d_v_accum.strides()[2] != v_seq_stride) { + throw std::runtime_error( + "Stride mismatch in vector VJP kernel: " + "output array strides must match input array strides. " + "This may occur with non-contiguous array views."); + } + + MTL::Size group_dims(1024, 1, 1); + MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1); + + bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; + bool query_transposed = !q.flags().row_contiguous; + bool has_sinks_flag = sinks.has_value(); + + // Function constants (same indices as forward) + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks_flag, MTL::DataType::DataTypeBool, 25}, + }; + + std::string hash_name = kname; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_sinks_flag ? "_sinks" : "_nosinks"; + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set kernel arguments (accumulate variant has slightly different buffer + // layout) Inputs: Q, K, V, O, dO, logsumexp + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_input_array(out, 3); + compute_encoder.set_input_array(d_out, 4); + compute_encoder.set_input_array(logsumexp, 5); + + // Outputs: dQ, dK_accum (float32), dV_accum (float32) + compute_encoder.set_output_array(d_q, 6); + compute_encoder.set_output_array(d_k_accum, 7); + compute_encoder.set_output_array(d_v_accum, 8); + + // Parameters (note: buffer indices shifted from regular VJP kernel) + compute_encoder.set_bytes(gqa_factor, 9); + compute_encoder.set_bytes(N, 10); + compute_encoder.set_bytes( + Q_seq, 11); // Extra parameter for accumulate variant + compute_encoder.set_bytes(k_head_stride, 12); + compute_encoder.set_bytes(k_seq_stride, 13); + compute_encoder.set_bytes(v_head_stride, 14); + compute_encoder.set_bytes(v_seq_stride, 15); + compute_encoder.set_bytes(scale, 16); + + // Output (O/dO) stride parameters - handle BLHV physical layout from STEEL + int num_q_heads = q.shape(1); + size_t o_batch_stride = out.strides(0); + size_t o_head_stride = out.shape(1) == 1 ? 0 : out.strides(1); + size_t o_seq_stride = out.strides(2); + compute_encoder.set_bytes(num_q_heads, 17); + compute_encoder.set_bytes(o_batch_stride, 18); + compute_encoder.set_bytes(o_head_stride, 19); + compute_encoder.set_bytes(o_seq_stride, 20); + + // Optional mask inputs (buffer indices shifted by 4) + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array( + m, 21 + float_mask); // 21 for bool, 22 for float + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); + compute_encoder.set_bytes(kv_seq_stride, 23); + compute_encoder.set_bytes(q_seq_stride, 24); + compute_encoder.set_bytes(head_stride, 25); + } + + // Note: sinks not supported in accumulate variant. + // use_fallback() returns true for sinks with non-float32 dtypes, + // so this code path should never be reached with has_sinks_flag=true. + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +} // namespace + void ScaledDotProductAttentionVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("NYI"); + auto& s = stream(); + auto& d = metal::device(s.device); + + // Parse inputs: + // inputs = [Q, K, V, (optional mask), (optional sinks), O, logsumexp, dO] + // The last 3 are always O, logsumexp, dO + const auto& q_pre = inputs[0]; + const auto& k_pre = inputs[1]; + const auto& v_pre = inputs[2]; + + // Determine indices based on optional inputs + // primals can have mask and/or sinks appended + size_t num_primals = inputs.size() - 3; // Subtract O, logsumexp, dO + const auto& out = inputs[num_primals]; + const auto& logsumexp = inputs[num_primals + 1]; + const auto& d_out = inputs[num_primals + 2]; + + auto& d_q = outputs[0]; + auto& d_k = outputs[1]; + auto& d_v = outputs[2]; + + std::vector copies; + copies.reserve(inputs.size()); + + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { + if (!predicate(arr)) { + array arr_copy = contiguous_copy_gpu(arr, s); + copies.push_back(std::move(arr_copy)); + return copies.back(); + } else { + return arr; + } + }; + + auto is_matrix_contiguous = [](const array& arr) { + return arr.strides(-1) == 1; + }; + + // Handle optional sinks + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs[num_primals - 1]); + } + + // Determine if we have a mask + bool has_arr_mask = num_primals > (3 + has_sinks_); + + // Determine early whether to use vector VJP (needed for K/V copy decisions) + // Vector VJP uses input K/V strides for output dK/dV pointer arithmetic, + // so K/V must be row-contiguous when using vector VJP. + // Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP. + const int query_head_dim_pre = q_pre.shape(-1); + const int value_head_dim_pre = v_pre.shape(-1); + const bool vector_supported_head_dim = + query_head_dim_pre == value_head_dim_pre && + (query_head_dim_pre == 64 || query_head_dim_pre == 96 || + query_head_dim_pre == 128); + bool use_vector_vjp = (q_pre.shape(2) <= 8) && vector_supported_head_dim; + + // Copy predicates for Q (same as forward) + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + auto is_row_contiguous = [](const array& arr) { + return arr.flags().row_contiguous; + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + // Vector VJP requires row-contiguous K/V because the kernel uses input + // strides for output pointer arithmetic, and dK/dV are always contiguous. + const auto& k = use_vector_vjp ? copy_unless(is_row_contiguous, k_pre) + : copy_unless(kv_copy_unless, k_pre); + const auto& v = use_vector_vjp ? copy_unless(is_row_contiguous, v_pre) + : copy_unless(kv_copy_unless, v_pre); + const auto& o = copy_unless(is_matrix_contiguous, out); + const auto& dO = copy_unless(is_matrix_contiguous, d_out); + const auto& lse = copy_unless(is_matrix_contiguous, logsumexp); + + // Allocate output gradient arrays + // The vector VJP kernel uses atomic adds to accumulate dK and dV, + // so we must zero-initialize these arrays. + const int query_head_dim = q.shape(-1); + const int value_head_dim = v.shape(-1); + + d_q.set_data(allocator::malloc(d_q.nbytes())); + + // CRITICAL FIX: The vector VJP kernel uses mlx_atomic for dK/dV + // accumulation. This works correctly ONLY when the output dtype is float32. + // For half/bfloat16, reinterpret_cast*>(d_keys) + // causes memory corruption because half is 2 bytes but float is 4 bytes. + // + // Solution: For non-float32 dtypes with vector VJP, we: + // 1. Allocate float32 temporary accumulators for dK and dV + // 2. Run the kernel with these float32 buffers + // 3. Copy/convert from float32 to the original dtype after kernel completion + bool needs_float32_accumulators = use_vector_vjp && (q.dtype() != float32); + std::optional dk_accum = std::nullopt; + std::optional dv_accum = std::nullopt; + + if (use_vector_vjp) { + if (needs_float32_accumulators) { + // Allocate float32 accumulator buffers with same shape as dK/dV + // Note: zeros() creates lazy arrays with null data pointer. + // We must explicitly allocate and zero-initialize for GPU kernel use. + size_t dk_bytes = d_k.size() * sizeof(float); + size_t dv_bytes = d_v.size() * sizeof(float); + dk_accum = array(allocator::malloc(dk_bytes), d_k.shape(), float32); + dv_accum = array(allocator::malloc(dv_bytes), d_v.shape(), float32); + + // Zero-initialize the accumulator buffers + array zero_f32 = array(0.0f, float32); + fill_gpu(zero_f32, dk_accum.value(), s); + fill_gpu(zero_f32, dv_accum.value(), s); + copies.push_back(std::move(zero_f32)); + + // Allocate the actual output arrays (will be written after kernel) + d_k.set_data(allocator::malloc(d_k.nbytes())); + d_v.set_data(allocator::malloc(d_v.nbytes())); + } else { + // No float32 accumulators needed: zero-initialize dK/dV directly + // Must allocate memory before fill_gpu + d_k.set_data(allocator::malloc(d_k.nbytes())); + d_v.set_data(allocator::malloc(d_v.nbytes())); + array zero = array(0.0f, d_k.dtype()); + fill_gpu(zero, d_k, s); + fill_gpu(zero, d_v, s); + copies.push_back(std::move(zero)); + } + } + + // Handle mask + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + + std::optional mask = std::nullopt; + if (has_arr_mask) { + mask = copy_unless(mask_copy_unless, inputs[3]); + } + + bool do_causal = do_causal_ && q.shape(2) > 1; + + // Dispatch to appropriate kernel based on sequence length + if (use_vector_vjp) { + if (needs_float32_accumulators) { + // Use float32 accumulator buffers with the accumulate kernel variant + // This variant has device float* signature for dK/dV, ensuring correct + // pointer arithmetic (sizeof(float)=4) instead of sizeof(T)=2 for + // half/bfloat16 + array& dk_acc = dk_accum.value(); + array& dv_acc = dv_accum.value(); + sdpa_vector_vjp_accumulate_dispatch( + s, + d, + q, + k, + v, + o, + dO, + lse, + d_q, + dk_acc, + dv_acc, + scale_, + do_causal, + mask, + sinks); + + // Convert float32 accumulators to original dtype + // This uses the standard copy primitive with type conversion + copy_gpu(dk_acc, d_k, CopyType::General, s); + copy_gpu(dv_acc, d_v, CopyType::General, s); + + // Add accumulators as temporaries for cleanup + d.add_temporary(dk_acc, s.index); + d.add_temporary(dv_acc, s.index); + } else { + // Float32: pass dK/dV directly (already zero-initialized above) + sdpa_vector_vjp_dispatch( + s, + d, + q, + k, + v, + o, + dO, + lse, + d_q, + d_k, + d_v, + scale_, + do_causal, + mask, + sinks); + } + } + + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..7e6d792bef 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -30,7 +30,6 @@ bool fast::ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s) { return true; @@ -42,7 +41,10 @@ bool fast::ScaledDotProductAttention::supports_bool_mask() { bool fast::ScaledDotProductAttentionVJP::use_fallback( const array& q, - Stream s) { + Stream s, + bool /* has_mask */, + bool /* has_sinks */, + int /* n_kv_heads */) { return true; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index bf140b7b51..71887b2c20 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -805,9 +805,9 @@ array scaled_dot_product_attention( inputs.push_back(astype(*sinks, final_type, stream)); } - bool is_training = detail::in_grad_tracing(); - bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback(q, stream); - bool output_logsumexp = is_training && has_fast_vjp; + bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback( + q, stream, has_mask, has_sinks, static_cast(n_kv_heads)); + bool output_logsumexp = detail::in_grad_tracing() && has_fast_vjp; if (!ScaledDotProductAttention::use_fallback( q, k, @@ -815,7 +815,6 @@ array scaled_dot_product_attention( has_mask, has_arr_mask, do_causal, - is_training, output_logsumexp, stream)) { if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) { @@ -853,11 +852,26 @@ std::vector ScaledDotProductAttention::vjp( assert(cotangents.size() == outputs.size()); auto s = stream(); - if (ScaledDotProductAttentionVJP::use_fallback(primals[0], s)) { - assert(outputs.size() == 1); + + // Determine if mask is present: primals = [Q, K, V, (mask), (sinks)] + bool has_mask = primals.size() > static_cast(3 + has_sinks_); + int n_kv_heads = primals[1].shape(1); // K is at index 1 + + // Check if we can use Flash Attention VJP + if (ScaledDotProductAttentionVJP::use_fallback( + primals[0], s, has_mask, has_sinks_, n_kv_heads) || + !output_logsumexp_) { return Custom::vjp(primals, cotangents, argnums, outputs); } + // When output_logsumexp_ is true, the forward pass creates 2 sibling arrays: + // outputs[0] = attention output, outputs[1] = logsumexp + // Even though only outputs[0] is returned to the user, the tape tracks both + // siblings. + assert( + outputs.size() >= 2 && + "Expected logsumexp in outputs[1] when output_logsumexp_ is true"); + auto fallback = [sdpa = fallback_, s](const std::vector& inputs) { std::vector primals(inputs.begin(), std::prev(inputs.end())); auto [_, vjps] = mlx::core::vjp(sdpa, primals, {inputs.back()}); @@ -873,8 +887,8 @@ std::vector ScaledDotProductAttention::vjp( auto primitive = std::make_shared( s, fallback, scale_, do_causal_, has_sinks_); std::vector inputs = primals; - inputs.push_back(outputs[0]); - inputs.push_back(outputs[1]); + inputs.push_back(outputs[0]); // Attention output + inputs.push_back(outputs[1]); // Logsumexp inputs.push_back(cotangents[0]); auto vjps = array::make_arrays(std::move(shapes), dtypes, primitive, inputs); diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..e3ffbb7a51 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -225,7 +225,6 @@ class ScaledDotProductAttention : public Custom { bool has_mask, bool has_arr_mask, bool do_causal, - bool is_training, bool output_logsumexp, Stream s); static bool supports_bool_mask(); @@ -273,7 +272,12 @@ class ScaledDotProductAttentionVJP : public Custom { do_causal_(do_causal), has_sinks_(has_sinks) {} - static bool use_fallback(const array& q, Stream s); + static bool use_fallback( + const array& q, + Stream s, + bool has_mask = false, + bool has_sinks = false, + int n_kv_heads = -1); void eval_cpu(const std::vector& inputs, std::vector& outputs) override { diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index fa6d039857..d2ab09869a 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -799,6 +799,391 @@ def test_grad(slow, fast, args): ).sum() test_grad(loss_slow, loss_fast, [q, k, v]) + def test_sdpa_grad_vector_path(self): + """Test VJP with short sequences using vector kernel (L <= 8)""" + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H = 2, 8 + for L in [1, 4, 7, 8]: + for D in [64, 128]: + with self.subTest(L=L, D=D): + scale = D**-0.5 + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + k = mx.random.normal((B, H, L, D), dtype=mx.float16) + v = mx.random.normal((B, H, L, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + def test_sdpa_grad_steel_path(self): + """Test VJP with longer sequences using STEEL kernel (L > 8)""" + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale, mask=None): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale, mask=mask) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H = 2, 8 + for L in [16, 32, 128, 256]: + for D in [64, 128]: + with self.subTest(L=L, D=D): + scale = D**-0.5 + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + k = mx.random.normal((B, H, L, D), dtype=mx.float16) + v = mx.random.normal((B, H, L, D), dtype=mx.float16) + + # Test without mask + test_vjp([q, k, v], scale) + + # Test with causal mask + test_vjp([q, k, v], scale, mask="causal") + + def test_sdpa_grad_head_dims(self): + """Test VJP across different head dimensions""" + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H, L = 2, 8, 64 + # D=256 not supported in vector VJP (threadgroup memory limit) + for D in [32, 64, 96, 128]: + with self.subTest(D=D): + scale = D**-0.5 + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + k = mx.random.normal((B, H, L, D), dtype=mx.float16) + v = mx.random.normal((B, H, L, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + def test_sdpa_grad_gqa(self): + """Test VJP with grouped query attention configurations""" + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, L, D = 2, 64, 64 + scale = D**-0.5 + + # Test various GQA configurations + configs = [ + (32, 8), # 4:1 GQA + (32, 4), # 8:1 GQA + (32, 2), # 16:1 GQA + (8, 8), # MHA (no GQA) + (16, 8), # 2:1 GQA + ] + for n_q_heads, n_kv_heads in configs: + with self.subTest(n_q_heads=n_q_heads, n_kv_heads=n_kv_heads): + q = mx.random.normal((B, n_q_heads, L, D), dtype=mx.float16) + k = mx.random.normal((B, n_kv_heads, L, D), dtype=mx.float16) + v = mx.random.normal((B, n_kv_heads, L, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + def test_sdpa_grad_dtypes(self): + """Test VJP with different precisions""" + + def test_vjp(primals, scale, tolerance): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H, L, D = 2, 8, 64, 64 + scale = D**-0.5 + + dtypes_and_tols = [ + (mx.float16, {"rtol": 1e-2, "atol": 1e-2}), + # bfloat16 has lower precision (7 bits mantissa vs 10 for float16) + (mx.bfloat16, {"rtol": 5e-2, "atol": 5e-2}), + (mx.float32, {"rtol": 1e-4, "atol": 1e-4}), + ] + + for dtype, tolerance in dtypes_and_tols: + with self.subTest(dtype=dtype): + q = mx.random.normal((B, H, L, D), dtype=dtype) + k = mx.random.normal((B, H, L, D), dtype=dtype) + v = mx.random.normal((B, H, L, D), dtype=dtype) + test_vjp([q, k, v], scale, tolerance) + + def test_sdpa_grad_edge_cases(self): + """Test VJP edge cases""" + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale, mask=None): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale, mask=mask) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + D = 64 + scale = D**-0.5 + + # Test single element (L=1) + with self.subTest(case="L=1"): + q = mx.random.normal((2, 8, 1, D), dtype=mx.float16) + k = mx.random.normal((2, 8, 1, D), dtype=mx.float16) + v = mx.random.normal((2, 8, 1, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + # Test non-power-of-2 lengths + for L in [63, 129, 65]: + with self.subTest(case=f"L={L}"): + q = mx.random.normal((2, 8, L, D), dtype=mx.float16) + k = mx.random.normal((2, 8, L, D), dtype=mx.float16) + v = mx.random.normal((2, 8, L, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + # Test large batch + with self.subTest(case="B=8"): + q = mx.random.normal((8, 8, 64, D), dtype=mx.float16) + k = mx.random.normal((8, 8, 64, D), dtype=mx.float16) + v = mx.random.normal((8, 8, 64, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + # Test different Q and KV lengths + with self.subTest(case="qL!=kvL"): + q = mx.random.normal((2, 8, 32, D), dtype=mx.float16) + k = mx.random.normal((2, 8, 64, D), dtype=mx.float16) + v = mx.random.normal((2, 8, 64, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + def test_sdpa_grad_with_mask(self): + """Test VJP with different mask types (boolean, additive, causal)""" + if not mx.is_available(mx.gpu): + return + + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale, mask=None): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale, mask=mask) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H, L, D = 2, 4, 32, 64 + scale = D**-0.5 + + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + k = mx.random.normal((B, H, L, D), dtype=mx.float16) + v = mx.random.normal((B, H, L, D), dtype=mx.float16) + + # Test with boolean mask + with self.subTest(mask_type="boolean"): + bool_mask = mx.random.uniform(0, 1, (B, H, L, L)) < 0.5 + test_vjp([q, k, v], scale, mask=bool_mask) + + # Test with additive mask + with self.subTest(mask_type="additive"): + additive_mask = mx.random.normal((B, H, L, L), dtype=mx.float16) + test_vjp([q, k, v], scale, mask=additive_mask) + + # Test with causal mask (mask=True/"causal") + with self.subTest(mask_type="causal"): + test_vjp([q, k, v], scale, mask="causal") + + # Test with no mask + with self.subTest(mask_type="none"): + test_vjp([q, k, v], scale, mask=None) + + # Test with broadcast mask (single head) + with self.subTest(mask_type="broadcast_head"): + broadcast_mask = mx.random.normal((B, 1, L, L), dtype=mx.float16) + test_vjp([q, k, v], scale, mask=broadcast_mask) + + # Test with broadcast mask (single batch) + with self.subTest(mask_type="broadcast_batch"): + broadcast_mask = mx.random.normal((1, H, L, L), dtype=mx.float16) + test_vjp([q, k, v], scale, mask=broadcast_mask) + + def test_sdpa_grad_short_seq(self): + """Test VJP for short sequences (L <= 8) that exercise fallback path""" + if not mx.is_available(mx.gpu): + return + + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H, D = 2, 4, 64 + scale = D**-0.5 + + # Test edge cases for vector mode (L <= 8) + for L in [1, 4, 7, 8]: + with self.subTest(L=L): + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + k = mx.random.normal((B, H, L, D), dtype=mx.float16) + v = mx.random.normal((B, H, L, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + # Test with GQA and short sequences + for L in [1, 4, 8]: + with self.subTest(L=L, gqa="4x"): + q = mx.random.normal((B, 8, L, D), dtype=mx.float16) + k = mx.random.normal((B, 2, L, D), dtype=mx.float16) + v = mx.random.normal((B, 2, L, D), dtype=mx.float16) + test_vjp([q, k, v], scale) + + def test_sdpa_grad_noncontiguous_kv(self): + """Test VJP with non-contiguous K/V inputs (strided views). + + This tests the fix for the stride mismatch bug where vector VJP + assumed K/V have the same strides as dK/dV outputs. + """ + if not mx.is_available(mx.gpu): + return + + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + def test_vjp(primals, scale): + slow = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale + ) + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue( + mx.allclose(vjp1[i], vjp2[i], **tolerance), + f"VJP mismatch for input {i}", + ) + + B, H, D = 2, 4, 64 + scale = D**-0.5 + + # Test short sequences (L <= 8) with non-contiguous K/V views + # This exercises the vector VJP path which had the stride mismatch bug + for L in [1, 4, 8]: + with self.subTest(L=L, case="strided_view"): + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + # Create non-contiguous K/V by slicing a larger array + k_full = mx.random.normal((B, H, L * 2, D), dtype=mx.float16) + v_full = mx.random.normal((B, H, L * 2, D), dtype=mx.float16) + k = k_full[:, :, ::2, :] # Strided view, non-contiguous + v = v_full[:, :, ::2, :] # Strided view, non-contiguous + test_vjp([q, k, v], scale) + + # Test with transposed K/V (another form of non-contiguous) + for L in [4, 8]: + with self.subTest(L=L, case="transposed"): + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + # Create by transposing batch and head dims then transposing back + k_orig = mx.random.normal((H, B, L, D), dtype=mx.float16) + v_orig = mx.random.normal((H, B, L, D), dtype=mx.float16) + k = mx.transpose(k_orig, [1, 0, 2, 3]) # Non-contiguous + v = mx.transpose(v_orig, [1, 0, 2, 3]) # Non-contiguous + test_vjp([q, k, v], scale) + + # Test longer sequences (L > 8) with non-contiguous K/V + # This exercises the STEEL VJP path + for L in [32, 64]: + with self.subTest(L=L, case="strided_view_long"): + q = mx.random.normal((B, H, L, D), dtype=mx.float16) + k_full = mx.random.normal((B, H, L * 2, D), dtype=mx.float16) + v_full = mx.random.normal((B, H, L * 2, D), dtype=mx.float16) + k = k_full[:, :, ::2, :] + v = v_full[:, :, ::2, :] + test_vjp([q, k, v], scale) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True)