Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/layers/extensions/inference/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ __forceinline__ __device__ T reciprocal(const T& a)
return make_vec4(reciprocal(a.x), reciprocal(a.y), reciprocal(a.z), reciprocal(a.w));
}

template <typename T1, typename T2>
__forceinline__ __device__ bool4 operator>(const T1& a, const T2& b)
__forceinline__ __device__ bool4 operator>(const float4& a, const float b)
{
return make_vec4(a.x > b, a.y > b, a.z > b, a.w > b);
}

__forceinline__ __device__ bool4 operator>(const Half4& a, const c10::Half& b)
{
return make_vec4(a.x > b, a.y > b, a.z > b, a.w > b);
}
Expand Down
137 changes: 122 additions & 15 deletions src/layers/extensions/inference/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,16 @@ process_with_mask_dispatcher(torch::Tensor& y_res, torch::Tensor& y_q, torch::Te
const torch::Tensor& scales, const torch::Tensor& means,
const torch::Tensor& mask, const float force_zero_thres)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(y);
const bool force_zero = force_zero_thres > 0.f;
const auto launch_info = get_kernel_launch_info<vec_t>(y);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

const bool force_zero = force_zero_thres > 0.f;
auto launch_kernel = [&](auto in_v) {
using in_t = decltype(in_v);
if (force_zero) {
Expand Down Expand Up @@ -160,7 +167,15 @@ template <typename scalar_t, typename vec_t>
__forceinline__ void combine_for_reading_2x_dispatcher(torch::Tensor& out, const torch::Tensor& x,
const torch::Tensor& mask)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x, 2);
const auto launch_info = get_kernel_launch_info<vec_t>(x,2);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
combine_for_reading_2x_kernel<vec_t><<<gridDim, blockDim, 0, stream>>>(out, x, mask, N);
} else {
Expand Down Expand Up @@ -202,7 +217,15 @@ template <typename scalar_t, typename vec_t>
__forceinline__ void restore_y_2x_dispatcher(torch::Tensor& out, const torch::Tensor& y,
const torch::Tensor& means, const torch::Tensor& mask)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(y);
const auto launch_info = get_kernel_launch_info<vec_t>(y);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
restore_y_2x_kernel<vec_t><<<gridDim, blockDim, 0, stream>>>(out, y, means, mask, N);
} else {
Expand Down Expand Up @@ -255,7 +278,15 @@ template <typename scalar_t, typename vec_t>
__forceinline__ void restore_y_4x_dispatcher(torch::Tensor& out, const torch::Tensor& y,
const torch::Tensor& means, const torch::Tensor& mask)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(y);
const auto launch_info = get_kernel_launch_info<vec_t>(y);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
restore_y_4x_kernel<vec_t><<<gridDim, blockDim, 0, stream>>>(out, y, means, mask, N);
} else {
Expand Down Expand Up @@ -312,7 +343,13 @@ build_index_dec_dispatcher(torch::Tensor& out, torch::optional<torch::Tensor>& c
const scalar_t scale_max, const scalar_t log_scale_min,
const scalar_t log_step_recip, const scalar_t skip_thres)
{
auto [blockDim, gridDim, stream, useVec, N] = get_kernel_launch_info_flatten<vec_t>(scales);
const auto launch_info = get_kernel_launch_info_flatten<vec_t>(scales);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const int N = std::get<4>(launch_info);

const bool with_cond = static_cast<float>(skip_thres) > 0.f;

auto launch_kernel = [&](auto in_v, auto out_v, auto cond_out_v) {
Expand Down Expand Up @@ -380,7 +417,13 @@ __forceinline__ void build_index_enc_dispatcher(
const torch::Tensor& scales, const scalar_t scale_min, const scalar_t scale_max,
const scalar_t log_scale_min, const scalar_t log_step_recip, const scalar_t skip_thres)
{
auto [blockDim, gridDim, stream, useVec, N] = get_kernel_launch_info_flatten<vec_t>(scales);
const auto launch_info = get_kernel_launch_info_flatten<vec_t>(scales);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const int N = std::get<4>(launch_info);

const bool with_cond = static_cast<float>(skip_thres) > 0.f;

auto launch_kernel = [&](auto in_v, auto out_v, auto cond_out_v) {
Expand Down Expand Up @@ -458,7 +501,15 @@ __global__ void bias_wsilu_kernel(GPUTensor1D<vec_t> x, const GPUTensor1D<scalar
template <typename scalar_t, typename vec_t>
__forceinline__ void bias_wsilu_dispatcher(torch::Tensor& x, const torch::Tensor& bias)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x);
const auto launch_info = get_kernel_launch_info<vec_t>(x);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
if (biasSafe) {
bias_wsilu_kernel<scalar_t, vec_t, true><<<gridDim, blockDim, 0, stream>>>(x, bias, N, HW);
Expand Down Expand Up @@ -507,7 +558,15 @@ __forceinline__ void bias_shortcut_dispatcher(torch::Tensor& x, const torch::Ten
const torch::Tensor& quant_step,
const torch::Tensor& shortcut)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x);
const auto launch_info = get_kernel_launch_info<vec_t>(x);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
if (biasSafe) {
bias_shortcut_kernel<scalar_t, vec_t, true, with_shortcut, with_quant>
Expand Down Expand Up @@ -563,7 +622,15 @@ __forceinline__ void bias_shortcut_no_inplace_dispatcher(torch::Tensor& out, con
const torch::Tensor& bias,
const torch::Tensor& shortcut)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x);
const auto launch_info = get_kernel_launch_info<vec_t>(x);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
if (biasSafe) {
bias_shortcut_no_inplace_kernel<scalar_t, vec_t, true>
Expand Down Expand Up @@ -608,7 +675,15 @@ template <typename scalar_t, typename vec_t>
__forceinline__ void bias_shortcut_2_dispatcher(torch::Tensor& x, const torch::Tensor& bias,
torch::Tensor& shortcut)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x);
const auto launch_info = get_kernel_launch_info<vec_t>(x);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
if (biasSafe) {
bias_shortcut_2_kernel<scalar_t, vec_t, true>
Expand Down Expand Up @@ -667,7 +742,15 @@ __global__ void bias_wsilu_chunk_add_kernel(GPUTensor1D<vec_t> x, const GPUTenso
template <typename scalar_t, typename vec_t>
__forceinline__ void bias_wsilu_chunk_add_dispatcher(torch::Tensor& x, const torch::Tensor& bias)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x, 2);
const auto launch_info = get_kernel_launch_info<vec_t>(x, 2);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
if (biasSafe) {
bias_wsilu_chunk_add_kernel<scalar_t, vec_t, true>
Expand Down Expand Up @@ -843,7 +926,15 @@ __global__ void round_and_to_int8_kernel(GPUTensor1D<vec_t1> z, GPUTensor1D<vec_
template <typename scalar_t, typename vec_t>
__forceinline__ void round_and_to_int8_dispatcher(torch::Tensor& z, torch::Tensor& z_int8)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(z);
const auto launch_info = get_kernel_launch_info<vec_t>(z);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
round_and_to_int8_kernel<scalar_t, vec_t, char4>
<<<gridDim, blockDim, 0, stream>>>(z, z_int8, N);
Expand Down Expand Up @@ -887,7 +978,15 @@ __forceinline__ void clamp_reciprocal_with_quant_dispatcher(torch::Tensor& q_dec
const torch::Tensor& q_dec,
torch::Tensor& y, const float min_val)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(q_dec);
const auto launch_info = get_kernel_launch_info<vec_t>(q_dec);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
clamp_reciprocal_with_quant_kernel<scalar_t, vec_t><<<gridDim, blockDim, 0, stream>>>(
q_dec_clamp, q_dec, y, static_cast<scalar_t>(min_val), N);
Expand Down Expand Up @@ -929,7 +1028,15 @@ template <typename scalar_t, typename vec_t>
__forceinline__ void add_and_multiply_dispatcher(torch::Tensor& x0, const torch::Tensor& x1,
const torch::Tensor& q)
{
auto [blockDim, gridDim, stream, useVec, biasSafe, N, HW] = get_kernel_launch_info<vec_t>(x0);
const auto launch_info = get_kernel_launch_info<vec_t>(x0);
const dim3& blockDim = std::get<0>(launch_info);
const dim3& gridDim = std::get<1>(launch_info);
const at::cuda::CUDAStream& stream = std::get<2>(launch_info);
const bool useVec = std::get<3>(launch_info);
const bool biasSafe = std::get<4>(launch_info);
const int N = std::get<5>(launch_info);
const int HW = std::get<6>(launch_info);

if (useVec) {
add_and_multiply_kernel<vec_t><<<gridDim, blockDim, 0, stream>>>(x0, x1, q, N);
} else {
Expand Down