Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 86 additions & 104 deletions mlx/backend/cuda/scaled_dot_product_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,102 +197,96 @@ __global__ void kernel_sdpav_2pass_1(
const T* K,
const T* V,
const T* sinks,
float* partials,
T* partials,
float* sums,
float* maxs,
int blocks,
__grid_constant__ const AttnParams params) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int blocks = 32;

constexpr int v_per_thread = D / BD;

const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
const int inner_k_stride = blocks * int(params.K_strides[2]);
const int inner_v_stride = blocks * int(params.V_strides[2]);

typedef float U;

U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];

__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
U o[v_per_thread] = {0.f};
__shared__ U shared_keys[D];
__shared__ U shared_values[D];

const U scale_log2 = params.scale * 1.44269504089f;

auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);

const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();

// Adjust to thread block and thread
const int batch_idx = blockIdx.z / blocks;
const int block_idx = blockIdx.z % blocks;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;

const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = block_idx * BN + warp_idx;
const int kv_head_idx = blockIdx.x;
const int batch_idx = blockIdx.y;
const int block_idx = blockIdx.z;
const int q_seq_idx = threadIdx.z;
const int q_head_idx = kv_head_idx * params.gqa_factor + threadIdx.y;

Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence

K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
block_idx * params.K_strides[2]; // Sequence

V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
block_idx * params.V_strides[2]; // Sequence

const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s + // Sequence
block_idx; // Block

partials += p_offset * D;
sums += p_offset;
maxs += p_offset;

// Read the query and 0 the output accumulator
// Read the query
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}

PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}

U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
if (sinks && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1.f;
}
auto k = shared_keys + lane_idx;
auto v = shared_values + lane_idx;

// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
for (int i = block_idx; i < params.kL; i += blocks) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}

if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
// Load keys and values into shared memory
if (warp.any(use_key)) {
block.sync();
if (threadIdx.y == 0 && threadIdx.z == 0) {
for (int j = 0; j < v_per_thread; j++) {
k[j] = static_cast<U>(K[v_per_thread * lane_idx + j]);
v[j] = static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
block.sync();
}

if (use_key) {
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
Expand All @@ -314,8 +308,7 @@ __global__ void kernel_sdpav_2pass_1(
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
o[j] = o[j] * factor + exp_score * static_cast<U>(v[j]);
}
}

Expand All @@ -325,67 +318,31 @@ __global__ void kernel_sdpav_2pass_1(
}

if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}

block.sync();

max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());

// Write the sum and new max
if (warp_idx == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
maxs[0] = max_score;
}

// Now we need to aggregate all the outputs
auto ff = exp2f(max_scores[warp_idx] - new_max);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[warp_idx][lane_idx] = o[i] * ff;
block.sync();

if (warp_idx == 0) {
U ot = outputs[0][lane_idx];
PRAGMA_LOOP_UNROLL
for (int j = 1; j < BN; j++) {
ot += outputs[j][lane_idx];
warp.sync();
}
o[i] = ot;
}
block.sync();
}

if (warp_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = o[i];
}
partials[v_per_thread * lane_idx + i] = static_cast<T>(o[i]);
}
}

template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_2(
const float* partials,
const T* partials,
const float* sums,
const float* maxs,
T* O,
int blocks,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int blocks = 32;

constexpr int v_per_thread = D / BD;

typedef float U;

U o[v_per_thread];
U o[v_per_thread] = {0.f};
__shared__ U outputs[BN][BD + 1];

auto block = cg::this_thread_block();
Expand All @@ -395,8 +352,9 @@ __global__ void kernel_sdpav_2pass_2(
const int warp_idx = warp.meta_group_rank();

// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int bh_idx = blockIdx.x;
const int batch_idx = bh_idx / params.H;
const int head_idx = bh_idx - batch_idx * params.H;
const int q_seq_idx = blockIdx.y;

const int p_stride_s = blocks;
Expand All @@ -406,32 +364,49 @@ __global__ void kernel_sdpav_2pass_2(
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s; // Sequence

partials += p_offset * D + warp_idx * D;
partials += p_offset * D + warp_idx * D + v_per_thread * lane_idx;
sums += p_offset;
maxs += p_offset;

O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence

U max_score = maxs[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
U sum_exp_score = 0.f;
U max_score = Limits<U>::finite_min();

PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = partials[v_per_thread * lane_idx + i];
for (int b = 0; b < blocks / BN; ++b) {
max_score = max(max_score, maxs[lane_idx + BN * b]);
}
max_score = cg::reduce(warp, max_score, cg::greater<U>());

for (int b = 0; b < blocks / BN; ++b) {
U factor = exp2f(maxs[lane_idx + BN * b] - max_score);
sum_exp_score += factor * sums[lane_idx + BN * b];
}
sum_exp_score = cg::reduce(warp, sum_exp_score, cg::plus<U>());

for (int b = 0; b < blocks / BN; ++b) {
U factor = exp2f(maxs[warp_idx] - max_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] += factor * static_cast<U>(partials[i]);
}
maxs += BN;
sums += BN;
partials += BN * D;
}

// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
U ot = outputs[warp_idx][lane_idx];
o[i] = cg::reduce(warp, ot, cg::plus<U>());
if (sum_exp_score != 0.f) {
o[i] *= __frcp_rn(sum_exp_score);
}
block.sync();
}

Expand Down Expand Up @@ -550,7 +525,9 @@ void sdpa_vector_2pass_fallback(
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};

// Allocate the intermediates
int blocks = 32;
int n_simds = params.gqa_factor * params.qL;
// TODO tune on different machines
int blocks = 256;

Shape intermediate_shape;
intermediate_shape.reserve(o.ndim() + 1);
Expand All @@ -559,7 +536,7 @@ void sdpa_vector_2pass_fallback(
intermediate_shape.push_back(blocks);
intermediate_shape.push_back(o.shape().back());

array intermediate(intermediate_shape, float32, nullptr, {});
array intermediate(intermediate_shape, q.dtype(), nullptr, {});
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
Expand Down Expand Up @@ -592,8 +569,8 @@ void sdpa_vector_2pass_fallback(
encoder.set_output_array(sums);
encoder.set_output_array(maxs);

dim3 grid_dim(params.H, params.qL, params.B * 32);
dim3 block_dim(8 * 32, 1, 1);
dim3 grid_dim(k.shape(1), params.B, blocks);
dim3 block_dim(32, params.gqa_factor, params.qL);

encoder.add_kernel_node(
kernel,
Expand All @@ -604,9 +581,10 @@ void sdpa_vector_2pass_fallback(
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
gpu_ptr<float>(intermediate),
gpu_ptr<DataType>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
blocks,
params);
}

Expand All @@ -619,18 +597,19 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(maxs);
encoder.set_output_array(o);

dim3 grid_dim(params.H, params.qL, params.B);
dim3 grid_dim(params.B * params.H, params.qL, 1);
dim3 block_dim(1024, 1, 1);

encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
gpu_ptr<float>(intermediate),
gpu_ptr<DataType>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
gpu_ptr<DataType>(o),
blocks,
params);
}
});
Expand Down Expand Up @@ -677,12 +656,15 @@ bool supports_sdpa_vector(
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const int num_query_heads = q.shape(1);
const int num_kv_heads = k.shape(1);
const int gqa_factor = num_query_heads / num_kv_heads;

const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);

const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_vector_config = sdpa_supported_head_dim &&
query_sequence_length < 4 && (query_sequence_length * gqa_factor) <= 32;

return supported_vector_config && !has_arr_mask;
}
Expand Down
Loading