From 0b06799c2403b3a10d6dc19450e8f50ec4837f09 Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Tue, 19 Feb 2019 14:47:36 -0800 Subject: [PATCH 1/7] adding unit-test example of per-channel scale and zero-point --- test/convolution-operator-tester.h | 302 ++++++++++++++++++++++++----- 1 file changed, 254 insertions(+), 48 deletions(-) diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h index 3f0a06a..e5dcb85 100644 --- a/test/convolution-operator-tester.h +++ b/test/convolution-operator-tester.h @@ -348,15 +348,19 @@ class ConvolutionOperatorTester { auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); auto u8rng = std::bind(std::uniform_int_distribution(), rng); - std::vector input(batchSize() * ((inputHeight() * inputWidth() - 1) * inputPixelStride() + groups() * groupInputChannels()) + 8); - std::vector kernel(groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * groupInputChannels()); + std::vector input( + batchSize() * ((inputHeight() * inputWidth() - 1) * inputPixelStride() + groups() * groupInputChannels()) + 8); + std::vector kernel( + groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * groupInputChannels()); std::vector bias(groups() * groupOutputChannels()); - std::vector output(batchSize() * ((outputHeight() * outputWidth() - 1) * outputPixelStride() + groups() * groupOutputChannels())); + std::vector output( + batchSize() * ((outputHeight() * outputWidth() - 1) * outputPixelStride() + groups() * groupOutputChannels())); std::vector accumulators(batchSize() * outputHeight() * outputWidth() * groups() * groupOutputChannels()); const uint8_t* inputPtr = input.data() + 8; const uint8_t inputZeroPoint = 127; const uint8_t kernelZeroPoint = 127; + const float kernelScale = 1.0f; for (size_t iteration = 0; iteration < iterations(); iteration++) { std::generate(input.begin(), input.end(), std::ref(u8rng)); @@ -370,8 +374,9 @@ class ConvolutionOperatorTester { for (size_t ox = 0; ox < outputWidth(); ox++) { for (size_t g = 0; g < groups(); g++) { for (size_t oc = 0; oc < groupOutputChannels(); oc++) { - accumulators[(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] = - bias[g * groupOutputChannels() + oc]; + accumulators + [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] = + bias[g * groupOutputChannels() + oc]; } } } @@ -389,9 +394,20 @@ class ConvolutionOperatorTester { for (size_t g = 0; g < groups(); g++) { for (size_t oc = 0; oc < groupOutputChannels(); oc++) { for (size_t ic = 0; ic < groupInputChannels(); ic++) { - accumulators[(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] += - (int32_t(inputPtr[((i * inputHeight() + iy) * inputWidth() + ix) * inputPixelStride() + g * groupInputChannels() + ic]) - int32_t(inputZeroPoint)) * - (int32_t(kernel[(((g * groupOutputChannels() + oc) * kernelHeight() + ky) * kernelWidth() + kx) * groupInputChannels() + ic]) - int32_t(kernelZeroPoint)); + accumulators + [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * + groupOutputChannels() + + oc] += + (int32_t(inputPtr + [((i * inputHeight() + iy) * inputWidth() + ix) * inputPixelStride() + + g * groupInputChannels() + ic]) - + int32_t(inputZeroPoint)) * + (int32_t(kernel + [(((g * groupOutputChannels() + oc) * kernelHeight() + ky) * kernelWidth() + + kx) * + groupInputChannels() + + ic]) - + int32_t(kernelZeroPoint)); } } } @@ -406,44 +422,228 @@ class ConvolutionOperatorTester { const int32_t accumulatorsMax = *std::max_element(accumulators.cbegin(), accumulators.cend()); const double outputScale = double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; - const uint8_t outputZeroPoint = uint8_t(std::max(std::min( - lrint(127.5 - 0.5 * double(accumulatorsMin + accumulatorsMax) / outputScale), - long(std::numeric_limits::max())), long(std::numeric_limits::min()))); + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint(127.5 - 0.5 * double(accumulatorsMin + accumulatorsMax) / outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); ASSERT_EQ(qnnp_status_success, qnnp_initialize()); qnnp_operator_t convolution = nullptr; + ASSERT_EQ( + qnnp_status_success, + qnnp_create_convolution2d_nhwc_q8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + kernelHeight(), + kernelWidth(), + subsamplingHeight(), + subsamplingWidth(), + dilationHeight(), + dilationWidth(), + groups(), + groupInputChannels(), + groupOutputChannels(), + inputZeroPoint, + 1.0f /* input scale */, + kernelZeroPoint, + kernelScale, + kernel.data(), + bias.data(), + outputZeroPoint, + outputScale, + qmin(), + qmax(), + 0, + &convolution)); + + ASSERT_EQ( + qnnp_status_success, + qnnp_setup_convolution2d_nhwc_q8( + convolution, + batchSize(), + inputHeight(), + inputWidth(), + inputPtr, + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ(qnnp_status_success, qnnp_run_operator(convolution, nullptr /* thread pool */)); + + ASSERT_EQ(qnnp_status_success, qnnp_delete_operator(convolution)); + convolution = nullptr; + + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t c = 0; c < groupOutputChannels(); c++) { + const double scaledAccumulator = + accumulators + [(((i * outputHeight() + y) * outputWidth() + x) * groups() + g) * groupOutputChannels() + c] * + kernelScale / outputScale; + const double clampedAccumulator = std::max( + std::min(scaledAccumulator, double(qmax()) - double(outputZeroPoint)), + double(qmin()) - double(outputZeroPoint)); + ASSERT_NEAR( + clampedAccumulator, + (int32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * outputPixelStride() + + g * groupOutputChannels() + c]) - + outputZeroPoint), + 0.9) + << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; + } + } + } + } + } + } + } + + void testQ8_perChannel() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + batchSize() * ((inputHeight() * inputWidth() - 1) * inputPixelStride() + groups() * groupInputChannels()) + 8); + std::vector kernel( + groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * groupInputChannels()); + std::vector bias(groups() * groupOutputChannels()); + std::vector output( + batchSize() * ((outputHeight() * outputWidth() - 1) * outputPixelStride() + groups() * groupOutputChannels())); + std::vector accumulators(batchSize() * outputHeight() * outputWidth() * groups() * groupOutputChannels()); + + const uint8_t* inputPtr = input.data() + 8; + const uint8_t inputZeroPoint = 127; + const uint8_t kernelZeroPointFixed = 127; + const float kernelScaleFixed = 1.0f; + std::vector kernelScale(groups() * groupOutputChannels()); + std::vector kernelZeroPoint(groups() * groupOutputChannels()); + std::fill(kernelScale.begin(), kernelScale.end(), kernelScaleFixed); + std::fill(kernelZeroPoint.begin(), kernelZeroPoint.end(), kernelZeroPointFixed); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(output.begin(), output.end(), 0xA5); + std::fill(accumulators.begin(), accumulators.end(), 0); + + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t oc = 0; oc < groupOutputChannels(); oc++) { + accumulators + [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] = + bias[g * groupOutputChannels() + oc]; + } + } + } + } + } + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t ky = 0; ky < kernelHeight(); ky++) { + const size_t iy = oy * subsamplingHeight() + ky * dilationHeight() - paddingTop(); + if (iy < inputHeight()) { + for (size_t kx = 0; kx < kernelWidth(); kx++) { + const size_t ix = ox * subsamplingWidth() + kx * dilationWidth() - paddingLeft(); + if (ix < inputWidth()) { + for (size_t g = 0; g < groups(); g++) { + for (size_t oc = 0; oc < groupOutputChannels(); oc++) { + for (size_t ic = 0; ic < groupInputChannels(); ic++) { + accumulators + [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * + groupOutputChannels() + + oc] += + (int32_t(inputPtr + [((i * inputHeight() + iy) * inputWidth() + ix) * inputPixelStride() + + g * groupInputChannels() + ic]) - + int32_t(inputZeroPoint)) * + (int32_t(kernel + [(((g * groupOutputChannels() + oc) * kernelHeight() + ky) * kernelWidth() + + kx) * + groupInputChannels() + + ic]) - + int32_t(kernelZeroPoint[g * groupOutputChannels() + oc])); + } + } + } + } + } + } + } + } + } + } + const int32_t accumulatorsMin = *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulatorsMax = *std::max_element(accumulators.cbegin(), accumulators.cend()); + + const double outputScale = double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint(127.5 - 0.5 * double(accumulatorsMin + accumulatorsMax) / outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + ASSERT_EQ(qnnp_status_success, qnnp_initialize()); + qnnp_operator_t convolution = nullptr; - ASSERT_EQ(qnnp_status_success, - qnnp_create_convolution2d_nhwc_q8( - paddingTop(), paddingRight(), paddingBottom(), paddingLeft(), - kernelHeight(), kernelWidth(), - subsamplingHeight(), subsamplingWidth(), - dilationHeight(), dilationWidth(), - groups(), groupInputChannels(), groupOutputChannels(), - inputZeroPoint, 1.0f /* input scale */, - kernelZeroPoint, 1.0f /* kernel scale */, - kernel.data(), bias.data(), - outputZeroPoint, outputScale, qmin(), qmax(), - 0, &convolution)); - - ASSERT_EQ(qnnp_status_success, - qnnp_setup_convolution2d_nhwc_q8( - convolution, - batchSize(), - inputHeight(), - inputWidth(), - inputPtr, - inputPixelStride(), - output.data(), - outputPixelStride(), - nullptr /* thread pool */)); - - ASSERT_EQ(qnnp_status_success, - qnnp_run_operator(convolution, nullptr /* thread pool */)); - - ASSERT_EQ(qnnp_status_success, - qnnp_delete_operator(convolution)); + ASSERT_EQ( + qnnp_status_success, + qnnp_create_convolution2d_nhwc_q8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + kernelHeight(), + kernelWidth(), + subsamplingHeight(), + subsamplingWidth(), + dilationHeight(), + dilationWidth(), + groups(), + groupInputChannels(), + groupOutputChannels(), + inputZeroPoint, + 1.0f /* input scale */, + kernelZeroPointFixed, + kernelScaleFixed, + kernel.data(), + bias.data(), + outputZeroPoint, + outputScale, + qmin(), + qmax(), + 0, + &convolution)); + + ASSERT_EQ( + qnnp_status_success, + qnnp_setup_convolution2d_nhwc_q8( + convolution, + batchSize(), + inputHeight(), + inputWidth(), + inputPtr, + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ(qnnp_status_success, qnnp_run_operator(convolution, nullptr /* thread pool */)); + + ASSERT_EQ(qnnp_status_success, qnnp_delete_operator(convolution)); convolution = nullptr; for (size_t i = 0; i < batchSize(); i++) { @@ -452,14 +652,20 @@ class ConvolutionOperatorTester { for (size_t g = 0; g < groups(); g++) { for (size_t c = 0; c < groupOutputChannels(); c++) { const double scaledAccumulator = - accumulators[(((i * outputHeight() + y) * outputWidth() + x) * groups() + g) * groupOutputChannels() + c] / outputScale; - const double clampedAccumulator = std::max(std::min(scaledAccumulator, - double(qmax()) - double(outputZeroPoint)), - double(qmin()) - double(outputZeroPoint)); + accumulators + [(((i * outputHeight() + y) * outputWidth() + x) * groups() + g) * groupOutputChannels() + c] * + kernelScale[g * groupOutputChannels() + c] / outputScale; + const double clampedAccumulator = std::max( + std::min(scaledAccumulator, double(qmax()) - double(outputZeroPoint)), + double(qmin()) - double(outputZeroPoint)); ASSERT_NEAR( - clampedAccumulator, - (int32_t(output[((i * outputHeight() + y) * outputWidth() + x) * outputPixelStride() + g * groupOutputChannels() + c]) - outputZeroPoint), - 0.9) << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; + clampedAccumulator, + (int32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * outputPixelStride() + + g * groupOutputChannels() + c]) - + outputZeroPoint), + 0.9) + << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; } } } From 1909ceb40e3f9fd29770e1c10a1f895a3fadf51d Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Thu, 21 Feb 2019 10:51:47 -0800 Subject: [PATCH 2/7] Adding q8gemm neon ukernel with weights quantization parameters per output channel --- src/q8gemm/4x8-neon.c | 356 +++++++++++++++++++++++++++++++++ src/qnnpack/pack.h | 39 ++++ src/qnnpack/params.h | 15 ++ src/qnnpack/q8gemm.h | 15 ++ src/qnnpack/requantization.h | 104 ++++++++++ test/gemm-microkernel-tester.h | 124 +++++++++++- test/q8gemm.cc | 296 ++++++++++++++++++++++++++- 7 files changed, 947 insertions(+), 2 deletions(-) diff --git a/src/q8gemm/4x8-neon.c b/src/q8gemm/4x8-neon.c index ce05b91..fc47259 100644 --- a/src/q8gemm/4x8-neon.c +++ b/src/q8gemm/4x8-neon.c @@ -362,3 +362,359 @@ void q8gemm_ukernel_4x8__neon( } } } + +void q8gemm_per_channel_ukernel_4x8__neon( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union qnnp_conv_quantization_params quantization_params[restrict static 1], + size_t kernel_quantization_params_offset) +{ + int32x4_t vacc0x0123 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); + int32x4_t vacc0x4567 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride); + if (mr != 4) { + a3 = a2; + } + + const uint8x8_t vb_zero_point = vld1_u8((const uint8_t*) &quantization_params->neon.kernel_zero_point_v[kernel_quantization_params_offset]); + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); a0 += 8; + const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); + const uint8x8_t va1 = vld1_u8(a1); a1 += 8; + const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); + const uint8x8_t va2 = vld1_u8(a2); a2 += 8; + const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); + const uint8x8_t va3 = vld1_u8(a3); a3 += 8; + const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + + const uint8x8_t vb01234567c7 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c7 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); + const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); + const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); + const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + if (k >= 2) { + const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + if (k >= 3) { + const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + if (k >= 4) { + const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + if (k >= 5) { + const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + if (k >= 6) { + const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + if (k >= 7) { + const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + } + } + } + } + } + } + } + + const int32x4_t vmultiplier0x0123 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset]); + const int32x4_t vmultiplier0x4567 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset + 4]); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier0x0123); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier0x4567); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier0x0123); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier0x4567); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier0x0123); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier0x4567); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier0x0123); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier0x4567); + + const int32x4_t vright_shift_0x0123 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset]); + const int32x4_t vright_shift_0x4567 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset + 4]); + const int32x4_t vzero_shift_mask_0x0123 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x0123, vmovq_n_s32(0))); + const int32x4_t vzero_shift_mask_0x4567 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x4567, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask_0x0123), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask_0x4567), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask_0x0123), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask_0x4567), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask_0x0123), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask_0x4567), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask_0x0123), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask_0x4567), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift_0x0123); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift_0x4567); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift_0x0123); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift_0x4567); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift_0x0123); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift_0x4567); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift_0x0123); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift_0x4567); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); +#endif + const uint8x16_t voutput_min = vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4; + vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2; + vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + } + } +} diff --git a/src/qnnpack/pack.h b/src/qnnpack/pack.h index 836b03c..885ed10 100644 --- a/src/qnnpack/pack.h +++ b/src/qnnpack/pack.h @@ -48,6 +48,45 @@ static inline void pack_q8gemm_w( } } +static inline void pack_q8gemm_w_per_channel( + size_t nc, // num output channels + size_t kc, // num input channels + uint32_t nr, // kernel-n-block-size + uint32_t np, // packed-n + uint32_t kr, + uint8_t izp, + uint8_t* kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) +{ + for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { + const size_t nr_block_size = min(nc - nr_block_start, nr); + int32_t* packed_b = (int32_t*) packed_w; + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { + *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + + (int32_t) kc * (int32_t) izp * (int32_t) kzp[nr_block_start + nr_block_offset]; + packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t)); + } + packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t)); + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { + int32_t ksum = 0; + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) { + const uint8_t kv = k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)]; + ksum += (int32_t) kv; + *((uint8_t*) packed_w) = kv; + packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t)); + } + packed_b[nr_block_offset] -= ksum * (int32_t) izp; + packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t)); + } + packed_w = (void*) ((uintptr_t) packed_w + ((nr - nr_block_size) & (np - 1)) * kr * sizeof(uint8_t)); + } + } +} + static inline void pack_q8conv_w( size_t n, size_t ks, diff --git a/src/qnnpack/params.h b/src/qnnpack/params.h index e30e237..1d14530 100644 --- a/src/qnnpack/params.h +++ b/src/qnnpack/params.h @@ -145,6 +145,9 @@ union qnnp_conv_quantization_params { int16_t output_zero_point; uint8_t output_max; uint8_t output_min; + uint8_t* kernel_zero_point_v; + int32_t* multiplier_v; + int32_t* right_shift_v; } neon; #endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 @@ -275,6 +278,18 @@ typedef void (*q8gemm_ukernel_function)( size_t c_stride, const union qnnp_conv_quantization_params* quantization_params); +typedef void (*q8gemm_per_channel_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const uint8_t* a, + size_t a_stride, + const void* w, + uint8_t* c, + size_t c_stride, + const union qnnp_conv_quantization_params* quantization_params, + size_t kernel_quantization_params_offset); + typedef void (*q8conv_ukernel_function)( size_t mr, size_t nr, diff --git a/src/qnnpack/q8gemm.h b/src/qnnpack/q8gemm.h index f5ac117..f0c8841 100644 --- a/src/qnnpack/q8gemm.h +++ b/src/qnnpack/q8gemm.h @@ -43,6 +43,21 @@ DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_8x8__aarch64_neon) DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_2x4c8__sse2) DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_4x4c2__sse2) +#define DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(fn_name) \ + QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const uint8_t* a, \ + size_t a_stride, \ + const void* w, \ + uint8_t* c, \ + size_t c_stride, \ + const union qnnp_conv_quantization_params* quantization_params, \ + size_t kernel_quantization_params_offset); + +DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_per_channel_ukernel_4x8__neon) + #define DECLARE_Q8GEMM_XZP_UKERNEL_FUNCTION(fn_name) \ QNNP_INTERNAL void fn_name( \ size_t mr, \ diff --git a/src/qnnpack/requantization.h b/src/qnnpack/requantization.h index 2b5fe6f..6c97ac6 100644 --- a/src/qnnpack/requantization.h +++ b/src/qnnpack/requantization.h @@ -197,6 +197,110 @@ static inline union qnnp_conv_quantization_params qnnp_compute_conv_quantization return params; } +static inline union qnnp_conv_quantization_params qnnp_compute_conv_quantization_params_per_channel( + uint8_t input_zero_point, + size_t kernel_params_size, // should be identical to group_output_channels + uint8_t* kernel_zero_point_v, + const float* scale_v, + int32_t* multiplier_v, // pre-allocated in operator-create + int32_t* right_shift_v, // pre-allocated in operator-create + uint8_t output_zero_point, + uint8_t output_min, + uint8_t output_max) +{ + const float scale = *scale_v; + const uint8_t kernel_zero_point = *kernel_zero_point_v; + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + union qnnp_conv_quantization_params params; + #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point; + params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point; + } + params.sse2.multiplier[0] = multiplier; + params.sse2.multiplier[1] = multiplier; + params.sse2.multiplier[2] = multiplier; + params.sse2.multiplier[3] = multiplier; + params.sse2.rounding[0] = UINT64_C(0x40000000); + params.sse2.rounding[1] = UINT64_C(0x40000000); + params.sse2.remainder_mask[0] = (int32_t) remainder_mask; + params.sse2.remainder_mask[1] = (int32_t) remainder_mask; + params.sse2.remainder_mask[2] = (int32_t) remainder_mask; + params.sse2.remainder_mask[3] = (int32_t) remainder_mask; + params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold; + params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold; + params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold; + params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold; + params.sse2.shift[0] = (uint64_t) (uint32_t) shift; + params.sse2.shift[1] = (uint64_t) (uint32_t) shift; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point; + } + for (uint32_t i = 0; i < 16; i++) { + params.sse2.output_max[i] = output_max; + params.sse2.output_min[i] = output_min; + } + #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point; + params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point; + params.neon.multiplier = multiplier; + params.neon.right_shift = -shift; + params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point; + params.neon.output_max = output_max; + params.neon.output_min = output_min; + params.neon.kernel_zero_point_v = kernel_zero_point_v; + params.neon.multiplier_v = multiplier_v; + params.neon.right_shift_v = right_shift_v; + for (uint32_t i = 0; i < kernel_params_size; ++i) { + const float s = scale_v[i]; + const uint8_t kzp = kernel_zero_point_v[i]; + /* Compute requantization parameters */ + const uint32_t sbits = fp32_to_bits(s); + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t m = (int32_t)(((sbits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(m >= INT32_C(0x40000000)); + assert(m <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t rs = 127 + 31 - 32 - (fp32_to_bits(s) >> 23); + assert(rs >= 0); + assert(rs < 32); + params.neon.multiplier_v[i] = m; + params.neon.right_shift_v[i] = -rs; + } + + #else + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point; + params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point; + params.scalar.multiplier = multiplier; + params.scalar.remainder_mask = (int32_t) remainder_mask; + params.scalar.remainder_threshold = (int32_t) remainder_threshold; + params.scalar.shift = (uint32_t) shift; + params.scalar.output_min_less_zero_point = + (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point; + params.scalar.output_max_less_zero_point = + (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point; + params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point; + #endif + return params; +} + static inline union qnnp_avgpool_quantization_params qnnp_compute_avgpool_quantization_params( int32_t bias, float scale, diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index c01de36..8f4f675 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -275,6 +275,128 @@ class GemmMicrokernelTester { } } + void test(q8gemm_per_channel_ukernel_function qgemm) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a((m() - 1) * aStride() + k() + 8); + std::vector b(n() * k()); + std::vector bias(n()); + std::vector> packedW(packedN() * packedK() + biasN() * sizeof(uint32_t) / sizeof(uint8_t)); + std::vector c((m() - 1) * cStride() + n()); + std::vector acc(m() * n()); + std::vector cRef(m() * n()); + + // Per-Channel quantization parameters + std::vector kernelZeroPointPerChannel(nr()); + std::vector kernelAndInputScalePerChannel(nr()); + std::vector requantizationScalePerChannel(nr()); + std::vector multiplierPerChannel(nr()); + std::vector rightShiftPerChannel(nr()); + + // 1) Fill zero-point per-channel around bZeroPoint() as center value. + // 2) Fill kernel-and-input per-channel using linear interpolation between min and max values. + // (Maintain: requantization_scale < 1 ; + // requantization_scale := input_scale * kernel_scale / output_scale) + const float scale_min = 0.5f; + const float scale_max = 0.99999f; + for (size_t i = 0; i < nr(); ++i) { + kernelZeroPointPerChannel[i] = + static_cast(std::min(255, std::max(0, bZeroPoint() + (int)(i - nr()/2)))); + kernelAndInputScalePerChannel[i] = scale_min + i * (scale_max - scale_min) / nr(); + } + + const uint8_t* aPtr = a.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(c.begin(), c.end(), 0xA5); + + std::fill(packedW.begin(), packedW.end(), bZeroPoint()); + pack_q8gemm_w_per_channel(n(), k(), + nr(), np(), kr(), + aZeroPoint(), kernelZeroPointPerChannel.data(), + b.data(), bias.data(), packedW.data()); + + ASSERT_NE(*std::max_element(a.cbegin(), a.cend()), *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE(*std::max_element(b.cbegin(), b.cend()), *std::min_element(b.cbegin(), b.cend())); + + /* Compute 32-bit results and output quantization arguments */ + std::fill(acc.begin(), acc.end(), 0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t kIndex = 0; kIndex < k(); kIndex++) { + ASSERT_LE(n(), packedN()); + ASSERT_LT(mIndex * n() + nIndex, acc.size()); + ASSERT_LT(mIndex * k() + kIndex, a.size()); + acc[mIndex * n() + nIndex] += + (int32_t(aPtr[mIndex * aStride() + kIndex]) - int32_t(aZeroPoint())) * + (int32_t(b[nIndex * k() + kIndex]) - int32_t(kernelZeroPointPerChannel[nIndex])); + } + acc[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const int32_t accMin = *std::min_element(acc.cbegin(), acc.cend()); + const int32_t accMax = *std::max_element(acc.cbegin(), acc.cend()); + if (m() * n() >= 3) { + ASSERT_NE(accMax, accMin) + << "Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + + const double cScale = uint32_t(accMax - accMin) >= 256 ? double(uint32_t(accMax - accMin)) / 255.0 : 1.00001; + const uint8_t cZeroPoint = uint8_t(std::max(std::min( + lrint(127.5 - 0.5 * double(accMin + accMax) / cScale), + long(std::numeric_limits::max())), long(std::numeric_limits::min()))); + + for (size_t nIndex = 0; nIndex < nr(); nIndex++) { + requantizationScalePerChannel[nIndex] = kernelAndInputScalePerChannel[nIndex] / float(cScale); + } + const union qnnp_conv_quantization_params quantizationParams = + qnnp_compute_conv_quantization_params_per_channel( + aZeroPoint(), nr(), kernelZeroPointPerChannel.data(), + requantizationScalePerChannel.data(), multiplierPerChannel.data(), rightShiftPerChannel.data(), cZeroPoint, qmin(), qmax()); + + qgemm( + m(), n(), k(), + aPtr, aStride() * sizeof(uint8_t), + packedW.data(), + c.data(), cStride() * sizeof(uint8_t), + &quantizationParams, 0); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + const union qnnp_q31_requantization_params scalarRequantizationParams = + qnnp_compute_scalar_requantization_params( + requantizationScalePerChannel[nIndex], cZeroPoint, qmin(), qmax()); + cRef[mIndex * n() + nIndex] = qnnp_q31_requantize(acc[mIndex * n() + nIndex], scalarRequantizationParams); + } + } + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_LE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmax())); + ASSERT_GE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmin())); + ASSERT_EQ(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(cRef[mIndex * n() + nIndex])) + << "at " << mIndex << ", " << nIndex << ": reference = " << (uint32_t) cRef[mIndex * n() + nIndex] + << " (accumulator = " << acc[mIndex * n() + nIndex] + << "), optimized = " << (uint32_t) c[mIndex * cStride() + nIndex] << ", Mr x Nr x Kr = " << mr() << " x " + << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() + << ", requantization scale = " << requantizationScalePerChannel[nIndex] << ", output zero point = " << int32_t(cZeroPoint); + } + } + } + } + void test(q8conv_ukernel_function qconv) const { ASSERT_LE(m(), mr()); ASSERT_LE(n(), nr()); @@ -826,5 +948,5 @@ class GemmMicrokernelTester { uint8_t bZeroPoint_{127}; uint8_t qmin_{0}; uint8_t qmax_{255}; - size_t iterations_{15}; + size_t iterations_{1}; }; diff --git a/test/q8gemm.cc b/test/q8gemm.cc index 4eb77f0..98b453b 100644 --- a/test/q8gemm.cc +++ b/test/q8gemm.cc @@ -14,7 +14,6 @@ #include "gemm-microkernel-tester.h" - #if CPUINFO_ARCH_ARM TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8) { TEST_REQUIRES_ARM_NEON; @@ -2064,6 +2063,301 @@ } } } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_qmin128_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .qmin(128) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_qmax128_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .qmax(128) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_azp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_bzp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_eq_8_nozp_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_azp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_bzp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .bZeroPoint(0) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_nozp_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_gt_8_subtile_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + } + } + + TEST(Q8GEMM_4x8__NEON, k_div_8_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_div_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_div_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + + TEST(Q8GEMM_4x8__NEON, k_div_8_subtile_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_per_channel_ukernel_4x8__neon); + } + } + } + } #endif #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 From 00c14947fff1910e9905bda8bc75e566e8aff31f Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Mon, 25 Feb 2019 11:03:10 -0800 Subject: [PATCH 3/7] Revert "adding unit-test example of per-channel scale and zero-point" This reverts commit 0b06799c2403b3a10d6dc19450e8f50ec4837f09. --- test/convolution-operator-tester.h | 302 +++++------------------------ test/gemm-microkernel-tester.h | 2 +- 2 files changed, 49 insertions(+), 255 deletions(-) diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h index e5dcb85..3f0a06a 100644 --- a/test/convolution-operator-tester.h +++ b/test/convolution-operator-tester.h @@ -348,19 +348,15 @@ class ConvolutionOperatorTester { auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); auto u8rng = std::bind(std::uniform_int_distribution(), rng); - std::vector input( - batchSize() * ((inputHeight() * inputWidth() - 1) * inputPixelStride() + groups() * groupInputChannels()) + 8); - std::vector kernel( - groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * groupInputChannels()); + std::vector input(batchSize() * ((inputHeight() * inputWidth() - 1) * inputPixelStride() + groups() * groupInputChannels()) + 8); + std::vector kernel(groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * groupInputChannels()); std::vector bias(groups() * groupOutputChannels()); - std::vector output( - batchSize() * ((outputHeight() * outputWidth() - 1) * outputPixelStride() + groups() * groupOutputChannels())); + std::vector output(batchSize() * ((outputHeight() * outputWidth() - 1) * outputPixelStride() + groups() * groupOutputChannels())); std::vector accumulators(batchSize() * outputHeight() * outputWidth() * groups() * groupOutputChannels()); const uint8_t* inputPtr = input.data() + 8; const uint8_t inputZeroPoint = 127; const uint8_t kernelZeroPoint = 127; - const float kernelScale = 1.0f; for (size_t iteration = 0; iteration < iterations(); iteration++) { std::generate(input.begin(), input.end(), std::ref(u8rng)); @@ -374,9 +370,8 @@ class ConvolutionOperatorTester { for (size_t ox = 0; ox < outputWidth(); ox++) { for (size_t g = 0; g < groups(); g++) { for (size_t oc = 0; oc < groupOutputChannels(); oc++) { - accumulators - [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] = - bias[g * groupOutputChannels() + oc]; + accumulators[(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] = + bias[g * groupOutputChannels() + oc]; } } } @@ -394,20 +389,9 @@ class ConvolutionOperatorTester { for (size_t g = 0; g < groups(); g++) { for (size_t oc = 0; oc < groupOutputChannels(); oc++) { for (size_t ic = 0; ic < groupInputChannels(); ic++) { - accumulators - [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * - groupOutputChannels() + - oc] += - (int32_t(inputPtr - [((i * inputHeight() + iy) * inputWidth() + ix) * inputPixelStride() + - g * groupInputChannels() + ic]) - - int32_t(inputZeroPoint)) * - (int32_t(kernel - [(((g * groupOutputChannels() + oc) * kernelHeight() + ky) * kernelWidth() + - kx) * - groupInputChannels() + - ic]) - - int32_t(kernelZeroPoint)); + accumulators[(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] += + (int32_t(inputPtr[((i * inputHeight() + iy) * inputWidth() + ix) * inputPixelStride() + g * groupInputChannels() + ic]) - int32_t(inputZeroPoint)) * + (int32_t(kernel[(((g * groupOutputChannels() + oc) * kernelHeight() + ky) * kernelWidth() + kx) * groupInputChannels() + ic]) - int32_t(kernelZeroPoint)); } } } @@ -422,228 +406,44 @@ class ConvolutionOperatorTester { const int32_t accumulatorsMax = *std::max_element(accumulators.cbegin(), accumulators.cend()); const double outputScale = double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; - const uint8_t outputZeroPoint = uint8_t(std::max( - std::min( - lrint(127.5 - 0.5 * double(accumulatorsMin + accumulatorsMax) / outputScale), - long(std::numeric_limits::max())), - long(std::numeric_limits::min()))); + const uint8_t outputZeroPoint = uint8_t(std::max(std::min( + lrint(127.5 - 0.5 * double(accumulatorsMin + accumulatorsMax) / outputScale), + long(std::numeric_limits::max())), long(std::numeric_limits::min()))); ASSERT_EQ(qnnp_status_success, qnnp_initialize()); qnnp_operator_t convolution = nullptr; - ASSERT_EQ( - qnnp_status_success, - qnnp_create_convolution2d_nhwc_q8( - paddingTop(), - paddingRight(), - paddingBottom(), - paddingLeft(), - kernelHeight(), - kernelWidth(), - subsamplingHeight(), - subsamplingWidth(), - dilationHeight(), - dilationWidth(), - groups(), - groupInputChannels(), - groupOutputChannels(), - inputZeroPoint, - 1.0f /* input scale */, - kernelZeroPoint, - kernelScale, - kernel.data(), - bias.data(), - outputZeroPoint, - outputScale, - qmin(), - qmax(), - 0, - &convolution)); - - ASSERT_EQ( - qnnp_status_success, - qnnp_setup_convolution2d_nhwc_q8( - convolution, - batchSize(), - inputHeight(), - inputWidth(), - inputPtr, - inputPixelStride(), - output.data(), - outputPixelStride(), - nullptr /* thread pool */)); - - ASSERT_EQ(qnnp_status_success, qnnp_run_operator(convolution, nullptr /* thread pool */)); - - ASSERT_EQ(qnnp_status_success, qnnp_delete_operator(convolution)); - convolution = nullptr; - - for (size_t i = 0; i < batchSize(); i++) { - for (size_t y = 0; y < outputHeight(); y++) { - for (size_t x = 0; x < outputWidth(); x++) { - for (size_t g = 0; g < groups(); g++) { - for (size_t c = 0; c < groupOutputChannels(); c++) { - const double scaledAccumulator = - accumulators - [(((i * outputHeight() + y) * outputWidth() + x) * groups() + g) * groupOutputChannels() + c] * - kernelScale / outputScale; - const double clampedAccumulator = std::max( - std::min(scaledAccumulator, double(qmax()) - double(outputZeroPoint)), - double(qmin()) - double(outputZeroPoint)); - ASSERT_NEAR( - clampedAccumulator, - (int32_t(output - [((i * outputHeight() + y) * outputWidth() + x) * outputPixelStride() + - g * groupOutputChannels() + c]) - - outputZeroPoint), - 0.9) - << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; - } - } - } - } - } - } - } - - void testQ8_perChannel() const { - std::random_device randomDevice; - auto rng = std::mt19937(randomDevice()); - auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); - auto u8rng = std::bind(std::uniform_int_distribution(), rng); - - std::vector input( - batchSize() * ((inputHeight() * inputWidth() - 1) * inputPixelStride() + groups() * groupInputChannels()) + 8); - std::vector kernel( - groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * groupInputChannels()); - std::vector bias(groups() * groupOutputChannels()); - std::vector output( - batchSize() * ((outputHeight() * outputWidth() - 1) * outputPixelStride() + groups() * groupOutputChannels())); - std::vector accumulators(batchSize() * outputHeight() * outputWidth() * groups() * groupOutputChannels()); - - const uint8_t* inputPtr = input.data() + 8; - const uint8_t inputZeroPoint = 127; - const uint8_t kernelZeroPointFixed = 127; - const float kernelScaleFixed = 1.0f; - std::vector kernelScale(groups() * groupOutputChannels()); - std::vector kernelZeroPoint(groups() * groupOutputChannels()); - std::fill(kernelScale.begin(), kernelScale.end(), kernelScaleFixed); - std::fill(kernelZeroPoint.begin(), kernelZeroPoint.end(), kernelZeroPointFixed); - - for (size_t iteration = 0; iteration < iterations(); iteration++) { - std::generate(input.begin(), input.end(), std::ref(u8rng)); - std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); - std::generate(bias.begin(), bias.end(), std::ref(s32rng)); - std::fill(output.begin(), output.end(), 0xA5); - std::fill(accumulators.begin(), accumulators.end(), 0); - - for (size_t i = 0; i < batchSize(); i++) { - for (size_t oy = 0; oy < outputHeight(); oy++) { - for (size_t ox = 0; ox < outputWidth(); ox++) { - for (size_t g = 0; g < groups(); g++) { - for (size_t oc = 0; oc < groupOutputChannels(); oc++) { - accumulators - [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * groupOutputChannels() + oc] = - bias[g * groupOutputChannels() + oc]; - } - } - } - } - } - for (size_t i = 0; i < batchSize(); i++) { - for (size_t oy = 0; oy < outputHeight(); oy++) { - for (size_t ox = 0; ox < outputWidth(); ox++) { - for (size_t ky = 0; ky < kernelHeight(); ky++) { - const size_t iy = oy * subsamplingHeight() + ky * dilationHeight() - paddingTop(); - if (iy < inputHeight()) { - for (size_t kx = 0; kx < kernelWidth(); kx++) { - const size_t ix = ox * subsamplingWidth() + kx * dilationWidth() - paddingLeft(); - if (ix < inputWidth()) { - for (size_t g = 0; g < groups(); g++) { - for (size_t oc = 0; oc < groupOutputChannels(); oc++) { - for (size_t ic = 0; ic < groupInputChannels(); ic++) { - accumulators - [(((i * outputHeight() + oy) * outputWidth() + ox) * groups() + g) * - groupOutputChannels() + - oc] += - (int32_t(inputPtr - [((i * inputHeight() + iy) * inputWidth() + ix) * inputPixelStride() + - g * groupInputChannels() + ic]) - - int32_t(inputZeroPoint)) * - (int32_t(kernel - [(((g * groupOutputChannels() + oc) * kernelHeight() + ky) * kernelWidth() + - kx) * - groupInputChannels() + - ic]) - - int32_t(kernelZeroPoint[g * groupOutputChannels() + oc])); - } - } - } - } - } - } - } - } - } - } - const int32_t accumulatorsMin = *std::min_element(accumulators.cbegin(), accumulators.cend()); - const int32_t accumulatorsMax = *std::max_element(accumulators.cbegin(), accumulators.cend()); - - const double outputScale = double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; - const uint8_t outputZeroPoint = uint8_t(std::max( - std::min( - lrint(127.5 - 0.5 * double(accumulatorsMin + accumulatorsMax) / outputScale), - long(std::numeric_limits::max())), - long(std::numeric_limits::min()))); - - ASSERT_EQ(qnnp_status_success, qnnp_initialize()); - qnnp_operator_t convolution = nullptr; - ASSERT_EQ( - qnnp_status_success, - qnnp_create_convolution2d_nhwc_q8( - paddingTop(), - paddingRight(), - paddingBottom(), - paddingLeft(), - kernelHeight(), - kernelWidth(), - subsamplingHeight(), - subsamplingWidth(), - dilationHeight(), - dilationWidth(), - groups(), - groupInputChannels(), - groupOutputChannels(), - inputZeroPoint, - 1.0f /* input scale */, - kernelZeroPointFixed, - kernelScaleFixed, - kernel.data(), - bias.data(), - outputZeroPoint, - outputScale, - qmin(), - qmax(), - 0, - &convolution)); - - ASSERT_EQ( - qnnp_status_success, - qnnp_setup_convolution2d_nhwc_q8( - convolution, - batchSize(), - inputHeight(), - inputWidth(), - inputPtr, - inputPixelStride(), - output.data(), - outputPixelStride(), - nullptr /* thread pool */)); - - ASSERT_EQ(qnnp_status_success, qnnp_run_operator(convolution, nullptr /* thread pool */)); - - ASSERT_EQ(qnnp_status_success, qnnp_delete_operator(convolution)); + ASSERT_EQ(qnnp_status_success, + qnnp_create_convolution2d_nhwc_q8( + paddingTop(), paddingRight(), paddingBottom(), paddingLeft(), + kernelHeight(), kernelWidth(), + subsamplingHeight(), subsamplingWidth(), + dilationHeight(), dilationWidth(), + groups(), groupInputChannels(), groupOutputChannels(), + inputZeroPoint, 1.0f /* input scale */, + kernelZeroPoint, 1.0f /* kernel scale */, + kernel.data(), bias.data(), + outputZeroPoint, outputScale, qmin(), qmax(), + 0, &convolution)); + + ASSERT_EQ(qnnp_status_success, + qnnp_setup_convolution2d_nhwc_q8( + convolution, + batchSize(), + inputHeight(), + inputWidth(), + inputPtr, + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ(qnnp_status_success, + qnnp_run_operator(convolution, nullptr /* thread pool */)); + + ASSERT_EQ(qnnp_status_success, + qnnp_delete_operator(convolution)); convolution = nullptr; for (size_t i = 0; i < batchSize(); i++) { @@ -652,20 +452,14 @@ class ConvolutionOperatorTester { for (size_t g = 0; g < groups(); g++) { for (size_t c = 0; c < groupOutputChannels(); c++) { const double scaledAccumulator = - accumulators - [(((i * outputHeight() + y) * outputWidth() + x) * groups() + g) * groupOutputChannels() + c] * - kernelScale[g * groupOutputChannels() + c] / outputScale; - const double clampedAccumulator = std::max( - std::min(scaledAccumulator, double(qmax()) - double(outputZeroPoint)), - double(qmin()) - double(outputZeroPoint)); + accumulators[(((i * outputHeight() + y) * outputWidth() + x) * groups() + g) * groupOutputChannels() + c] / outputScale; + const double clampedAccumulator = std::max(std::min(scaledAccumulator, + double(qmax()) - double(outputZeroPoint)), + double(qmin()) - double(outputZeroPoint)); ASSERT_NEAR( - clampedAccumulator, - (int32_t(output - [((i * outputHeight() + y) * outputWidth() + x) * outputPixelStride() + - g * groupOutputChannels() + c]) - - outputZeroPoint), - 0.9) - << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; + clampedAccumulator, + (int32_t(output[((i * outputHeight() + y) * outputWidth() + x) * outputPixelStride() + g * groupOutputChannels() + c]) - outputZeroPoint), + 0.9) << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; } } } diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index 8f4f675..9333e03 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -948,5 +948,5 @@ class GemmMicrokernelTester { uint8_t bZeroPoint_{127}; uint8_t qmin_{0}; uint8_t qmax_{255}; - size_t iterations_{1}; + size_t iterations_{15}; }; From 0ab971e9661ff311a0a5f8110e9a4135488a739e Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Mon, 25 Feb 2019 17:34:49 -0800 Subject: [PATCH 4/7] Benchmarks for q8gemm with per-channel kernel quantization parameters --- bench/q8gemm.cc | 134 ++++++++++++++++++++++++++++++++++++++++++ src/q8gemm/4x8-neon.c | 2 +- src/qnnpack/q8gemm.h | 2 +- test/q8gemm.cc | 38 ++++++------ 4 files changed, 155 insertions(+), 21 deletions(-) diff --git a/bench/q8gemm.cc b/bench/q8gemm.cc index f06d9e6..0fc098b 100644 --- a/bench/q8gemm.cc +++ b/bench/q8gemm.cc @@ -299,6 +299,105 @@ class Q8GEMM_XZP : public Q8GEMM { qnnp_q31_requantization_params requantizationParams_; }; +class Q8GEMM_PER_CHANNEL : public Q8GEMM { + public: + inline Q8GEMM_PER_CHANNEL(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr) : Q8GEMM(mr, nr, np, kr) {} + virtual void SetUp(const benchmark::State&) override + { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + a_.resize(mc() * kc()); + std::generate(a_.begin(), a_.end(), std::ref(u8rng)); + k_.resize(nc() * kc()); + std::generate(k_.begin(), k_.end(), std::ref(u8rng)); + b_.resize(mc()); + std::generate(b_.begin(), b_.end(), std::ref(s32rng)); + w_.resize(kcStride() * ncStride() + ncStride() * sizeof(int32_t) / sizeof(uint8_t)); + const uint8_t kernel_zero_point_center = 127; + kernelZeroPointPerChannel_.resize(nr()); + requantizationScalePerChannel_.resize(nr()); + multiplierPerChannel_.resize(nr()); + rightShiftPerChannel_.resize(nr()); + const float scale_min = 0.5f; + const float scale_max = 0.99999f; + for (size_t i = 0; i < nr(); ++i) { + kernelZeroPointPerChannel_[i] = + static_cast(std::min(255, std::max(0, kernel_zero_point_center + (int)(i - nr()/2)))); + requantizationScalePerChannel_[i] = scale_min + i * (scale_max - scale_min) / nr(); + } + std::fill(w_.begin(), w_.end(), kernel_zero_point_center); + pack_q8gemm_w_per_channel( + nc(), kc(), + nr(), np(), kr(), + 127, kernelZeroPointPerChannel_.data(), + k(), b(), w()); + c_.resize(mc() * nc()); + std::fill(c_.begin(), c_.end(), 0xA5); + quantizationParams_ = + qnnp_compute_conv_quantization_params_per_channel( + 127, nr(), kernelZeroPointPerChannel_.data(), + requantizationScalePerChannel_.data(), multiplierPerChannel_.data(), rightShiftPerChannel_.data(), + 127, 1, 254); + } + + virtual void TearDown(benchmark::State& state) override + { + state.SetItemsProcessed(uint64_t(state.iterations()) * 2 * mc() * nc() * kc()); + a_.clear(); + k_.clear(); + b_.clear(); + w_.clear(); + c_.clear(); + kernelZeroPointPerChannel_.clear(); + kernelAndInputScalePerChannel_.clear(); + requantizationScalePerChannel_.clear(); + multiplierPerChannel_.clear(); + rightShiftPerChannel_.clear(); + } + + protected: + std::vector kernelZeroPointPerChannel_; + std::vector kernelAndInputScalePerChannel_; + std::vector requantizationScalePerChannel_; + std::vector multiplierPerChannel_; + std::vector rightShiftPerChannel_; +}; + +template +class Q8GEMM_PER_CHANNEL_L1 : public Q8GEMM_PER_CHANNEL { + public: + inline Q8GEMM_PER_CHANNEL_L1() : Q8GEMM_PER_CHANNEL(MR, NR, NP, KR) + { + cpuinfo_initialize(); + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 512; + kc_ = ((l1d_size - l1d_reserve) / sizeof(uint8_t) - mr() * nr()) / (mr() + nr()); + if (kr() != 1) { + kc_ = kc_ / kr() * kr(); + } else { + kc_ = kc_ / nr() * nr(); + } + } +}; + +template +class Q8GEMM_PER_CHANNEL_Op : public Q8GEMM_PER_CHANNEL { + public: + inline Q8GEMM_PER_CHANNEL_Op() : Q8GEMM_PER_CHANNEL(MR, NR, NP, KR) {} + + virtual void SetUp(const benchmark::State& state) override + { + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + Q8GEMM_PER_CHANNEL::SetUp(state); + } +}; + template class Q8GEMM_XZP_L1 : public Q8GEMM_XZP { public: @@ -770,6 +869,41 @@ BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(MobileNetV1GemmArguments); BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(SqueezeNetV10GemmArguments); BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(GemmArguments); +BENCHMARK_TEMPLATE_F(Q8GEMM_PER_CHANNEL_L1, 4x8__neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + q8gemm_ukernel_4x8__neon_per_channel( + mr(), nr(), kc(), + a(), kc() * sizeof(uint8_t), + w(), + c(), mr() * sizeof(uint8_t), + quantizationParams(), 0); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + q8gemm_ukernel_4x8__neon_per_channel( + mrr, nrr, kc(), + a() + m * kc(), kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, nc() * sizeof(uint8_t), + quantizationParams(), 0); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(GemmArguments); + BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__neon, 8, 8, 8, 1)(benchmark::State& state) { for (auto _ : state) { diff --git a/src/q8gemm/4x8-neon.c b/src/q8gemm/4x8-neon.c index fc47259..5855d43 100644 --- a/src/q8gemm/4x8-neon.c +++ b/src/q8gemm/4x8-neon.c @@ -363,7 +363,7 @@ void q8gemm_ukernel_4x8__neon( } } -void q8gemm_per_channel_ukernel_4x8__neon( +void q8gemm_ukernel_4x8__neon_per_channel( size_t mr, size_t nr, size_t k, diff --git a/src/qnnpack/q8gemm.h b/src/qnnpack/q8gemm.h index f0c8841..d063c55 100644 --- a/src/qnnpack/q8gemm.h +++ b/src/qnnpack/q8gemm.h @@ -56,7 +56,7 @@ DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_4x4c2__sse2) const union qnnp_conv_quantization_params* quantization_params, \ size_t kernel_quantization_params_offset); -DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_per_channel_ukernel_4x8__neon) +DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_ukernel_4x8__neon_per_channel) #define DECLARE_Q8GEMM_XZP_UKERNEL_FUNCTION(fn_name) \ QNNP_INTERNAL void fn_name( \ diff --git a/test/q8gemm.cc b/test/q8gemm.cc index 98b453b..e4e9ee4 100644 --- a/test/q8gemm.cc +++ b/test/q8gemm.cc @@ -2074,7 +2074,7 @@ .m(4) .n(8) .k(8) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_a_per_channel) { @@ -2088,7 +2088,7 @@ .n(8) .k(8) .aStride(37) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_c_per_channel) { @@ -2102,7 +2102,7 @@ .n(8) .k(8) .cStride(17) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_qmin128_per_channel) { @@ -2116,7 +2116,7 @@ .n(8) .k(8) .qmin(128) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_qmax128_per_channel) { @@ -2130,7 +2130,7 @@ .n(8) .k(8) .qmax(128) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_azp0_per_channel) { @@ -2144,7 +2144,7 @@ .n(8) .k(8) .aZeroPoint(0) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_bzp0_per_channel) { @@ -2158,7 +2158,7 @@ .n(8) .k(8) .bZeroPoint(0) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_eq_8_nozp_per_channel) { @@ -2173,7 +2173,7 @@ .k(8) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } TEST(Q8GEMM_4x8__NEON, k_gt_8_per_channel) { @@ -2187,7 +2187,7 @@ .m(4) .n(8) .k(k) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2203,7 +2203,7 @@ .n(8) .k(k) .aStride(37) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2219,7 +2219,7 @@ .n(8) .k(k) .cStride(17) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2235,7 +2235,7 @@ .n(8) .k(k) .aZeroPoint(0) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2251,7 +2251,7 @@ .n(8) .k(k) .bZeroPoint(0) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2268,7 +2268,7 @@ .k(k) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2286,7 +2286,7 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } } @@ -2303,7 +2303,7 @@ .m(4) .n(8) .k(k) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2319,7 +2319,7 @@ .n(8) .k(k) .aStride(171) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2335,7 +2335,7 @@ .n(8) .k(k) .cStride(17) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } @@ -2353,7 +2353,7 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_per_channel_ukernel_4x8__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } } From efe00eb8ffe16210d172ebc3425155c1992eba08 Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Tue, 26 Feb 2019 17:38:07 -0800 Subject: [PATCH 5/7] moving 4x8-neon per channel ukernel to a separate file --- configure.py | 1 + src/q8gemm/4x8-neon.c | 356 ----------------------------- src/q8gemm/4x8-neon_per_channel.c | 368 ++++++++++++++++++++++++++++++ 3 files changed, 369 insertions(+), 356 deletions(-) create mode 100644 src/q8gemm/4x8-neon_per_channel.c diff --git a/configure.py b/configure.py index 8c9d1ae..03304f8 100755 --- a/configure.py +++ b/configure.py @@ -107,6 +107,7 @@ def main(args): build.cc("q8gavgpool/up8xm-neon.c"), build.cc("q8gemm/4x-sumrows-neon.c"), build.cc("q8gemm/4x8-neon.c"), + build.cc("q8gemm/4x8-neon_per_channel.c"), build.cc("q8gemm/4x8c2-xzp-neon.c"), build.cc("q8gemm/6x4-neon.c"), build.cc("q8gemm/8x8-neon.c"), diff --git a/src/q8gemm/4x8-neon.c b/src/q8gemm/4x8-neon.c index 5855d43..ce05b91 100644 --- a/src/q8gemm/4x8-neon.c +++ b/src/q8gemm/4x8-neon.c @@ -362,359 +362,3 @@ void q8gemm_ukernel_4x8__neon( } } } - -void q8gemm_ukernel_4x8__neon_per_channel( - size_t mr, - size_t nr, - size_t k, - const uint8_t* restrict a, - size_t a_stride, - const void* restrict w, - uint8_t* restrict c, - size_t c_stride, - const union qnnp_conv_quantization_params quantization_params[restrict static 1], - size_t kernel_quantization_params_offset) -{ - int32x4_t vacc0x0123 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); - int32x4_t vacc0x4567 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); - int32x4_t vacc1x0123 = vacc0x0123; - int32x4_t vacc1x4567 = vacc0x4567; - int32x4_t vacc2x0123 = vacc0x0123; - int32x4_t vacc2x4567 = vacc0x4567; - int32x4_t vacc3x0123 = vacc0x0123; - int32x4_t vacc3x4567 = vacc0x4567; - - const uint8_t* a0 = a; - const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride); - if (mr < 2) { - a1 = a0; - } - const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride); - if (mr <= 2) { - a2 = a1; - } - const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride); - if (mr != 4) { - a3 = a2; - } - - const uint8x8_t vb_zero_point = vld1_u8((const uint8_t*) &quantization_params->neon.kernel_zero_point_v[kernel_quantization_params_offset]); - for (; k >= 8; k -= 8) { - const uint8x8_t va0 = vld1_u8(a0); a0 += 8; - const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); - const uint8x8_t va1 = vld1_u8(a1); a1 += 8; - const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); - const uint8x8_t va2 = vld1_u8(a2); a2 += 8; - const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); - const uint8x8_t va3 = vld1_u8(a3); a3 += 8; - const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); - - const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); - - const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); - - const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); - - const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); - - const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); - - const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); - - const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); - - const uint8x8_t vb01234567c7 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c7 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3); - } - if (k != 0) { - const size_t a_predecrement = 8 - k; - const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); - const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); - const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); - const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); - const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); - const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); - const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); - const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); - const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); - - const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); - - if (k >= 2) { - const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); - - if (k >= 3) { - const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); - - if (k >= 4) { - const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); - - if (k >= 5) { - const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); - - if (k >= 6) { - const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); - - if (k >= 7) { - const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); - const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); - - vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); - vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); - vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); - vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); - vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); - vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); - vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); - vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); - } - } - } - } - } - } - } - - const int32x4_t vmultiplier0x0123 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset]); - const int32x4_t vmultiplier0x4567 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset + 4]); - vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier0x0123); - vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier0x4567); - vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier0x0123); - vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier0x4567); - vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier0x0123); - vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier0x4567); - vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier0x0123); - vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier0x4567); - - const int32x4_t vright_shift_0x0123 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset]); - const int32x4_t vright_shift_0x4567 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset + 4]); - const int32x4_t vzero_shift_mask_0x0123 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x0123, vmovq_n_s32(0))); - const int32x4_t vzero_shift_mask_0x4567 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x4567, vmovq_n_s32(0))); - vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask_0x0123), 31); - vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask_0x4567), 31); - vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask_0x0123), 31); - vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask_0x4567), 31); - vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask_0x0123), 31); - vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask_0x4567), 31); - vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask_0x0123), 31); - vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask_0x4567), 31); - - vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift_0x0123); - vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift_0x4567); - vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift_0x0123); - vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift_0x4567); - vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift_0x0123); - vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift_0x4567); - vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift_0x0123); - vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift_0x4567); - - const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); -#ifdef __aarch64__ - const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); - const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); - const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); - const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); - - uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); - uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); -#else - const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); - const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); - const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); - const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); - - uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); - uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); -#endif - const uint8x16_t voutput_min = vld1q_dup_u8(&quantization_params->neon.output_min); - const uint8x16_t voutput_max = vld1q_dup_u8(&quantization_params->neon.output_max); - - vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); - vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); - vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); - vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); - - uint8_t* c0 = c; - uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + c_stride); - if (mr < 2) { - c1 = c0; - } - uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + c_stride); - if (mr <= 2) { - c2 = c1; - } - uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + c_stride); - if (mr != 4) { - c3 = c2; - } - if (nr == 8) { - vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); - vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); - vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); - vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); - } else { - if (nr >= 4) { - vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4; - vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4; - vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4; - vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4; - vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); - vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); - nr -= 4; - } - if (nr >= 2) { - vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2; - vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2; - vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2; - vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2; - vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); - vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); - nr -= 2; - } - if (nr != 0) { - vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); - vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); - vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); - vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); - } - } -} diff --git a/src/q8gemm/4x8-neon_per_channel.c b/src/q8gemm/4x8-neon_per_channel.c new file mode 100644 index 0000000..c7af2b5 --- /dev/null +++ b/src/q8gemm/4x8-neon_per_channel.c @@ -0,0 +1,368 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + + +void q8gemm_ukernel_4x8__neon_per_channel( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union qnnp_conv_quantization_params quantization_params[restrict static 1], + size_t kernel_quantization_params_offset) +{ + int32x4_t vacc0x0123 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); + int32x4_t vacc0x4567 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride); + if (mr != 4) { + a3 = a2; + } + + const uint8x8_t vb_zero_point = vld1_u8((const uint8_t*) &quantization_params->neon.kernel_zero_point_v[kernel_quantization_params_offset]); + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); a0 += 8; + const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); + const uint8x8_t va1 = vld1_u8(a1); a1 += 8; + const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); + const uint8x8_t va2 = vld1_u8(a2); a2 += 8; + const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); + const uint8x8_t va3 = vld1_u8(a3); a3 += 8; + const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + + const uint8x8_t vb01234567c7 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c7 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); + const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); + const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); + const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + if (k >= 2) { + const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + if (k >= 3) { + const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + if (k >= 4) { + const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + if (k >= 5) { + const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + if (k >= 6) { + const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + if (k >= 7) { + const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + } + } + } + } + } + } + } + + const int32x4_t vmultiplier0x0123 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset]); + const int32x4_t vmultiplier0x4567 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset + 4]); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier0x0123); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier0x4567); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier0x0123); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier0x4567); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier0x0123); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier0x4567); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier0x0123); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier0x4567); + + const int32x4_t vright_shift_0x0123 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset]); + const int32x4_t vright_shift_0x4567 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset + 4]); + const int32x4_t vzero_shift_mask_0x0123 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x0123, vmovq_n_s32(0))); + const int32x4_t vzero_shift_mask_0x4567 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x4567, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask_0x0123), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask_0x4567), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask_0x0123), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask_0x4567), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask_0x0123), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask_0x4567), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask_0x0123), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask_0x4567), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift_0x0123); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift_0x4567); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift_0x0123); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift_0x4567); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift_0x0123); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift_0x4567); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift_0x0123); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift_0x4567); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); +#endif + const uint8x16_t voutput_min = vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4; + vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2; + vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + } + } +} From fe09e886b15133d1636905ade721d77490c02f0e Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Wed, 27 Feb 2019 17:16:46 -0800 Subject: [PATCH 6/7] q8gemm per-channel armv7 ukernel --- bench/q8gemm.cc | 34 + configure.py | 1 + src/q8gemm/4x8-aarch32-neon-per-channel.S | 819 ++++++++++++++++++++++ src/q8gemm/4x8-aarch32-neon.S | 2 +- src/qnnpack/q8gemm.h | 2 + test/gemm-microkernel-tester.h | 3 + test/q8gemm.cc | 295 ++++++++ 7 files changed, 1155 insertions(+), 1 deletion(-) create mode 100644 src/q8gemm/4x8-aarch32-neon-per-channel.S diff --git a/bench/q8gemm.cc b/bench/q8gemm.cc index 0fc098b..8800efd 100644 --- a/bench/q8gemm.cc +++ b/bench/q8gemm.cc @@ -746,6 +746,40 @@ BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(MobileNetV1GemmArgumen BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(SqueezeNetV10GemmArguments); BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(GemmArguments); +BENCHMARK_TEMPLATE_F(Q8GEMM_PER_CHANNEL_L1, 4x8__aarch32_neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + q8gemm_ukernel_4x8__aarch32_neon_per_channel( + mr(), nr(), kc(), + a(), kc() * sizeof(uint8_t), + w(), + c(), mr() * sizeof(uint8_t), + quantizationParams(), 0); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + q8gemm_ukernel_4x8__aarch32_neon_per_channel( + mrr, nrr, kc(), + a() + m * kc(), kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, nc() * sizeof(uint8_t), + quantizationParams(), 0); + } + } + } +} +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(GemmArguments); + BENCHMARK_TEMPLATE_F(Q8GEMM_XZP_L1, 4x8c2__aarch32_neon, 4, 8, 8, 2)(benchmark::State& state) { for (auto _ : state) { diff --git a/configure.py b/configure.py index 03304f8..b6adc24 100755 --- a/configure.py +++ b/configure.py @@ -129,6 +129,7 @@ def main(args): build.cc("q8conv/4x8-aarch32-neon.S"), build.cc("q8dwconv/up8x9-aarch32-neon.S"), build.cc("q8gemm/4x8-aarch32-neon.S"), + build.cc("q8gemm/4x8-aarch32-neon-per-channel.S"), build.cc("q8gemm/4x8c2-xzp-aarch32-neon.S"), ] if build.target.is_arm64: diff --git a/src/q8gemm/4x8-aarch32-neon-per-channel.S b/src/q8gemm/4x8-aarch32-neon-per-channel.S new file mode 100644 index 0000000..e200124 --- /dev/null +++ b/src/q8gemm/4x8-aarch32-neon-per-channel.S @@ -0,0 +1,819 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +.syntax unified + +# void q8gemm_ukernel_4x8__aarch32_neon_per_channel( +# size_t mr, +# size_t nr, +# size_t k, +# const uint8_t*restrict a, +# size_t a_stride, +# const void*restrict w, +# uint8_t*restrict c, +# size_t c_stride, +# const union qnnp_conv_quantization_params quantization_params[restrict static 1], +# size_t kernel_quantization_params_offset) +BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + # Load w + # - ip = w + LDR ip, [sp, 4] + PUSH {r4-r8} + + VPUSH {d8-d15} + # Load quantization params + # - r7 = quantization_params + LDR r7, [sp, 100] + + # load kernel_quantization_params_offset + LDR r5, [sp, 104] + + # Load bias0123, bias4567 + VLDM ip!, {d16-d19} + + # Load a_stride + # - r6 = a_stride + LDR r6, [sp, 84] + CMP r0, 2 + + # a1 = a0 + a_stride + ADD r4, r3, r6 + + # Load b_zero_point from kernel_zero_point_v: + # - d15 = b_zero_point + MOV r8, 16 + ADD r8, r7, r8 + LDR r8, [r8, r5] + VLD1.8 {d15}, [r8] + MOVLO r4, r3 + + # Move kernel_quantization_params_offset to r8 to use later + MOV r8, r5 + + ADD r7, r7, 4 + ADD r5, r4, r6 + + # q10 := vacc1x0123 + VMOV.I32 q10, q8 + MOVLS r5, r4 + # q11 := vacc1x4567 + VMOV.I32 q11, q9 + ADD r6, r5, r6 + # q12 := vacc2x0123 + VMOV.I32 q12, q8 + CMP r0, 4 + # q13 := vacc2x4567 + VMOV.I32 q13, q9 + MOVNE r6, r5 + # q14 := vacc3x0123 + VMOV.I32 q14, q8 + SUBS r2, r2, 8 + # q15 := vacc3x4567 + VMOV.I32 q15, q9 + + BLO 1f + + .p2align 5 +0: + # Load a0 + # - d1 = a0 + VLD1.8 {d1}, [r3]! + + # Load a1 + # - d3 = a1 + VLD1.8 {d3}, [r4]! + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r5]! + + # q0 = va0 = a0 + VMOVL.U8 q0, d1 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r6]! + + # q1 = va1 = a1 + VMOVL.U8 q1, d3 + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + VMOVL.U8 q2, d5 + # q3 = va3 = a3 + VMOVL.U8 q3, d7 + + ### Channel 0 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + ### Channel 1 ### + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + + # Load b0-b7 (channel 3) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 3) + # - d11 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + + # Load b0-b7 (channel 4) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + + # Load b0-b7 (channel 5) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 5) + # - d9 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + + # Load b0-b7 (channel 7) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 7) + # - d11 = vb4567 (channel 7) + VSUBL.U8 q5, d11, d15 + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + ### Channel 8 ### + SUBS r2, r2, 8 + + # vacc0x0123 += vb0123 * va0[7] + VMLAL.S16 q8, d10, d1[3] + # vacc0x4567 += vb4567 * va0[7] + VMLAL.S16 q9, d11, d1[3] + + # vacc1x0123 += vb0123 * va1[7] + VMLAL.S16 q10, d10, d3[3] + # vacc1x4567 += vb4567 * va1[7] + VMLAL.S16 q11, d11, d3[3] + + # vacc2x0123 += vb0123 * va2[7] + VMLAL.S16 q12, d10, d5[3] + # vacc2x4567 += vb4567 * va2[7] + VMLAL.S16 q13, d11, d5[3] + + # vacc3x0123 += vb0123 * va3[7] + VMLAL.S16 q14, d10, d7[3] + # vacc3x4567 += vb4567 * va3[7] + VMLAL.S16 q15, d11, d7[3] + + BHS 0b + +1: + CMP r2, -8 + BEQ 2f + + # Adjust a0, a1, a2, a3 + ADD r3, r2 + ADD r4, r2 + ADD r5, r2 + ADD r6, r2 + + # a_shift = 8 * k - 64 + LSL r2, r2, 3 + VDUP.32 d13, r2 + + # Load a0 + # - d1 = a0 + VLD1.8 {d1}, [r3] + + # Load a1 + # - d3 = a1 + VLD1.8 {d3}, [r4] + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r5] + + # q0 = va0 = a0 + VSHL.U64 d1, d1, d13 + VMOVL.U8 q0, d1 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r6] + + # q1 = va1 = a1 + VSHL.U64 d3, d3, d13 + VMOVL.U8 q1, d3 + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + VSHL.U64 d5, d5, d13 + VMOVL.U8 q2, d5 + # q3 = va3 = a3 + VSHL.U64 d7, d7, d13 + VMOVL.U8 q3, d7 + + ### Channel 0 ### + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + CMP r2, -48 + BLO 2f + + ### Channel 1 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + BLS 2f + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + CMP r2, -32 + BLO 2f + + # Load b0-b7 (channel 3) + # - d9 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 3) + # - d9 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + BLS 2f + + # Load b0-b7 (channel 4) + # - d11 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + CMP r2, -16 + BLO 2f + + # Load b0-b7 (channel 5) + # - d13 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 5) + # - d11 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + BLS 2f + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + .p2align 4 +2: + # Load multiplier: + # - d12 = vmultiplier ( vmultiplier0x0123 ) + ADD r7, 16 + LDR r8, [sp, 104] + LDR r8, [r7, r8] + VLD1.32 {d12, d13}, [r8]! + SUB r7, 12 + + VQRDMULH.S32 q8, q8, q6 + VQRDMULH.S32 q10, q10, q6 + VQRDMULH.S32 q12, q12, q6 + VQRDMULH.S32 q14, q14, q6 + + VLD1.32 {d12, d13}, [r8] + + VQRDMULH.S32 q9, q9, q6 + VQRDMULH.S32 q11, q11, q6 + VQRDMULH.S32 q13, q13, q6 + VQRDMULH.S32 q15, q15, q6 + + # Load right_shift + # - q4 = d8:d9 = vright_shift_0x0123 + ADD r7, 16 + LDR r8, [sp, 104] + LDR r8, [r7, r8] + VLD1.32 {d8, d9}, [r8]! + SUB r7, 12 + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask_0x0123 + VCEQ.S32 q5, q4, 0 + + VBIC q0, q8, q5 + VBIC q1, q10, q5 + VBIC q2, q12, q5 + VBIC q3, q14, q5 + + VSRA.S32 q8, q0, 31 + VSRA.S32 q10, q1, 31 + VSRA.S32 q12, q2, 31 + VSRA.S32 q14, q3, 31 + + VRSHL.S32 q8, q8, q4 + VRSHL.S32 q10, q10, q4 + VRSHL.S32 q12, q12, q4 + VRSHL.S32 q14, q14, q4 + + # - q4 = d8:d9 = vright_shift_0x4567 + VLD1.32 {d8, d9}, [r8] + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask_0x4567 + VCEQ.S32 q5, q4, 0 + + VBIC q0, q9, q5 + VBIC q1, q11, q5 + VBIC q2, q13, q5 + VBIC q3, q15, q5 + + VSRA.S32 q9, q0, 31 + VSRA.S32 q11, q1, 31 + VSRA.S32 q13, q2, 31 + VSRA.S32 q15, q3, 31 + + VRSHL.S32 q9, q9, q4 + VRSHL.S32 q11, q11, q4 + VRSHL.S32 q13, q13, q4 + VRSHL.S32 q15, q15, q4 + + # Load output_zero_point + # - q7 = d14:d15 = voutput_zero_point + VLD1.16 {d14[], d15[]}, [r7]! + + # Load max: + # - q5 = d10:d11 = vmax + VLD1.8 {d10[], d11[]}, [r7]! + + # Load c, c_stride: + # - r2 = c + # - r2 = c_stride + LDRD r2, r3, [sp, 92] + + VQMOVN.S32 d16, q8 + VQMOVN.S32 d17, q9 + VQMOVN.S32 d18, q10 + VQMOVN.S32 d19, q11 + VQMOVN.S32 d20, q12 + VQMOVN.S32 d21, q13 + VQMOVN.S32 d22, q14 + VQMOVN.S32 d23, q15 + + # Load min: + # - q4 = q8:q9 = vmin + VLD1.8 {d8[], d9[]}, [r7]! + ADD r4, r2, r3 + + VQADD.S16 q8, q8, q7 + VQADD.S16 q9, q9, q7 + CMP r0, 2 + VQADD.S16 q10, q10, q7 + VQADD.S16 q11, q11, q7 + MOVLO r4, r2 + + VQMOVUN.S16 d16, q8 + VQMOVUN.S16 d17, q9 + ADD r5, r4, r3 + VQMOVUN.S16 d18, q10 + VQMOVUN.S16 d19, q11 + MOVLS r5, r4 + + VMIN.U8 q8, q8, q5 + CMP r0, 4 + VMIN.U8 q9, q9, q5 + ADD r3, r5, r3 + + VMAX.U8 q8, q8, q4 + MOVNE r3, r5 + CMP r1, 8 + VMAX.U8 q9, q9, q4 + + BNE 4f + + VST1.8 {d16}, [r2] + VST1.8 {d17}, [r4] + VST1.8 {d18}, [r5] + VST1.8 {d19}, [r3] + + VPOP {d8-d15} + POP {r4-r8} + BX lr + + .p2align 3 +4: + CMP r1, 4 + BLO 5f + + VST1.32 {d16[0]}, [r2]! + VST1.32 {d17[0]}, [r4]! + VST1.32 {d18[0]}, [r5]! + VST1.32 {d19[0]}, [r3]! + + SUB r1, 4 + VEXT.8 q8, q8, q8, 4 + VEXT.8 q9, q9, q9, 4 + +5: + CMP r1, 2 + BLO 6f + + VST1.16 {d16[0]}, [r2]! + VST1.16 {d17[0]}, [r4]! + VST1.16 {d18[0]}, [r5]! + VST1.16 {d19[0]}, [r3]! + + SUB r1, 2 + VEXT.8 q8, q8, q8, 2 + VEXT.8 q9, q9, q9, 2 + +6: + TEQ r1, 0 + BEQ 7f + + VST1.8 {d16[0]}, [r2] + VST1.8 {d17[0]}, [r4] + VST1.8 {d18[0]}, [r5] + VST1.8 {d19[0]}, [r3] + +7: + VPOP {d8-d15} + POP {r4-r8} + BX lr +END_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/src/q8gemm/4x8-aarch32-neon.S b/src/q8gemm/4x8-aarch32-neon.S index a8c1021..1cbbc1a 100644 --- a/src/q8gemm/4x8-aarch32-neon.S +++ b/src/q8gemm/4x8-aarch32-neon.S @@ -710,7 +710,7 @@ BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon VQADD.S16 q9, q9, q7 CMP r0, 2 VQADD.S16 q10, q10, q7 - VQADD.S16 q11, q11, q7 + VQADD.S16 q11, q11, q7 MOVLO r4, r2 VQMOVUN.S16 d16, q8 diff --git a/src/qnnpack/q8gemm.h b/src/qnnpack/q8gemm.h index d063c55..2bf8753 100644 --- a/src/qnnpack/q8gemm.h +++ b/src/qnnpack/q8gemm.h @@ -58,6 +58,8 @@ DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_4x4c2__sse2) DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_ukernel_4x8__neon_per_channel) +DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_ukernel_4x8__aarch32_neon_per_channel) + #define DECLARE_Q8GEMM_XZP_UKERNEL_FUNCTION(fn_name) \ QNNP_INTERNAL void fn_name( \ size_t mr, \ diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index 9333e03..6a0b4e8 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -16,6 +16,7 @@ #include #include #include +#include #include @@ -310,6 +311,8 @@ class GemmMicrokernelTester { kernelZeroPointPerChannel[i] = static_cast(std::min(255, std::max(0, bZeroPoint() + (int)(i - nr()/2)))); kernelAndInputScalePerChannel[i] = scale_min + i * (scale_max - scale_min) / nr(); + // kernelZeroPointPerChannel[i] = 127; + // kernelAndInputScalePerChannel[i] = 1.0f; } const uint8_t* aPtr = a.data() + 8; diff --git a/test/q8gemm.cc b/test/q8gemm.cc index e4e9ee4..35254bd 100644 --- a/test/q8gemm.cc +++ b/test/q8gemm.cc @@ -604,6 +604,301 @@ } } } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_qmin128_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .qmin(128) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_qmax128_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .qmax(128) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_azp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_bzp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_nozp_per_channel) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_azp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_bzp0_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .bZeroPoint(0) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_nozp_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_subtile_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_subtile_per_channel) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); + } + } + } + } #endif #if CPUINFO_ARCH_ARM64 From 5609f855b52837e98f602b838cd262e1d20e19a2 Mon Sep 17 00:00:00 2001 From: Mor Tzur Date: Thu, 7 Mar 2019 11:07:09 -0800 Subject: [PATCH 7/7] cleanup comments --- src/q8gemm/4x8-aarch32-neon-per-channel.S | 94 +++++++++++------------ test/gemm-microkernel-tester.h | 3 - 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/src/q8gemm/4x8-aarch32-neon-per-channel.S b/src/q8gemm/4x8-aarch32-neon-per-channel.S index e200124..a88f368 100644 --- a/src/q8gemm/4x8-aarch32-neon-per-channel.S +++ b/src/q8gemm/4x8-aarch32-neon-per-channel.S @@ -37,8 +37,8 @@ BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel # - r7 = quantization_params LDR r7, [sp, 100] - # load kernel_quantization_params_offset - LDR r5, [sp, 104] + # load kernel_quantization_params_offset + LDR r5, [sp, 104] # Load bias0123, bias4567 VLDM ip!, {d16-d19} @@ -48,19 +48,19 @@ BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel LDR r6, [sp, 84] CMP r0, 2 - # a1 = a0 + a_stride + # a1 = a0 + a_stride ADD r4, r3, r6 # Load b_zero_point from kernel_zero_point_v: # - d15 = b_zero_point - MOV r8, 16 - ADD r8, r7, r8 - LDR r8, [r8, r5] + MOV r8, 16 + ADD r8, r7, r8 + LDR r8, [r8, r5] VLD1.8 {d15}, [r8] MOVLO r4, r3 - # Move kernel_quantization_params_offset to r8 to use later - MOV r8, r5 + # Move kernel_quantization_params_offset to r8 to use later + MOV r8, r5 ADD r7, r7, 4 ADD r5, r4, r6 @@ -641,61 +641,61 @@ BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel .p2align 4 2: - # Load multiplier: - # - d12 = vmultiplier ( vmultiplier0x0123 ) - ADD r7, 16 - LDR r8, [sp, 104] - LDR r8, [r7, r8] - VLD1.32 {d12, d13}, [r8]! - SUB r7, 12 + # Load multiplier: + # - d12 = vmultiplier ( vmultiplier0x0123 ) + ADD r7, 16 + LDR r8, [sp, 104] + LDR r8, [r7, r8] + VLD1.32 {d12, d13}, [r8]! + SUB r7, 12 VQRDMULH.S32 q8, q8, q6 VQRDMULH.S32 q10, q10, q6 - VQRDMULH.S32 q12, q12, q6 - VQRDMULH.S32 q14, q14, q6 + VQRDMULH.S32 q12, q12, q6 + VQRDMULH.S32 q14, q14, q6 - VLD1.32 {d12, d13}, [r8] + VLD1.32 {d12, d13}, [r8] - VQRDMULH.S32 q9, q9, q6 - VQRDMULH.S32 q11, q11, q6 - VQRDMULH.S32 q13, q13, q6 + VQRDMULH.S32 q9, q9, q6 + VQRDMULH.S32 q11, q11, q6 + VQRDMULH.S32 q13, q13, q6 VQRDMULH.S32 q15, q15, q6 - # Load right_shift + # Load right_shift # - q4 = d8:d9 = vright_shift_0x0123 - ADD r7, 16 - LDR r8, [sp, 104] - LDR r8, [r7, r8] + ADD r7, 16 + LDR r8, [sp, 104] + LDR r8, [r7, r8] VLD1.32 {d8, d9}, [r8]! - SUB r7, 12 + SUB r7, 12 # Compute vzero_shift_mask # - q5 = vzero_shift_mask_0x0123 VCEQ.S32 q5, q4, 0 - VBIC q0, q8, q5 - VBIC q1, q10, q5 - VBIC q2, q12, q5 - VBIC q3, q14, q5 + VBIC q0, q8, q5 + VBIC q1, q10, q5 + VBIC q2, q12, q5 + VBIC q3, q14, q5 - VSRA.S32 q8, q0, 31 - VSRA.S32 q10, q1, 31 - VSRA.S32 q12, q2, 31 - VSRA.S32 q14, q3, 31 + VSRA.S32 q8, q0, 31 + VSRA.S32 q10, q1, 31 + VSRA.S32 q12, q2, 31 + VSRA.S32 q14, q3, 31 - VRSHL.S32 q8, q8, q4 - VRSHL.S32 q10, q10, q4 - VRSHL.S32 q12, q12, q4 - VRSHL.S32 q14, q14, q4 + VRSHL.S32 q8, q8, q4 + VRSHL.S32 q10, q10, q4 + VRSHL.S32 q12, q12, q4 + VRSHL.S32 q14, q14, q4 - # - q4 = d8:d9 = vright_shift_0x4567 - VLD1.32 {d8, d9}, [r8] + # - q4 = d8:d9 = vright_shift_0x4567 + VLD1.32 {d8, d9}, [r8] - # Compute vzero_shift_mask + # Compute vzero_shift_mask # - q5 = vzero_shift_mask_0x4567 VCEQ.S32 q5, q4, 0 - VBIC q0, q9, q5 + VBIC q0, q9, q5 VBIC q1, q11, q5 VBIC q2, q13, q5 VBIC q3, q15, q5 @@ -705,12 +705,12 @@ BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel VSRA.S32 q13, q2, 31 VSRA.S32 q15, q3, 31 - VRSHL.S32 q9, q9, q4 - VRSHL.S32 q11, q11, q4 - VRSHL.S32 q13, q13, q4 - VRSHL.S32 q15, q15, q4 + VRSHL.S32 q9, q9, q4 + VRSHL.S32 q11, q11, q4 + VRSHL.S32 q13, q13, q4 + VRSHL.S32 q15, q15, q4 - # Load output_zero_point + # Load output_zero_point # - q7 = d14:d15 = voutput_zero_point VLD1.16 {d14[], d15[]}, [r7]! diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index 6a0b4e8..9333e03 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -16,7 +16,6 @@ #include #include #include -#include #include @@ -311,8 +310,6 @@ class GemmMicrokernelTester { kernelZeroPointPerChannel[i] = static_cast(std::min(255, std::max(0, bZeroPoint() + (int)(i - nr()/2)))); kernelAndInputScalePerChannel[i] = scale_min + i * (scale_max - scale_min) / nr(); - // kernelZeroPointPerChannel[i] = 127; - // kernelAndInputScalePerChannel[i] = 1.0f; } const uint8_t* aPtr = a.data() + 8;