diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index 397123cb..78ec3391 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -1,4 +1,5 @@ include(cc_library) +include(cc_binary) cc_library( NAME @@ -72,6 +73,16 @@ cc_library( torch ) +cc_test( + NAME + layernorm_kernels_test + SRCS + layernrom_kernels_test.cu + layernorm_kernels.cu + DEPS + torch + GTest::gtest_main +) + add_subdirectory(flash_attn) add_subdirectory(flash_infer) - diff --git a/src/kernels/layernorm_kernels.cu b/src/kernels/layernorm_kernels.cu index 3e32bd8b..4ce7c38c 100644 --- a/src/kernels/layernorm_kernels.cu +++ b/src/kernels/layernorm_kernels.cu @@ -2,6 +2,7 @@ #include #include "dispatch.h" +#include "layernorm_kernels.h" #include "reduce_kernel_utils.cuh" namespace llm::kernel { @@ -173,6 +174,61 @@ __global__ void layer_norm_kernel(T* __restrict__ out, } } +// equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b +// The mean and standard-deviation are calculated over the last dimension +template <> +__global__ void layer_norm_kernel(half2* __restrict__ out, + const half2* __restrict__ input, + const half2* __restrict__ weight, + const half2* __restrict__ bias, + const float epsilon, + int64_t n) { + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + __shared__ half s_mean; + __shared__ half s_variance; + half2 mean = make_half2(__float2half(0.0f), __float2half(0.0f)); + half2 variance = make_half2(__float2half(0.0f), __float2half(0.0f)); + + // calculate mean of the input. + for (int i = tidx; i < n; i += blockDim.x) { + const int idx = bidx * n + i; + mean = __hadd2(mean, __ldg(&input[idx])); + } + mean = block_reduce_sum(mean); + if (tidx == 0) { + s_mean = __hdiv(__hadd(mean.x, mean.y), __float2half((float)n * 2)); + } + __syncthreads(); + + // calculate variance of the input. + for (int i = tidx; i < n; i += blockDim.x) { + const half2 x = __hsub2(input[bidx * n + i], make_half2(s_mean, s_mean)); + variance = __hadd2(variance, __hmul2(x, x)); + } + variance = block_reduce_sum(variance); + if (tidx == 0) { + s_variance = __hadd(variance.x, variance.y); + s_variance = __hdiv(s_variance, __float2half((float)n * 2)); + s_variance = __hadd(s_variance, __float2half(epsilon)); + s_variance = hrsqrt(s_variance); + } + __syncthreads(); + + for (int i = tidx; i < n; i += blockDim.x) { + const int idx = bidx * n + i; + half2 local_out = __ldg(&input[idx]); + local_out = __hsub2(local_out, make_half2(s_mean, s_mean)); + local_out = __hmul2(local_out, make_half2(s_variance, s_variance)); + local_out = __hmul2(local_out, __ldg(&weight[i])); + if (bias != nullptr) { + local_out = __hadd2(local_out, __ldg(&bias[i])); + } + out[idx] = local_out; + } +} + void layer_norm(torch::Tensor& out, torch::Tensor input, torch::Tensor weight, @@ -197,4 +253,54 @@ void layer_norm(torch::Tensor& out, }); } -} // namespace llm::kernel +template +void invoke_layernorm_kernel(T* out, + const T* input, + const T* weight, + const T* bias, + const float epsilon, + int m, + int n) { + layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +} + +template <> +void invoke_layernorm_kernel(half2* out, + const half2* input, + const half2* weight, + const half2* bias, + const float epsilon, + int m, + int n) { + layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +} +template <> +void invoke_layernorm_kernel(float* out, + const float* input, + const float* weight, + const float* bias, + const float epsilon, + int m, + int n) { + layer_norm_kernel<<>>(out, input, weight, bias, epsilon, n); +} + +template <> +void invoke_layernorm_kernel(half* out, + const half* input, + const half* weight, + const half* bias, + const float epsilon, + int m, + int n) { + int half_n = n / 2; + half2* out_ptr = (half2*)out; + const half2* input_ptr = (const half2*)input; + const half2* weight_ptr = (const half2*)weight; + const half2* bias_ptr = (const half2*)bias; + + dim3 block(std::min(half_n, 1024)); + layer_norm_kernel + <<>>(out_ptr, input_ptr, weight_ptr, bias_ptr, epsilon, half_n); +} +} // namespace llm::kernel \ No newline at end of file diff --git a/src/kernels/layernorm_kernels.h b/src/kernels/layernorm_kernels.h index 496622bb..57e8cf0b 100644 --- a/src/kernels/layernorm_kernels.h +++ b/src/kernels/layernorm_kernels.h @@ -20,4 +20,12 @@ void layer_norm(torch::Tensor& out, torch::Tensor bias, float epsilon); +template +void invoke_layernorm_kernel(T* out, + const T* input, + const T* weight, + const T* bias, + const float epsilon, + int m, + int n); } // namespace llm::kernel diff --git a/src/kernels/layernrom_kernels_test.cu b/src/kernels/layernrom_kernels_test.cu new file mode 100644 index 00000000..c50c18cf --- /dev/null +++ b/src/kernels/layernrom_kernels_test.cu @@ -0,0 +1,54 @@ +#include +#include +#include + +#include + +#include "layernorm_kernels.h" + +TEST(NormalizationKernelTest, LayernormFloatTest) { + float epsilon = 1e-6; + int m = 32; + int n = 512; + + auto out = torch::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA)); + auto input = + torch::randn({m, n}, torch::TensorOptions().device(torch::kCUDA)); + auto weight = torch::randn({n}, torch::TensorOptions().device(torch::kCUDA)); + auto bias = torch::randn({n}, torch::TensorOptions().device(torch::kCUDA)); + auto desired_out = torch::nn::functional::layer_norm( + input, + torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( + bias)); + + llm::kernel::layer_norm(out, input, weight, bias, epsilon); + + EXPECT_TRUE(torch::allclose(out, desired_out, 1e-3, 1e-5)); +} + +TEST(NormalizationKernelTest, LayernormHalfTest) { + float epsilon = 1e-6; + int m = 4; + int n = 512; + + auto out = torch::zeros( + {m, n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto input = torch::randn( + {m, n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto weight = torch::randn( + {n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto bias = torch::randn( + {n}, + torch::TensorOptions().dtype(at::ScalarType::Half).device(torch::kCUDA)); + auto desired_out = torch::nn::functional::layer_norm( + input, + torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias( + bias)); + + llm::kernel::layer_norm(out, input, weight, bias, epsilon); + + EXPECT_TRUE(torch::allclose(out, desired_out, 0.05, 1e-3)); +} \ No newline at end of file diff --git a/src/kernels/reduce_kernel_utils.cuh b/src/kernels/reduce_kernel_utils.cuh index 5e414dff..16e077ad 100644 --- a/src/kernels/reduce_kernel_utils.cuh +++ b/src/kernels/reduce_kernel_utils.cuh @@ -24,6 +24,36 @@ __inline__ __device__ T warp_reduce_sum(T val) { return val; } +// performs a parallel reduction operation across the threads within a single +// warp (32 threads). +// - val: The value to be reduced within a warp. +template <> +__inline__ __device__ half warp_reduce_sum(half val) { + // uses bitwise operations to perform a parallel reduction + // within a warp. The 'mask' is right-shifted by 1 in each iteration + // until it reaches zero, effectively summing all values within the warp. +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val = __hadd(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + } + return val; +} + +// performs a parallel reduction operation across the threads within a single +// warp (32 threads). +// - val: The value to be reduced within a warp. +template <> +__inline__ __device__ half2 warp_reduce_sum(half2 val) { + // uses bitwise operations to perform a parallel reduction + // within a warp. The 'mask' is right-shifted by 1 in each iteration + // until it reaches zero, effectively summing all values within the warp. +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val = __hadd2(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + } + return val; +} + // performs a parallel reduction operation across the threads within a single // warp (32 threads). // - val: The value to be reduced within a warp. @@ -63,6 +93,35 @@ __inline__ __device__ T block_reduce_sum(T val) { return val; } +/* Calculate the sum of all elements in a thread block */ +template <> +__inline__ __device__ half2 block_reduce_sum(half2 val) { + // up to 32 warps in a block + static __shared__ half2 shared[32]; + // lane id in a warp + int lane = threadIdx.x & 0x1f; + // wrap id: threadIdx.x / 32 + int wid = threadIdx.x >> 5; + + // perform a parallel reduction across the threads within each warp + val = warp_reduce_sum(val); + + if (lane == 0) { + // write the sum of each warp to shared memory + shared[wid] = val; + } + // wait for all warps to finish + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) + ? shared[lane] + : make_half2(__float2half(0.0f), __float2half(0.0f)); + val = warp_reduce_sum(val); + return val; +} + /* Calculate the max of all elements in a thread block */ template __inline__ __device__ T block_reduce_max(T val) { @@ -139,9 +198,8 @@ struct TopK { // operator for cub::BlockReduce to get topk across a thread block template -__device__ __forceinline__ TopK reduce_topk_op( - const TopK& a, - const TopK& b) { +__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, + const TopK& b) { TopK res = a; for (int i = 0; i < K; ++i) { res.insert(b.u[i], b.p[i]);