From 55825ab651c1fdef9d003adfd2f60bcaa0407f28 Mon Sep 17 00:00:00 2001 From: FFBP Date: Tue, 3 Feb 2026 21:52:32 +0800 Subject: [PATCH 1/4] Add .gitignore for build artifacts --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..25752ccb --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.o +test_kernels From 66d70e98c7b8231ddca451484242bd73b34d01be Mon Sep 17 00:00:00 2001 From: FFBP Date: Wed, 4 Feb 2026 09:33:50 +0800 Subject: [PATCH 2/4] kernels --- src/kernels.cu | 370 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 323 insertions(+), 47 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index 74312070..e7f3d71d 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -1,61 +1,337 @@ #include #include +#include +#include +#include +#include #include "../tester/utils.h" -/** - * @brief Computes the trace of a matrix. - * - * The trace of a matrix is defined as the sum of its diagonal elements. - * This function expects a flattened row-major matrix stored in a - * std::vector. If the matrix is not square, the trace will sum up - * elements along the main diagonal up to the smaller of rows or cols. - * - * @tparam T The numeric type of matrix elements (e.g., float, int). - * @param h_input A flattened matrix of size rows * cols. - * @param rows Number of rows in the matrix. - * @param cols Number of columns in the matrix. - * @return The trace (sum of diagonal values) of the matrix. - */ +// ============================================================================ +// TRACE IMPLEMENTATION +// ============================================================================ + +template +__global__ void trace_kernel(const T* d_input, T* d_partial_sums, + size_t rows, size_t cols, size_t min_dim) { + extern __shared__ char shared_mem[]; + T* sdata = reinterpret_cast(shared_mem); + + unsigned int tid = threadIdx.x; + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + + T val = (idx < min_dim) ? d_input[idx * cols + idx] : T(0); + sdata[tid] = val; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + d_partial_sums[blockIdx.x] = sdata[0]; + } +} + template T trace(const std::vector& h_input, size_t rows, size_t cols) { - // TODO: Implement the trace function - return T(-1); + size_t min_dim = (rows < cols) ? rows : cols; + + if (min_dim == 0) return T(0); + + if (min_dim < 1024) { + T sum = T(0); + for (size_t i = 0; i < min_dim; ++i) { + sum += h_input[i * cols + i]; + } + return sum; + } + + T* d_input; + size_t matrix_size = rows * cols * sizeof(T); + cudaMalloc(&d_input, matrix_size); + cudaMemcpy(d_input, h_input.data(), matrix_size, cudaMemcpyHostToDevice); + + int block_size = 256; + int num_blocks = (min_dim + block_size - 1) / block_size; + + T* d_partial_sums; + cudaMalloc(&d_partial_sums, num_blocks * sizeof(T)); + + size_t shared_mem_size = block_size * sizeof(T); + trace_kernel<<>>( + d_input, d_partial_sums, rows, cols, min_dim + ); + + std::vector h_partial_sums(num_blocks); + cudaMemcpy(h_partial_sums.data(), d_partial_sums, + num_blocks * sizeof(T), cudaMemcpyDeviceToHost); + + T result = T(0); + for (int i = 0; i < num_blocks; ++i) { + result += h_partial_sums[i]; + } + + cudaFree(d_input); + cudaFree(d_partial_sums); + + return result; +} + +// ============================================================================ +// FLASH ATTENTION IMPLEMENTATION +// ============================================================================ + +template +struct TypeConverter { + __device__ __forceinline__ static float to_float(T val); + __device__ __forceinline__ static T from_float(float val); +}; + +template <> +struct TypeConverter { + __device__ __forceinline__ static float to_float(float val) { + return val; + } + __device__ __forceinline__ static float from_float(float val) { + return val; + } +}; + +template <> +struct TypeConverter { + __device__ __forceinline__ static float to_float(half val) { + return __half2float(val); + } + __device__ __forceinline__ static half from_float(float val) { + return __float2half(val); + } +}; + +// Flash Attention Kernel +template +__global__ void flash_attention_forward_kernel( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + T* __restrict__ O, + const int batch_size, + const int tgt_seq_len, + const int src_seq_len, + const int query_heads, + const int kv_heads, + const int head_dim, + const bool is_causal, + const float softmax_scale +) { + // 线程索引 + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.y; + const int global_tgt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // 边界检查 + if (global_tgt_idx >= tgt_seq_len) return; + + // GQA 支持 + const int kv_head_idx = (head_idx * kv_heads) / query_heads; + + // 计算偏移量 + const size_t q_batch_stride = tgt_seq_len * query_heads * head_dim; + const size_t q_seq_stride = query_heads * head_dim; + const size_t q_head_stride = head_dim; + + const int q_offset = batch_idx * q_batch_stride + + global_tgt_idx * q_seq_stride + + head_idx * q_head_stride; + + const size_t kv_batch_stride = src_seq_len * kv_heads * head_dim; + const size_t kv_seq_stride = kv_heads * head_dim; + const size_t kv_head_stride = head_dim; + + const int kv_batch_offset = batch_idx * kv_batch_stride + + kv_head_idx * kv_head_stride; + + const int o_offset = batch_idx * q_batch_stride + + global_tgt_idx * q_seq_stride + + head_idx * q_head_stride; + + // 在线 Softmax 状态 + float max_score = -FLT_MAX; + float sum_exp = 0.0f; + + // 动态分配输出累加器(避免固定大小假设) + float* output_acc = new float[head_dim]; + float* q_vec = new float[head_dim]; + + // 初始化 + for (int d = 0; d < head_dim; ++d) { + output_acc[d] = 0.0f; + q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); + } + + // 分块处理 Key 和 Value + const int BLOCK_SIZE = 64; // 固定分块大小 + + for (int src_block_start = 0; src_block_start < src_seq_len; src_block_start += BLOCK_SIZE) { + const int src_block_end = min(src_block_start + BLOCK_SIZE, src_seq_len); + + // 计算当前块的注意力分数 + float block_scores[64]; // 匹配 BLOCK_SIZE + + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int src_idx = src_block_start + i; + + if (src_idx >= src_block_end) { + block_scores[i] = -FLT_MAX; + continue; + } + + // Causal masking + if (is_causal && src_idx > global_tgt_idx) { + block_scores[i] = -FLT_MAX; + continue; + } + + // 计算 Q·K^T + const int k_offset = kv_batch_offset + src_idx * kv_seq_stride; + + float score = 0.0f; + for (int d = 0; d < head_dim; ++d) { + float k_val = TypeConverter::to_float(K[k_offset + d]); + score += q_vec[d] * k_val; + } + + block_scores[i] = score * softmax_scale; + } + + // 在线 Softmax 更新 + float prev_max = max_score; + + // 更新最大值 + for (int i = 0; i < BLOCK_SIZE; ++i) { + if (block_scores[i] > -FLT_MAX) { + max_score = fmaxf(max_score, block_scores[i]); + } + } + + // 重新缩放 + const float correction = expf(prev_max - max_score); + sum_exp *= correction; + + for (int d = 0; d < head_dim; ++d) { + output_acc[d] *= correction; + } + + // 累加当前块的贡献 + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int src_idx = src_block_start + i; + + if (src_idx >= src_block_end || block_scores[i] == -FLT_MAX) { + continue; + } + + const float attn_weight = expf(block_scores[i] - max_score); + sum_exp += attn_weight; + + const int v_offset = kv_batch_offset + src_idx * kv_seq_stride; + + for (int d = 0; d < head_dim; ++d) { + float v_val = TypeConverter::to_float(V[v_offset + d]); + output_acc[d] += attn_weight * v_val; + } + } + } + + // 最终归一化并写出 + const float inv_sum = (sum_exp > 1e-6f) ? (1.0f / sum_exp) : 0.0f; + + for (int d = 0; d < head_dim; ++d) { + O[o_offset + d] = TypeConverter::from_float(output_acc[d] * inv_sum); + } + + // 清理 + delete[] output_acc; + delete[] q_vec; } -/** - * @brief Computes flash attention for given query, key, and value tensors. - * - * @tparam T Data type (float) for input/output tensors - * @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] batch_size Batch dimension size - * @param[in] target_seq_len Target sequence length - * @param[in] src_seq_len Source sequence length - * @param[in] query_heads Number of query attention heads - * @param[in] kv_heads Number of key/value heads (supports grouped query attention) - * @param[in] head_dim Dimension size of each attention head - * @param[in] is_causal Whether to apply causal masking - */ template -void flashAttention(const std::vector& h_q, const std::vector& h_k, - const std::vector& h_v, std::vector& h_o, - int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { - // TODO: Implement the flash attention function +void flashAttention( + const std::vector& h_q, + const std::vector& h_k, + const std::vector& h_v, + std::vector& h_o, + int batch_size, + int target_seq_len, + int src_seq_len, + int query_heads, + int kv_heads, + int head_dim, + bool is_causal +) { + const float softmax_scale = 1.0f / sqrtf(static_cast(head_dim)); + + const size_t q_size = batch_size * target_seq_len * query_heads * head_dim; + const size_t kv_size = batch_size * src_seq_len * kv_heads * head_dim; + const size_t o_size = q_size; + + T *d_q, *d_k, *d_v, *d_o; + cudaMalloc(&d_q, q_size * sizeof(T)); + cudaMalloc(&d_k, kv_size * sizeof(T)); + cudaMalloc(&d_v, kv_size * sizeof(T)); + cudaMalloc(&d_o, o_size * sizeof(T)); + + cudaMemcpy(d_q, h_q.data(), q_size * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k.data(), kv_size * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v.data(), kv_size * sizeof(T), cudaMemcpyHostToDevice); + + const int threads_per_block = 256; + dim3 block_dim(threads_per_block); + dim3 grid_dim( + (target_seq_len + threads_per_block - 1) / threads_per_block, + query_heads, + batch_size + ); + + flash_attention_forward_kernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, + is_causal, softmax_scale + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + cudaMemcpy(h_o.data(), d_o, o_size * sizeof(T), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + cudaFree(d_q); + cudaFree(d_k); + cudaFree(d_v); + cudaFree(d_o); } -// ********************************************************************* -// Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) -// DO NOT MODIFY THIS SECTION -// ********************************************************************* +// ============================================================================ +// 显式模板实例化 +// ============================================================================ + template int trace(const std::vector&, size_t, size_t); template float trace(const std::vector&, size_t, size_t); -template void flashAttention(const std::vector&, const std::vector&, - const std::vector&, std::vector&, - int, int, int, int, int, int, bool); -template void flashAttention(const std::vector&, const std::vector&, - const std::vector&, std::vector&, - int, int, int, int, int, int, bool); + +template void flashAttention( + const std::vector&, const std::vector&, + const std::vector&, std::vector&, + int, int, int, int, int, int, bool +); + +template void flashAttention( + const std::vector&, const std::vector&, + const std::vector&, std::vector&, + int, int, int, int, int, int, bool +); + From ed8da94b44c0da637ad5ec3c1149c0673b5656a4 Mon Sep 17 00:00:00 2001 From: FFBP Date: Thu, 5 Feb 2026 11:33:18 +0800 Subject: [PATCH 3/4] update kernels --- src/kernels.cu | 61 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index e7f3d71d..60208110 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -10,6 +10,20 @@ // ============================================================================ // TRACE IMPLEMENTATION // ============================================================================ +/** + * @brief Computes the trace of a matrix. + * + * The trace of a matrix is defined as the sum of its diagonal elements. + * This function expects a flattened row-major matrix stored in a + * std::vector. If the matrix is not square, the trace will sum up + * elements along the main diagonal up to the smaller of rows or cols. + * + * @tparam T The numeric type of matrix elements (e.g., float, int). + * @param h_input A flattened matrix of size rows * cols. + * @param rows Number of rows in the matrix. + * @param cols Number of columns in the matrix. + * @return The trace (sum of diagonal values) of the matrix. + */ template __global__ void trace_kernel(const T* d_input, T* d_partial_sums, @@ -24,6 +38,7 @@ __global__ void trace_kernel(const T* d_input, T* d_partial_sums, sdata[tid] = val; __syncthreads(); + // 并行 reduction for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; @@ -42,6 +57,7 @@ T trace(const std::vector& h_input, size_t rows, size_t cols) { if (min_dim == 0) return T(0); + // 矩阵维度小,cpu 直接算 if (min_dim < 1024) { T sum = T(0); for (size_t i = 0; i < min_dim; ++i) { @@ -84,6 +100,22 @@ T trace(const std::vector& h_input, size_t rows, size_t cols) { // ============================================================================ // FLASH ATTENTION IMPLEMENTATION // ============================================================================ +/** + * @brief Computes flash attention for given query, key, and value tensors. + * + * @tparam T Data type (float) for input/output tensors + * @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] + * @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] + * @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] + * @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] + * @param[in] batch_size Batch dimension size + * @param[in] target_seq_len Target sequence length + * @param[in] src_seq_len Source sequence length + * @param[in] query_heads Number of query attention heads + * @param[in] kv_heads Number of key/value heads (supports grouped query attention) + * @param[in] head_dim Dimension size of each attention head + * @param[in] is_causal Whether to apply causal masking + */ template struct TypeConverter { @@ -169,12 +201,13 @@ __global__ void flash_attention_forward_kernel( // 初始化 for (int d = 0; d < head_dim; ++d) { output_acc[d] = 0.0f; - q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); + q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); // 加载 Q 的第 i 行作为外循环 } // 分块处理 Key 和 Value const int BLOCK_SIZE = 64; // 固定分块大小 + // 使用 KV 的第 j 行作为内循环 for (int src_block_start = 0; src_block_start < src_seq_len; src_block_start += BLOCK_SIZE) { const int src_block_end = min(src_block_start + BLOCK_SIZE, src_seq_len); @@ -316,22 +349,16 @@ void flashAttention( cudaFree(d_o); } -// ============================================================================ -// 显式模板实例化 -// ============================================================================ - +// ********************************************************************* +// Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) +// DO NOT MODIFY THIS SECTION +// ********************************************************************* template int trace(const std::vector&, size_t, size_t); template float trace(const std::vector&, size_t, size_t); - -template void flashAttention( - const std::vector&, const std::vector&, - const std::vector&, std::vector&, - int, int, int, int, int, int, bool -); - -template void flashAttention( - const std::vector&, const std::vector&, - const std::vector&, std::vector&, - int, int, int, int, int, int, bool -); +template void flashAttention(const std::vector&, const std::vector&, + const std::vector&, std::vector&, + int, int, int, int, int, int, bool); +template void flashAttention(const std::vector&, const std::vector&, + const std::vector&, std::vector&, + int, int, int, int, int, int, bool); From 67565e56c8bfc2265dc71806d51ec5cb2cb28eaf Mon Sep 17 00:00:00 2001 From: FFBP Date: Fri, 6 Feb 2026 13:59:02 +0800 Subject: [PATCH 4/4] updata kernel --- src/kernels.cu | 514 +++++++++++++++++++++++++++---------------------- src/kernels.mu | 345 +++++++++++++++++++++++++++++---- 2 files changed, 592 insertions(+), 267 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index 60208110..bdeccc09 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -117,238 +117,297 @@ T trace(const std::vector& h_input, size_t rows, size_t cols) { * @param[in] is_causal Whether to apply causal masking */ -template -struct TypeConverter { - __device__ __forceinline__ static float to_float(T val); - __device__ __forceinline__ static T from_float(float val); -}; + template + struct TypeConverter; + + template <> + struct TypeConverter { + __device__ __forceinline__ static float to_float(float v) { + return v; + } + __device__ __forceinline__ static float from_float(float v) { + return v; + } + }; + + template <> + struct TypeConverter { + __device__ __forceinline__ static float to_float(half v) { + return __half2float(v); + } + __device__ __forceinline__ static half from_float(float v) { + return __float2half(v); + } + }; -template <> -struct TypeConverter { - __device__ __forceinline__ static float to_float(float val) { - return val; - } - __device__ __forceinline__ static float from_float(float val) { - return val; - } -}; + // 特殊 head_dim 模板特化 + template + __global__ void flash_attention_forward_kernel_fast( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + T* __restrict__ O, + int batch_size, + int tgt_seq_len, + int src_seq_len, + int query_heads, + int kv_heads, + bool is_causal, + float softmax_scale + ) { + // 每个线程处理一个 Q 众多头中的一个头中的 token + int tgt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_idx = blockIdx.y; + int batch_idx = blockIdx.z; + + if (tgt_idx >= tgt_seq_len) return; + + int kv_head_idx = (head_idx * kv_heads) / query_heads; + -template <> -struct TypeConverter { - __device__ __forceinline__ static float to_float(half val) { - return __half2float(val); - } - __device__ __forceinline__ static half from_float(float val) { - return __float2half(val); - } -}; + // 计算内存偏移量 + int q_offset = + batch_idx * tgt_seq_len * query_heads * HEAD_DIM + + tgt_idx * query_heads * HEAD_DIM + + head_idx * HEAD_DIM; + + int kv_batch_offset = + batch_idx * src_seq_len * kv_heads * HEAD_DIM + + kv_head_idx * HEAD_DIM; + + // Q[i] 向量 + float q_vec[HEAD_DIM]; + float out[HEAD_DIM]; + + #pragma unroll + for (int d = 0; d < HEAD_DIM; ++d) { + q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); + out[d] = 0.0f; + } -// Flash Attention Kernel -template -__global__ void flash_attention_forward_kernel( - const T* __restrict__ Q, - const T* __restrict__ K, - const T* __restrict__ V, - T* __restrict__ O, - const int batch_size, - const int tgt_seq_len, - const int src_seq_len, - const int query_heads, - const int kv_heads, - const int head_dim, - const bool is_causal, - const float softmax_scale -) { - // 线程索引 - const int batch_idx = blockIdx.z; - const int head_idx = blockIdx.y; - const int global_tgt_idx = blockIdx.x * blockDim.x + threadIdx.x; - - // 边界检查 - if (global_tgt_idx >= tgt_seq_len) return; - - // GQA 支持 - const int kv_head_idx = (head_idx * kv_heads) / query_heads; - - // 计算偏移量 - const size_t q_batch_stride = tgt_seq_len * query_heads * head_dim; - const size_t q_seq_stride = query_heads * head_dim; - const size_t q_head_stride = head_dim; - - const int q_offset = batch_idx * q_batch_stride + - global_tgt_idx * q_seq_stride + - head_idx * q_head_stride; - - const size_t kv_batch_stride = src_seq_len * kv_heads * head_dim; - const size_t kv_seq_stride = kv_heads * head_dim; - const size_t kv_head_stride = head_dim; - - const int kv_batch_offset = batch_idx * kv_batch_stride + - kv_head_idx * kv_head_stride; - - const int o_offset = batch_idx * q_batch_stride + - global_tgt_idx * q_seq_stride + - head_idx * q_head_stride; - - // 在线 Softmax 状态 - float max_score = -FLT_MAX; - float sum_exp = 0.0f; - - // 动态分配输出累加器(避免固定大小假设) - float* output_acc = new float[head_dim]; - float* q_vec = new float[head_dim]; - - // 初始化 - for (int d = 0; d < head_dim; ++d) { - output_acc[d] = 0.0f; - q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); // 加载 Q 的第 i 行作为外循环 - } - - // 分块处理 Key 和 Value - const int BLOCK_SIZE = 64; // 固定分块大小 - - // 使用 KV 的第 j 行作为内循环 - for (int src_block_start = 0; src_block_start < src_seq_len; src_block_start += BLOCK_SIZE) { - const int src_block_end = min(src_block_start + BLOCK_SIZE, src_seq_len); - - // 计算当前块的注意力分数 - float block_scores[64]; // 匹配 BLOCK_SIZE - - for (int i = 0; i < BLOCK_SIZE; ++i) { - const int src_idx = src_block_start + i; - - if (src_idx >= src_block_end) { - block_scores[i] = -FLT_MAX; - continue; - } - - // Causal masking - if (is_causal && src_idx > global_tgt_idx) { - block_scores[i] = -FLT_MAX; - continue; - } - - // 计算 Q·K^T - const int k_offset = kv_batch_offset + src_idx * kv_seq_stride; - - float score = 0.0f; - for (int d = 0; d < head_dim; ++d) { - float k_val = TypeConverter::to_float(K[k_offset + d]); - score += q_vec[d] * k_val; - } - - block_scores[i] = score * softmax_scale; - } - - // 在线 Softmax 更新 - float prev_max = max_score; - - // 更新最大值 - for (int i = 0; i < BLOCK_SIZE; ++i) { - if (block_scores[i] > -FLT_MAX) { - max_score = fmaxf(max_score, block_scores[i]); - } - } - - // 重新缩放 - const float correction = expf(prev_max - max_score); - sum_exp *= correction; - - for (int d = 0; d < head_dim; ++d) { - output_acc[d] *= correction; - } - - // 累加当前块的贡献 - for (int i = 0; i < BLOCK_SIZE; ++i) { - const int src_idx = src_block_start + i; - - if (src_idx >= src_block_end || block_scores[i] == -FLT_MAX) { - continue; - } - - const float attn_weight = expf(block_scores[i] - max_score); - sum_exp += attn_weight; - - const int v_offset = kv_batch_offset + src_idx * kv_seq_stride; - - for (int d = 0; d < head_dim; ++d) { - float v_val = TypeConverter::to_float(V[v_offset + d]); - output_acc[d] += attn_weight * v_val; - } - } - } - - // 最终归一化并写出 - const float inv_sum = (sum_exp > 1e-6f) ? (1.0f / sum_exp) : 0.0f; - - for (int d = 0; d < head_dim; ++d) { - O[o_offset + d] = TypeConverter::from_float(output_acc[d] * inv_sum); - } - - // 清理 - delete[] output_acc; - delete[] q_vec; -} -template -void flashAttention( - const std::vector& h_q, - const std::vector& h_k, - const std::vector& h_v, - std::vector& h_o, - int batch_size, - int target_seq_len, - int src_seq_len, - int query_heads, - int kv_heads, - int head_dim, - bool is_causal -) { - const float softmax_scale = 1.0f / sqrtf(static_cast(head_dim)); - - const size_t q_size = batch_size * target_seq_len * query_heads * head_dim; - const size_t kv_size = batch_size * src_seq_len * kv_heads * head_dim; - const size_t o_size = q_size; - - T *d_q, *d_k, *d_v, *d_o; - cudaMalloc(&d_q, q_size * sizeof(T)); - cudaMalloc(&d_k, kv_size * sizeof(T)); - cudaMalloc(&d_v, kv_size * sizeof(T)); - cudaMalloc(&d_o, o_size * sizeof(T)); - - cudaMemcpy(d_q, h_q.data(), q_size * sizeof(T), cudaMemcpyHostToDevice); - cudaMemcpy(d_k, h_k.data(), kv_size * sizeof(T), cudaMemcpyHostToDevice); - cudaMemcpy(d_v, h_v.data(), kv_size * sizeof(T), cudaMemcpyHostToDevice); - - const int threads_per_block = 256; - dim3 block_dim(threads_per_block); - dim3 grid_dim( - (target_seq_len + threads_per_block - 1) / threads_per_block, - query_heads, - batch_size - ); - - flash_attention_forward_kernel<<>>( - d_q, d_k, d_v, d_o, - batch_size, target_seq_len, src_seq_len, - query_heads, kv_heads, head_dim, - is_causal, softmax_scale - ); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("CUDA Error: %s\n", cudaGetErrorString(err)); - } - - cudaMemcpy(h_o.data(), d_o, o_size * sizeof(T), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - - cudaFree(d_q); - cudaFree(d_k); - cudaFree(d_v); - cudaFree(d_o); -} + // Online Softmax + Attention + float max_score = -FLT_MAX; + float sum_exp = 0.0f; // Softmax 分母 + + for (int s = 0; s < src_seq_len; ++s) { + if (is_causal && s > tgt_idx) continue; + + int k_offset = kv_batch_offset + s * kv_heads * HEAD_DIM; + + float score = 0.0f; + + // 计算 Q * K ^ T + #pragma unroll + for (int d = 0; d < HEAD_DIM; ++d) { + score += q_vec[d] * + TypeConverter::to_float(K[k_offset + d]); + } + score *= softmax_scale; + + // Online Softmax 更新 + float prev_max = max_score; + max_score = fmaxf(max_score, score); + + float scale = expf(prev_max - max_score); + sum_exp *= scale; + #pragma unroll + for (int d = 0; d < HEAD_DIM; ++d) out[d] *= scale; + + float w = expf(score - max_score); + sum_exp += w; + + int v_offset = kv_batch_offset + s * kv_heads * HEAD_DIM; + #pragma unroll + for (int d = 0; d < HEAD_DIM; ++d) { + out[d] += w * TypeConverter::to_float(V[v_offset + d]); + } + } + + float inv = (sum_exp > 1e-6f) ? 1.0f / sum_exp : 0.0f; + + #pragma unroll + for (int d = 0; d < HEAD_DIM; ++d) { + O[q_offset + d] = + TypeConverter::from_float(out[d] * inv); + } + } + + + // 支持任意 head_dim + template + __global__ void flash_attention_forward_kernel_generic( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + T* __restrict__ O, + int batch_size, + int tgt_seq_len, + int src_seq_len, + int query_heads, + int kv_heads, + int head_dim, + bool is_causal, + float softmax_scale + ) { + extern __shared__ float smem[]; + + int tid = threadIdx.x; + int tgt_idx = blockIdx.x * blockDim.x + tid; + int head_idx = blockIdx.y; + int batch_idx = blockIdx.z; + + if (tgt_idx >= tgt_seq_len) return; + + float* q_vec = smem + tid * head_dim; + float* out = q_vec + blockDim.x * head_dim; + + int kv_head_idx = (head_idx * kv_heads) / query_heads; + + int q_offset = + batch_idx * tgt_seq_len * query_heads * head_dim + + tgt_idx * query_heads * head_dim + + head_idx * head_dim; + + int kv_batch_offset = + batch_idx * src_seq_len * kv_heads * head_dim + + kv_head_idx * head_dim; + + // 加载到 shared memory + for (int d = 0; d < head_dim; ++d) { + q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); + out[d] = 0.0f; + } + + float max_score = -FLT_MAX; + float sum_exp = 0.0f; + + for (int s = 0; s < src_seq_len; ++s) { + if (is_causal && s > tgt_idx) continue; + + int k_offset = kv_batch_offset + s * kv_heads * head_dim; + + float score = 0.0f; + for (int d = 0; d < head_dim; ++d) { + score += q_vec[d] * + TypeConverter::to_float(K[k_offset + d]); + } + score *= softmax_scale; + + float prev_max = max_score; + max_score = fmaxf(max_score, score); + + float scale = expf(prev_max - max_score); + sum_exp *= scale; + for (int d = 0; d < head_dim; ++d) out[d] *= scale; + + float w = expf(score - max_score); + sum_exp += w; + + int v_offset = kv_batch_offset + s * kv_heads * head_dim; + for (int d = 0; d < head_dim; ++d) { + out[d] += w * TypeConverter::to_float(V[v_offset + d]); + } + } + + float inv = (sum_exp > 1e-6f) ? 1.0f / sum_exp : 0.0f; + + for (int d = 0; d < head_dim; ++d) { + O[q_offset + d] = + TypeConverter::from_float(out[d] * inv); + } + } + + template + void flashAttention( + const std::vector& h_q, + const std::vector& h_k, + const std::vector& h_v, + std::vector& h_o, + int batch_size, + int tgt_seq_len, + int src_seq_len, + int query_heads, + int kv_heads, + int head_dim, + bool is_causal + ) { + float softmax_scale = 1.0f / sqrtf((float)head_dim); + + size_t q_size = batch_size * tgt_seq_len * query_heads * head_dim; + size_t kv_size = batch_size * src_seq_len * kv_heads * head_dim; + + T *d_q, *d_k, *d_v, *d_o; + cudaMalloc(&d_q, q_size * sizeof(T)); + cudaMalloc(&d_k, kv_size * sizeof(T)); + cudaMalloc(&d_v, kv_size * sizeof(T)); + cudaMalloc(&d_o, q_size * sizeof(T)); + + cudaMemcpy(d_q, h_q.data(), q_size * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k.data(), kv_size * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v.data(), kv_size * sizeof(T), cudaMemcpyHostToDevice); + + dim3 block(128); + dim3 grid( + (tgt_seq_len + block.x - 1) / block.x, + query_heads, + batch_size + ); + + // 选择 kernel 路径 + bool use_fast = + (head_dim == 16 || head_dim == 32 || + head_dim == 64 || head_dim == 128); + + if (use_fast) { + // 特殊 head_dim + if (head_dim == 16) + flash_attention_forward_kernel_fast<<>>( + d_q,d_k,d_v,d_o, + batch_size,tgt_seq_len,src_seq_len, + query_heads,kv_heads,is_causal,softmax_scale); + else if (head_dim == 32) + flash_attention_forward_kernel_fast<<>>( + d_q,d_k,d_v,d_o, + batch_size,tgt_seq_len,src_seq_len, + query_heads,kv_heads,is_causal,softmax_scale + ); + else if (head_dim == 64) + flash_attention_forward_kernel_fast<<>>( + d_q,d_k,d_v,d_o, + batch_size,tgt_seq_len,src_seq_len, + query_heads,kv_heads,is_causal,softmax_scale + ); + else if (head_dim == 128) + flash_attention_forward_kernel_fast<<>>( + d_q,d_k,d_v,d_o, + batch_size,tgt_seq_len,src_seq_len, + query_heads,kv_heads,is_causal,softmax_scale + ); + } else { + + // 任意 head_dim + size_t smem_size = + 2 * block.x * head_dim * sizeof(float); + + flash_attention_forward_kernel_generic + <<>>( + d_q, d_k, d_v, d_o, + batch_size, tgt_seq_len, src_seq_len, + query_heads, kv_heads, + head_dim, is_causal, softmax_scale); + } + + cudaMemcpy(h_o.data(), d_o, q_size * sizeof(T), + cudaMemcpyDeviceToHost); + + cudaFree(d_q); + cudaFree(d_k); + cudaFree(d_v); + cudaFree(d_o); + } + // ********************************************************************* // Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) // DO NOT MODIFY THIS SECTION @@ -361,4 +420,3 @@ template void flashAttention(const std::vector&, const std::vector template void flashAttention(const std::vector&, const std::vector&, const std::vector&, std::vector&, int, int, int, int, int, int, bool); - diff --git a/src/kernels.mu b/src/kernels.mu index 1fb87770..caa0fc4b 100644 --- a/src/kernels.mu +++ b/src/kernels.mu @@ -1,51 +1,318 @@ -#include +#include #include - +#include +#include +#include +#include +#include #include "../tester/utils.h" -/** - * @brief Computes the trace of a matrix. - * - * The trace of a matrix is defined as the sum of its diagonal elements. - * This function expects a flattened row-major matrix stored in a - * std::vector. If the matrix is not square, the trace will sum up - * elements along the main diagonal up to the smaller of rows or cols. - * - * @tparam T The numeric type of matrix elements (e.g., float, int). - * @param h_input A flattened matrix of size rows * cols. - * @param rows Number of rows in the matrix. - * @param cols Number of columns in the matrix. - * @return The trace (sum of diagonal values) of the matrix. - */ +// ============================================================================ +// TRACE IMPLEMENTATION +// ============================================================================ + +template +__global__ void trace_kernel(const T* d_input, T* d_partial_sums, + size_t rows, size_t cols, size_t min_dim) { + extern __shared__ char shared_mem[]; + T* sdata = reinterpret_cast(shared_mem); + + unsigned int tid = threadIdx.x; + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + + T val = (idx < min_dim) ? d_input[idx * cols + idx] : T(0); + sdata[tid] = val; + __syncthreads(); + + // 并行 reduction + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + d_partial_sums[blockIdx.x] = sdata[0]; + } +} + template T trace(const std::vector& h_input, size_t rows, size_t cols) { - // TODO: Implement the trace function - return T(-1); + size_t min_dim = (rows < cols) ? rows : cols; + if (min_dim == 0) return T(0); + + // 小矩阵CPU直接计算 + if (min_dim < 1024) { + T sum = T(0); + for (size_t i = 0; i < min_dim; ++i) { + sum += h_input[i * cols + i]; + } + return sum; + } + + T* d_input; + size_t matrix_size = rows * cols * sizeof(T); + musaMalloc(&d_input, matrix_size); + musaMemcpy(d_input, h_input.data(), matrix_size, musaMemcpyHostToDevice); + + int block_size = 256; + int num_blocks = (min_dim + block_size - 1) / block_size; + + T* d_partial_sums; + musaMalloc(&d_partial_sums, num_blocks * sizeof(T)); + + size_t shared_mem_size = block_size * sizeof(T); + trace_kernel<<>>( + d_input, d_partial_sums, rows, cols, min_dim + ); + + std::vector h_partial_sums(num_blocks); + musaMemcpy(h_partial_sums.data(), d_partial_sums, + num_blocks * sizeof(T), musaMemcpyDeviceToHost); + + T result = T(0); + for (int i = 0; i < num_blocks; ++i) { + result += h_partial_sums[i]; + } + + musaFree(d_input); + musaFree(d_partial_sums); + + return result; +} + +// ============================================================================ +// FLASH ATTENTION IMPLEMENTATION +// ============================================================================ + +template +struct TypeConverter { + __device__ __forceinline__ static float to_float(T val); + __device__ __forceinline__ static T from_float(float val); +}; + +template <> +struct TypeConverter { + __device__ __forceinline__ static float to_float(float val) { + return val; + } + __device__ __forceinline__ static float from_float(float val) { + return val; + } +}; + +template <> +struct TypeConverter { + __device__ __forceinline__ static float to_float(half val) { + return __half2float(val); + } + __device__ __forceinline__ static half from_float(float val) { + return __float2half(val); + } +}; + +// Flash Attention Forward Kernel +template +__global__ void flash_attention_forward_kernel( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + T* __restrict__ O, + const int batch_size, + const int tgt_seq_len, + const int src_seq_len, + const int query_heads, + const int kv_heads, + const int head_dim, + const bool is_causal, + const float softmax_scale +) { + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.y; + const int global_tgt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_tgt_idx >= tgt_seq_len) return; + + // GQA 支持 + const int kv_head_idx = (head_idx * kv_heads) / query_heads; + + // 计算偏移量 + const size_t q_batch_stride = tgt_seq_len * query_heads * head_dim; + const size_t q_seq_stride = query_heads * head_dim; + const size_t q_head_stride = head_dim; + const int q_offset = batch_idx * q_batch_stride + + global_tgt_idx * q_seq_stride + + head_idx * q_head_stride; + + const size_t kv_batch_stride = src_seq_len * kv_heads * head_dim; + const size_t kv_seq_stride = kv_heads * head_dim; + const size_t kv_head_stride = head_dim; + const int kv_batch_offset = batch_idx * kv_batch_stride + + kv_head_idx * kv_head_stride; + + const int o_offset = batch_idx * q_batch_stride + + global_tgt_idx * q_seq_stride + + head_idx * q_head_stride; + + // 在线 Softmax 状态 + float max_score = -FLT_MAX; + float sum_exp = 0.0f; + + // Q 中的一行 + float output_acc[256]; + float q_vec[256]; + + // 初始化 + for (int d = 0; d < head_dim; ++d) { + output_acc[d] = 0.0f; + q_vec[d] = TypeConverter::to_float(Q[q_offset + d]); + } + + // 分块处理 Key 和 Value + const int BLOCK_SIZE = 64; + + for (int src_block_start = 0; src_block_start < src_seq_len; + src_block_start += BLOCK_SIZE) { + const int src_block_end = min(src_block_start + BLOCK_SIZE, src_seq_len); + + float block_scores[64]; + + // 计算注意力分数 + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int src_idx = src_block_start + i; + if (src_idx >= src_block_end) { + block_scores[i] = -FLT_MAX; + continue; + } + + // Causal masking + if (is_causal && src_idx > global_tgt_idx) { + block_scores[i] = -FLT_MAX; + continue; + } + + const int k_offset = kv_batch_offset + src_idx * kv_seq_stride; + float score = 0.0f; + + // 向量化点积 + #pragma unroll 8 + for (int d = 0; d < head_dim; ++d) { + float k_val = TypeConverter::to_float(K[k_offset + d]); + score += q_vec[d] * k_val; + } + + block_scores[i] = score * softmax_scale; + } + + // 在线 Softmax 更新 + float prev_max = max_score; + + for (int i = 0; i < BLOCK_SIZE; ++i) { + if (block_scores[i] > -FLT_MAX) { + max_score = fmaxf(max_score, block_scores[i]); + } + } + + const float correction = expf(prev_max - max_score); + sum_exp *= correction; + + #pragma unroll 8 + for (int d = 0; d < head_dim; ++d) { + output_acc[d] *= correction; + } + + // 累加当前块的贡献 + for (int i = 0; i < BLOCK_SIZE; ++i) { + const int src_idx = src_block_start + i; + if (src_idx >= src_block_end || block_scores[i] == -FLT_MAX) { + continue; + } + + const float attn_weight = expf(block_scores[i] - max_score); + sum_exp += attn_weight; + + const int v_offset = kv_batch_offset + src_idx * kv_seq_stride; + + #pragma unroll 8 + for (int d = 0; d < head_dim; ++d) { + float v_val = TypeConverter::to_float(V[v_offset + d]); + output_acc[d] += attn_weight * v_val; + } + } + } + + // 最终归一化并写出 + const float inv_sum = (sum_exp > 1e-6f) ? (1.0f / sum_exp) : 0.0f; + + #pragma unroll 8 + for (int d = 0; d < head_dim; ++d) { + O[o_offset + d] = TypeConverter::from_float(output_acc[d] * inv_sum); + } } -/** - * @brief Computes flash attention for given query, key, and value tensors. - * - * @tparam T Data type (float) for input/output tensors - * @param[in] h_q Query tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] h_k Key tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[in] h_v Value tensor of shape [batch_size, src_seq_len, kv_heads, head_dim] - * @param[out] h_o Output attention tensor of shape [batch_size, tgt_seq_len, query_heads, head_dim] - * @param[in] batch_size Batch dimension size - * @param[in] target_seq_len Target sequence length - * @param[in] src_seq_len Source sequence length - * @param[in] query_heads Number of query attention heads - * @param[in] kv_heads Number of key/value heads (supports grouped query attention) - * @param[in] head_dim Dimension size of each attention head - * @param[in] is_causal Whether to apply causal masking - */ template -void flashAttention(const std::vector& h_q, const std::vector& h_k, - const std::vector& h_v, std::vector& h_o, - int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { +void flashAttention( + const std::vector& h_q, + const std::vector& h_k, + const std::vector& h_v, + std::vector& h_o, + int batch_size, + int target_seq_len, + int src_seq_len, + int query_heads, + int kv_heads, + int head_dim, + bool is_causal +) { + const float softmax_scale = 1.0f / sqrtf(static_cast(head_dim)); + + const size_t q_size = batch_size * target_seq_len * query_heads * head_dim; + const size_t kv_size = batch_size * src_seq_len * kv_heads * head_dim; + const size_t o_size = q_size; + + T *d_q, *d_k, *d_v, *d_o; + musaMalloc(&d_q, q_size * sizeof(T)); + musaMalloc(&d_k, kv_size * sizeof(T)); + musaMalloc(&d_v, kv_size * sizeof(T)); + musaMalloc(&d_o, o_size * sizeof(T)); + + musaMemcpy(d_q, h_q.data(), q_size * sizeof(T), musaMemcpyHostToDevice); + musaMemcpy(d_k, h_k.data(), kv_size * sizeof(T), musaMemcpyHostToDevice); + musaMemcpy(d_v, h_v.data(), kv_size * sizeof(T), musaMemcpyHostToDevice); + + const int threads_per_block = 256; + dim3 block_dim(threads_per_block); + dim3 grid_dim( + (target_seq_len + threads_per_block - 1) / threads_per_block, + query_heads, + batch_size + ); + + flash_attention_forward_kernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, + is_causal, softmax_scale + ); + + musaError_t err = musaGetLastError(); + if (err != musaSuccess) { + printf("MUSA Error: %s\n", musaGetErrorString(err)); + } + + musaMemcpy(h_o.data(), d_o, o_size * sizeof(T), musaMemcpyDeviceToHost); + musaDeviceSynchronize(); + + musaFree(d_q); + musaFree(d_k); + musaFree(d_v); + musaFree(d_o); } +// ********************************************************************* +// Explicit Template Instantiations +// ********************************************************************* // ********************************************************************* // Explicit Template Instantiations (REQUIRED FOR LINKING WITH TESTER.O) // DO NOT MODIFY THIS SECTION @@ -57,4 +324,4 @@ template void flashAttention(const std::vector&, const std::vector int, int, int, int, int, int, bool); template void flashAttention(const std::vector&, const std::vector&, const std::vector&, std::vector&, - int, int, int, int, int, int, bool); + int, int, int, int, int, int, bool); \ No newline at end of file