diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 111cfee170..64ac63180e 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -197,28 +197,23 @@ __global__ void kernel_sdpav_2pass_1( const T* K, const T* V, const T* sinks, - float* partials, + T* partials, float* sums, float* maxs, + int blocks, __grid_constant__ const AttnParams params) { - constexpr int BN = 8; constexpr int BD = 32; - constexpr int blocks = 32; - constexpr int v_per_thread = D / BD; - const int inner_k_stride = blocks * BN * int(params.K_strides[2]); - const int inner_v_stride = blocks * BN * int(params.V_strides[2]); + const int inner_k_stride = blocks * int(params.K_strides[2]); + const int inner_v_stride = blocks * int(params.V_strides[2]); typedef float U; U q[v_per_thread]; - U k[v_per_thread]; - U o[v_per_thread]; - - __shared__ U outputs[BN][BD + 1]; - __shared__ U max_scores[BN]; - __shared__ U sum_exp_scores[BN]; + U o[v_per_thread] = {0.f}; + __shared__ U shared_keys[D]; + __shared__ U shared_values[D]; const U scale_log2 = params.scale * 1.44269504089f; @@ -226,34 +221,31 @@ __global__ void kernel_sdpav_2pass_1( auto warp = cg::tiled_partition<32>(block); const int lane_idx = warp.thread_rank(); - const int warp_idx = warp.meta_group_rank(); // Adjust to thread block and thread - const int batch_idx = blockIdx.z / blocks; - const int block_idx = blockIdx.z % blocks; - const int head_idx = blockIdx.x; - const int kv_head_idx = head_idx / params.gqa_factor; - - const int q_seq_idx = blockIdx.y; - const int kv_seq_idx = block_idx * BN + warp_idx; + const int kv_head_idx = blockIdx.x; + const int batch_idx = blockIdx.y; + const int block_idx = blockIdx.z; + const int q_seq_idx = threadIdx.z; + const int q_head_idx = kv_head_idx * params.gqa_factor + threadIdx.y; Q += batch_idx * params.Q_strides[0] + // Batch - head_idx * params.Q_strides[1] + // Head + q_head_idx * params.Q_strides[1] + // Head q_seq_idx * params.Q_strides[2]; // Sequence K += batch_idx * params.K_strides[0] + // Batch kv_head_idx * params.K_strides[1] + // Head - kv_seq_idx * params.K_strides[2]; // Sequence + block_idx * params.K_strides[2]; // Sequence V += batch_idx * params.V_strides[0] + // Batch kv_head_idx * params.V_strides[1] + // Head - kv_seq_idx * params.V_strides[2]; // Sequence + block_idx * params.V_strides[2]; // Sequence const int p_stride_s = blocks; const int p_stride_h = params.qL * p_stride_s; const int p_stride_b = params.H * p_stride_h; const int p_offset = batch_idx * p_stride_b + // Batch - head_idx * p_stride_h + // Head + q_head_idx * p_stride_h + // Head q_seq_idx * p_stride_s + // Sequence block_idx; // Block @@ -261,38 +253,40 @@ __global__ void kernel_sdpav_2pass_1( sums += p_offset; maxs += p_offset; - // Read the query and 0 the output accumulator + // Read the query PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { q[i] = scale_log2 * static_cast(Q[v_per_thread * lane_idx + i]); } - PRAGMA_LOOP_UNROLL - for (int i = 0; i < v_per_thread; i++) { - o[i] = 0.f; - } - U max_score = Limits::finite_min(); U sum_exp_score = 0.f; - if (sinks && warp_idx == 0 && block_idx == 0) { - max_score = M_LOG2E * static_cast(sinks[head_idx]); + if (sinks && block_idx == 0) { + max_score = M_LOG2E * static_cast(sinks[q_head_idx]); sum_exp_score = 1.f; } + auto k = shared_keys + lane_idx; + auto v = shared_values + lane_idx; // For each key - for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { + for (int i = block_idx; i < params.kL; i += blocks) { bool use_key = true; if constexpr (do_causal) { use_key = i <= (params.kL - params.qL + q_seq_idx); } - - if (use_key) { - // Read the key - PRAGMA_LOOP_UNROLL - for (int j = 0; j < v_per_thread; j++) { - k[j] = K[v_per_thread * lane_idx + j]; + // Load keys and values into shared memory + if (warp.any(use_key)) { + block.sync(); + if (threadIdx.y == 0 && threadIdx.z == 0) { + for (int j = 0; j < v_per_thread; j++) { + k[j] = static_cast(K[v_per_thread * lane_idx + j]); + v[j] = static_cast(V[v_per_thread * lane_idx + j]); + } } + block.sync(); + } + if (use_key) { // Compute the i-th score U score = 0.f; PRAGMA_LOOP_UNROLL @@ -314,8 +308,7 @@ __global__ void kernel_sdpav_2pass_1( // Update the output accumulator PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { - o[j] = o[j] * factor + - exp_score * static_cast(V[v_per_thread * lane_idx + j]); + o[j] = o[j] * factor + exp_score * static_cast(v[j]); } } @@ -325,67 +318,31 @@ __global__ void kernel_sdpav_2pass_1( } if (lane_idx == 0) { - max_scores[warp_idx] = max_score; - sum_exp_scores[warp_idx] = sum_exp_score; - } - - block.sync(); - - max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9; - U new_max = cg::reduce(warp, max_score, cg::greater()); - U factor = exp2f(max_score - new_max); - sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f; - sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus()); - - // Write the sum and new max - if (warp_idx == 0) { sums[0] = sum_exp_score; - maxs[0] = new_max; + maxs[0] = max_score; } - // Now we need to aggregate all the outputs - auto ff = exp2f(max_scores[warp_idx] - new_max); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { - outputs[warp_idx][lane_idx] = o[i] * ff; - block.sync(); - - if (warp_idx == 0) { - U ot = outputs[0][lane_idx]; - PRAGMA_LOOP_UNROLL - for (int j = 1; j < BN; j++) { - ot += outputs[j][lane_idx]; - warp.sync(); - } - o[i] = ot; - } - block.sync(); - } - - if (warp_idx == 0) { - PRAGMA_LOOP_UNROLL - for (int i = 0; i < v_per_thread; i++) { - partials[v_per_thread * lane_idx + i] = o[i]; - } + partials[v_per_thread * lane_idx + i] = static_cast(o[i]); } } template __global__ void kernel_sdpav_2pass_2( - const float* partials, + const T* partials, const float* sums, const float* maxs, T* O, + int blocks, __grid_constant__ const AttnParams params) { constexpr int BN = 32; constexpr int BD = 32; - constexpr int blocks = 32; - constexpr int v_per_thread = D / BD; typedef float U; - U o[v_per_thread]; + U o[v_per_thread] = {0.f}; __shared__ U outputs[BN][BD + 1]; auto block = cg::this_thread_block(); @@ -395,8 +352,9 @@ __global__ void kernel_sdpav_2pass_2( const int warp_idx = warp.meta_group_rank(); // Adjust to thread block and thread - const int batch_idx = blockIdx.z; - const int head_idx = blockIdx.x; + const int bh_idx = blockIdx.x; + const int batch_idx = bh_idx / params.H; + const int head_idx = bh_idx - batch_idx * params.H; const int q_seq_idx = blockIdx.y; const int p_stride_s = blocks; @@ -406,7 +364,7 @@ __global__ void kernel_sdpav_2pass_2( head_idx * p_stride_h + // Head q_seq_idx * p_stride_s; // Sequence - partials += p_offset * D + warp_idx * D; + partials += p_offset * D + warp_idx * D + v_per_thread * lane_idx; sums += p_offset; maxs += p_offset; @@ -414,15 +372,29 @@ __global__ void kernel_sdpav_2pass_2( head_idx * params.O_strides[1] + // Head q_seq_idx * params.O_strides[2]; // Sequence - U max_score = maxs[lane_idx]; - U new_max = cg::reduce(warp, max_score, cg::greater()); - U factor = exp2f(max_score - new_max); - U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus()); - sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score); + U sum_exp_score = 0.f; + U max_score = Limits::finite_min(); - PRAGMA_LOOP_UNROLL - for (int i = 0; i < v_per_thread; i++) { - o[i] = partials[v_per_thread * lane_idx + i]; + for (int b = 0; b < blocks / BN; ++b) { + max_score = max(max_score, maxs[lane_idx + BN * b]); + } + max_score = cg::reduce(warp, max_score, cg::greater()); + + for (int b = 0; b < blocks / BN; ++b) { + U factor = exp2f(maxs[lane_idx + BN * b] - max_score); + sum_exp_score += factor * sums[lane_idx + BN * b]; + } + sum_exp_score = cg::reduce(warp, sum_exp_score, cg::plus()); + + for (int b = 0; b < blocks / BN; ++b) { + U factor = exp2f(maxs[warp_idx] - max_score); + PRAGMA_LOOP_UNROLL + for (int i = 0; i < v_per_thread; i++) { + o[i] += factor * static_cast(partials[i]); + } + maxs += BN; + sums += BN; + partials += BN * D; } // Now we need to aggregate all the outputs @@ -430,8 +402,11 @@ __global__ void kernel_sdpav_2pass_2( for (int i = 0; i < v_per_thread; i++) { outputs[lane_idx][warp_idx] = o[i]; block.sync(); - U ot = outputs[warp_idx][lane_idx] * factor; - o[i] = cg::reduce(warp, ot, cg::plus()) * sum_exp_score; + U ot = outputs[warp_idx][lane_idx]; + o[i] = cg::reduce(warp, ot, cg::plus()); + if (sum_exp_score != 0.f) { + o[i] *= __frcp_rn(sum_exp_score); + } block.sync(); } @@ -550,7 +525,9 @@ void sdpa_vector_2pass_fallback( /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; // Allocate the intermediates - int blocks = 32; + int n_simds = params.gqa_factor * params.qL; + // TODO tune on different machines + int blocks = 256; Shape intermediate_shape; intermediate_shape.reserve(o.ndim() + 1); @@ -559,7 +536,7 @@ void sdpa_vector_2pass_fallback( intermediate_shape.push_back(blocks); intermediate_shape.push_back(o.shape().back()); - array intermediate(intermediate_shape, float32, nullptr, {}); + array intermediate(intermediate_shape, q.dtype(), nullptr, {}); intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); @@ -592,8 +569,8 @@ void sdpa_vector_2pass_fallback( encoder.set_output_array(sums); encoder.set_output_array(maxs); - dim3 grid_dim(params.H, params.qL, params.B * 32); - dim3 block_dim(8 * 32, 1, 1); + dim3 grid_dim(k.shape(1), params.B, blocks); + dim3 block_dim(32, params.gqa_factor, params.qL); encoder.add_kernel_node( kernel, @@ -604,9 +581,10 @@ void sdpa_vector_2pass_fallback( gpu_ptr(k), gpu_ptr(v), sinks ? gpu_ptr(*sinks) : nullptr, - gpu_ptr(intermediate), + gpu_ptr(intermediate), gpu_ptr(sums), gpu_ptr(maxs), + blocks, params); } @@ -619,7 +597,7 @@ void sdpa_vector_2pass_fallback( encoder.set_input_array(maxs); encoder.set_output_array(o); - dim3 grid_dim(params.H, params.qL, params.B); + dim3 grid_dim(params.B * params.H, params.qL, 1); dim3 block_dim(1024, 1, 1); encoder.add_kernel_node( @@ -627,10 +605,11 @@ void sdpa_vector_2pass_fallback( grid_dim, block_dim, 0, - gpu_ptr(intermediate), + gpu_ptr(intermediate), gpu_ptr(sums), gpu_ptr(maxs), gpu_ptr(o), + blocks, params); } }); @@ -677,12 +656,15 @@ bool supports_sdpa_vector( const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); const int key_sequence_length = k.shape(2); + const int num_query_heads = q.shape(1); + const int num_kv_heads = k.shape(1); + const int gqa_factor = num_query_heads / num_kv_heads; const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); - const bool supported_vector_config = - sdpa_supported_head_dim && query_sequence_length < 4; + const bool supported_vector_config = sdpa_supported_head_dim && + query_sequence_length < 4 && (query_sequence_length * gqa_factor) <= 32; return supported_vector_config && !has_arr_mask; }