From b42da8b21b02c2493e5dde81aeee61fe81e43b2b Mon Sep 17 00:00:00 2001 From: chenhuwa Date: Mon, 13 Jun 2022 12:52:24 +0800 Subject: [PATCH] denormals optimization for brgemm kernels --- src/cpu/x64/jit_brdgmm_dw_conv.cpp | 8 ++++++++ src/cpu/x64/jit_brgemm_1x1_conv.cpp | 13 +++++++++++++ src/cpu/x64/jit_brgemm_conv.cpp | 15 +++++++++++++++ src/cpu/x64/jit_brgemm_inner_product.cpp | 15 +++++++++++++++ 4 files changed, 51 insertions(+) diff --git a/src/cpu/x64/jit_brdgmm_dw_conv.cpp b/src/cpu/x64/jit_brdgmm_dw_conv.cpp index f890444725c..716a414f783 100644 --- a/src/cpu/x64/jit_brdgmm_dw_conv.cpp +++ b/src/cpu/x64/jit_brdgmm_dw_conv.cpp @@ -16,6 +16,7 @@ #include "common/memory_tracking.hpp" #include "common/utils.hpp" +#include #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" #include "cpu/x64/jit_brdgmm_dw_conv.hpp" @@ -358,6 +359,13 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { const size_t dst_mb_stride = jcp.ngroups * jcp.ow * jcp.oh * jcp.dst_dsz; parallel(jcp.nthr, [&](const int ithr, const int nthr) { + // TODO this is a hot fix for denormals, need refactor + unsigned int DENORMALS_ZERO = 0x0040; + unsigned int FLUSH_ZERO = 0x8000; + unsigned int csr = _mm_getcsr(); + csr |= DENORMALS_ZERO; + csr |= FLUSH_ZERO; + _mm_setcsr(csr); int start {0}, end {0}; balance211(work_amount, nthr, ithr, start, end); int n {0}, chb {0}, oh {0}, owb {0}; diff --git a/src/cpu/x64/jit_brgemm_1x1_conv.cpp b/src/cpu/x64/jit_brgemm_1x1_conv.cpp index 9548187a7ad..1fdf9d745d3 100644 --- a/src/cpu/x64/jit_brgemm_1x1_conv.cpp +++ b/src/cpu/x64/jit_brgemm_1x1_conv.cpp @@ -19,6 +19,7 @@ #include "common/nstl.hpp" #include "common/type_helpers.hpp" #include "common/utils.hpp" +#include #include "cpu/cpu_primitive.hpp" @@ -487,6 +488,12 @@ status_t brgemm_1x1_convolution_fwd_t::execute_forward_all( #define BRGC_WO(...) \ parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \ if (ithr >= work_amount) return; \ + unsigned int DENORMALS_ZERO = 0x0040; \ + unsigned int FLUSH_ZERO = 0x8000; \ + unsigned int csr = _mm_getcsr(); \ + csr |= DENORMALS_ZERO; \ + csr |= FLUSH_ZERO; \ + _mm_setcsr(csr); \ brgemm_batch_element_t *const brg_batch \ = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \ char *const c_buffer = (jcp.use_buffer) \ @@ -552,6 +559,12 @@ status_t brgemm_1x1_convolution_fwd_t::execute_forward_all( #define BRGC_WO(...) \ parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \ if (ithr >= work_amount) return; \ + unsigned int DENORMALS_ZERO = 0x0040; \ + unsigned int FLUSH_ZERO = 0x8000; \ + unsigned int csr = _mm_getcsr(); \ + csr |= DENORMALS_ZERO; \ + csr |= FLUSH_ZERO; \ + _mm_setcsr(csr); \ brgemm_batch_element_t *const brg_batch \ = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \ char *const c_buffer = (jcp.use_buffer) \ diff --git a/src/cpu/x64/jit_brgemm_conv.cpp b/src/cpu/x64/jit_brgemm_conv.cpp index b8add4c78af..07f150ad835 100644 --- a/src/cpu/x64/jit_brgemm_conv.cpp +++ b/src/cpu/x64/jit_brgemm_conv.cpp @@ -20,6 +20,7 @@ #include "common/utils.hpp" #include "cpu/cpu_primitive.hpp" #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" +#include #include "cpu/x64/amx_tile_configure.hpp" #include "cpu/x64/jit_brgemm_conv.hpp" @@ -855,6 +856,13 @@ status_t brgemm_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { parallel(jcp.nthr, [&](const int ithr, const int nthr) { if (ithr >= work_amount) return; + // TODO this is a hot fix for denormals, need refactor + unsigned int DENORMALS_ZERO = 0x0040; + unsigned int FLUSH_ZERO = 0x8000; + unsigned int csr = _mm_getcsr(); + csr |= DENORMALS_ZERO; + csr |= FLUSH_ZERO; + _mm_setcsr(csr); brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global + static_cast(ithr) * jcp.adjusted_batch_size; @@ -1014,6 +1022,13 @@ status_t brgemm_convolution_fwd_t::cal_compensation( const dim_t kd_b {kd_bs[k]}, kd_e {kd_es[k]}, kh_b {kh_bs[k]}, kh_e {kh_es[k]}; assert(kd_e > kd_b && kh_e > kh_b); + // TODO this is a hot fix for denormals, need refactor + unsigned int DENORMALS_ZERO = 0x0040; + unsigned int FLUSH_ZERO = 0x8000; + unsigned int csr = _mm_getcsr(); + csr |= DENORMALS_ZERO; + csr |= FLUSH_ZERO; + _mm_setcsr(csr); if (jcp.exec_type == exec_vpad && jcp.max_vpad > 0) { const auto ow = owb * jcp.ow_block; diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 557196f5c2f..56c01be8d41 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -18,6 +18,7 @@ #include "common/dnnl_thread.hpp" #include "common/type_helpers.hpp" #include "common/utils.hpp" +#include #include "cpu/cpu_primitive.hpp" @@ -297,6 +298,13 @@ status_t brgemm_inner_product_fwd_t::execute_forward( bool ok = init_thr_groups( ithr, nthr, nthr_ic, nthr_oc_mb, ithr_ic, ithr_oc_mb); if (!ok) return; + // TODO this is a hot fix for denormals, need refactor + unsigned int DENORMALS_ZERO = 0x0040; + unsigned int FLUSH_ZERO = 0x8000; + unsigned int csr = _mm_getcsr(); + csr |= DENORMALS_ZERO; + csr |= FLUSH_ZERO; + _mm_setcsr(csr); int start {0}, end {0}; balance211(work_amount, nthr_oc_mb, ithr_oc_mb, start, end); @@ -376,6 +384,13 @@ status_t brgemm_inner_product_fwd_t::execute_forward( bool ok = init_thr_groups( ithr, nthr, nthr_ic, nthr_oc_mb, ithr_ic, ithr_oc_mb); if (!ok) return; + // TODO this is a hot fix for denormals, need refactor + unsigned int DENORMALS_ZERO = 0x0040; + unsigned int FLUSH_ZERO = 0x8000; + unsigned int csr = _mm_getcsr(); + csr |= DENORMALS_ZERO; + csr |= FLUSH_ZERO; + _mm_setcsr(csr); int ocmb_start {0}, ocmb_end {0}; int start {0}, end {0};