diff --git a/src/kernels.cu b/src/kernels.cu index 74312070..0d6d7d47 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -2,6 +2,7 @@ #include #include "../tester/utils.h" +#define BLOCKSIZE 256 /** * @brief Computes the trace of a matrix. @@ -17,10 +18,46 @@ * @param cols Number of columns in the matrix. * @return The trace (sum of diagonal values) of the matrix. */ + template + __global__ void traceKernel(const T* input,size_t rows, size_t cols,T* output){ + __shared__ T cache[BLOCKSIZE]; + cache[threadIdx.x] = 0; + for(int i = blockIdx.x * blockDim.x + threadIdx.x;i < cols && i 0;total /= 2){ + if(threadIdx.x < total){ + cache[threadIdx.x] += cache[threadIdx.x + total]; + } + __syncthreads(); + } + if(threadIdx.x == 0){ + atomicAdd(output,cache[0]); + } + } + template T trace(const std::vector& h_input, size_t rows, size_t cols) { // TODO: Implement the trace function - return T(-1); + int Blocknumber_x = (cols + BLOCKSIZE - 1) / BLOCKSIZE; + + T* d_input; + size_t size = rows * cols * sizeof(T); + cudaMalloc((void**)&d_input, size); + cudaMemcpy(d_input, h_input.data(), size, cudaMemcpyHostToDevice); + + T* d_output; + cudaMalloc((void**)&d_output, sizeof(T)); + cudaMemset(d_output, 0, sizeof(T)); + + traceKernel<<>>(d_input, rows, cols,d_output); + T h_output; + cudaMemcpy(&h_output, d_output, sizeof(T), cudaMemcpyDeviceToHost); + cudaFree(d_input); + cudaFree(d_output); + return h_output; } /** @@ -45,6 +82,7 @@ void flashAttention(const std::vector& h_q, const std::vector& h_k, int batch_size, int target_seq_len, int src_seq_len, int query_heads, int kv_heads, int head_dim, bool is_causal) { // TODO: Implement the flash attention function + } // ********************************************************************* diff --git a/src/kernels.o b/src/kernels.o new file mode 100644 index 00000000..adb738ef Binary files /dev/null and b/src/kernels.o differ diff --git a/test_kernels b/test_kernels new file mode 100755 index 00000000..309b8504 Binary files /dev/null and b/test_kernels differ diff --git a/tester/tester_nv.o b/tester/tester_nv.o index ab43e278..8e26b60d 100644 Binary files a/tester/tester_nv.o and b/tester/tester_nv.o differ