From 4b0ab8d749e61c23715b5dae9b27e8a0795010d7 Mon Sep 17 00:00:00 2001 From: "zhuan.liu@intel.com" Date: Wed, 2 Apr 2025 01:07:49 -0400 Subject: [PATCH] test --- .../core/kernels/mkl/mkl_matmul_op_fused.cc | 109 +++++++----------- .../core/kernels/mkl/mkl_matmul_ops_common.h | 24 ++-- tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 6 +- 3 files changed, 58 insertions(+), 81 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 2d0065a52e5b4a..4403f817f16f46 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tensorflow { // Fuse Operation -template +template class MklFusedMatMulOp : public MklDnnMatMulOpBase { public: explicit MklFusedMatMulOp(OpKernelConstruction* ctx) @@ -68,17 +68,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { (void)SetFPMathMode(); } - MklDnnShape src_mkl_shape; - MklDnnShape weight_mkl_shape; - GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape, native_format); - GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape, native_format); - OP_REQUIRES( - ctx, !weight_mkl_shape.IsMklTensor(), - absl::InvalidArgumentError("Weight should not be in MKL Layout")); - // Get shapes of input tensors - auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() - : src_tensor.shape(); + auto src_tf_shape = src_tensor.shape(); auto weight_tf_shape = weight_tensor.shape(); // Check the constraint of input matrix and bias @@ -90,42 +81,47 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { OP_REQUIRES(ctx, bias_tensor.dim_size(i) == 1, absl::InvalidArgumentError( absl::StrCat("For bias_dims > 1, all except the " - "last dimension (channel) must be 1, got: ", + "last dimension (n) must be 1, got: ", bias_tensor.shape().DebugString()))); } - // Expression: [batch, k] * [k, channel] + [channel] = [batch, channel] + // Expression: [m, k] * [k, n] + [n] = [m, n] // // Get dimension size of each matrix, dim_pair[] is the location of k // in the inputs, we have constraint that k of the two inputs are // the same const int64_t dim_pair[] = {1, transpose_b_ ? 1 : 0}; - const int64_t batch = src_tf_shape.dim_size(1 - dim_pair[0]); + const int64_t m = src_tf_shape.dim_size(1 - dim_pair[0]); const int64_t k = src_tf_shape.dim_size(dim_pair[0]); - const int64_t channel = weight_tf_shape.dim_size(1 - dim_pair[1]); + const int64_t n = weight_tf_shape.dim_size(1 - dim_pair[1]); OP_REQUIRES( ctx, k == weight_tf_shape.dim_size(dim_pair[1]), absl::InvalidArgumentError(absl::StrCat( "Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(), ", In[1]: ", weight_tf_shape.DebugString()))); - OP_REQUIRES(ctx, bias_tensor.dim_size(bias_tensor.dims() - 1) == channel, + OP_REQUIRES(ctx, bias_tensor.dim_size(bias_tensor.dims() - 1) == n, absl::InvalidArgumentError(absl::StrCat( - "Must provide as many biases as the channel size: ", - bias_tensor.shape().DebugString(), " vs. ", channel))); + "Must provide as many biases as the n size: ", + bias_tensor.shape().DebugString(), " vs. ", n))); - // For inputs s[batch, k], w[k, channel] and b[channel], the primitive + // For inputs s[m, k], w[k, n] and b[n], the primitive // dims should be described like this: - // s[batch, k] * w^T[channel, k] + b[channel] = dst[batch, channel] + // s[m, k] * w^T[n, k] + b[n] = dst[m, n] // [n, ic] * [oc, ic] + [oc] = [n, oc] - memory::dims src_dims = memory::dims({batch, k}); - // Reverse the weights dims from [k, channel] to [channel, k]. - memory::dims weight_dims = memory::dims({channel, k}); - memory::dims bias_dims = memory::dims({channel}); - memory::dims dst_dims = memory::dims({batch, channel}); - memory::format_tag src_format = memory::format_tag::nc; + // memory::dims src_dims = memory::dims({m, k}); + // // Reverse the weights dims from [k, n] to [n, k]. + // memory::dims weight_dims = memory::dims({n, k}); + memory::dims src_dims = memory::dims({m, k}); + memory::dims weight_dims = memory::dims({k, n}); + // broadcast: this op used to call oneDNN inner-product op + // So bias input is 1-dimensional. Now this op calls oneDNN + // matmul op, thus here it should be 2-dimensional. + memory::dims bias_dims = memory::dims({1, n}); + memory::dims dst_dims = memory::dims({m, n}); + memory::format_tag src_format = memory::format_tag::ab; memory::format_tag weight_format = - transpose_b_ ? memory::format_tag::oi : memory::format_tag::io; + transpose_b_ ? memory::format_tag::ba : memory::format_tag::ab; // Set weight format `any` for primitive as per oneDNN recommendation. MklDnnMatMulFwdParams matmul_params( @@ -134,7 +130,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { memory::format_tag::nc, this->is_weight_const_); // Extend the basic parameters for data types and fusions. ExtendMklDnnMatMulFwdParams(ctx, matmul_params); - auto st = ExecuteSingleThreadedGemm(batch, channel, k, sizeof(T)); + auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T)); // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. Eigen::ThreadPoolInterface* eigen_interface = @@ -146,7 +142,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // Allocate output tensor. Tensor* dst_tensor = nullptr; - std::shared_ptr matmul_pd = + std::shared_ptr matmul_pd = matmul_prim->GetPrimitiveDesc(); // The output shape of MatMul is same both for MKL and TF version. @@ -155,34 +151,21 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(false); - TensorShape output_tf_shape({batch, channel}); + TensorShape output_tf_shape({m, n}); if (fuse_add_) { const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add); - MklDnnShape add_mkl_shape; - GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format); - - // For native format, we need not to set metadata. - if (native_format && ctx->forward_input_to_output_with_shape( - kInputIndex_Add, kOutputIndex_Dst, - output_tf_shape, &dst_tensor)) { - ; // Need to do nothing for native format - } else if (!native_format && ForwardMklTensorInToOutWithMklShape( - ctx, kInputIndex_Add, kOutputIndex_Dst, - &dst_tensor, output_mkl_shape, false)) { - ; // If it's not native format, need to forward and set meta first - } else { - // If forward is not successful, we should use reorder to copy add + if (!ctx->forward_input_to_output_with_shape( + kInputIndex_Add, kOutputIndex_Dst, + output_tf_shape, &dst_tensor)) { + // If forward is not successful, we should use reorder to copy add // tensor to dst tensor - AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor, - output_tf_shape, output_mkl_shape, - native_format); + OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputIndex_Dst, + output_tf_shape, &dst_tensor)); auto output_format_tag = MklTensorFormatToMklDnnDataFormat(MklTensorFormat::FORMAT_NC); auto add_md = - add_mkl_shape.IsMklTensor() - ? add_mkl_shape.GetMklLayout() - : memory::desc(dst_dims, MklDnnType(), output_format_tag); + memory::desc(dst_dims, MklDnnType(), output_format_tag); auto dst_md = memory::desc(dst_dims, MklDnnType(), output_format_tag); @@ -190,13 +173,11 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { static_cast(const_cast(add_tensor.flat().data())); void* dst_buf = static_cast((dst_tensor)->flat().data()); - if (native_format) { - // We are simply deep copying the add_tensor to dst_tensor without - // changing memory layout, hence using same memory descriptor. - add_md = dst_md = - memory::desc({add_tensor.NumElements()}, MklDnnType(), - dnnl::memory::format_tag::x); - } + // We are simply deep copying the add_tensor to dst_tensor without + // changing memory layout, hence using same memory descriptor. + add_md = dst_md = + memory::desc({add_tensor.NumElements()}, MklDnnType(), + dnnl::memory::format_tag::x); auto fuse_add_src_ = memory(add_md, this->cpu_engine_, add_buf); auto fuse_add_dst_ = memory(dst_md, this->cpu_engine_, dst_buf); @@ -207,12 +188,12 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { this->cpu_engine_, ctx); } } else { - AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape, - output_mkl_shape, native_format); + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, output_tf_shape, &dst_tensor)); } // if there's nothing to compute, just return. - if (batch == 0 || channel == 0) { + if (m == 0 || n == 0) { return; } @@ -227,9 +208,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { MklDnnData src_mkl(&(this->cpu_engine_)); MklDnnData weight_mkl(&(this->cpu_engine_)); - auto src_md = src_mkl_shape.IsMklTensor() - ? src_mkl_shape.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), src_format); + auto src_md = memory::desc(src_dims, MklDnnType(), src_format); if (src_md != matmul_pd->src_desc()) { src_mkl.SetUsrMem(src_md, src_data); @@ -344,7 +323,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklNameChangeOpLabel), \ - MklFusedMatMulOp); + MklFusedMatMulOp); TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_bfloat16(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_half(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); @@ -352,4 +331,4 @@ TF_CALL_half(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow -#endif // INTEL_MKL +#endif // INTEL_MKL \ No newline at end of file diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 6585c4dbf1316f..56f9cae00d2d29 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #endif -using dnnl::inner_product_forward; +using dnnl::matmul; using dnnl::primitive_attr; using dnnl::prop_kind; using dnnl::stream; @@ -188,7 +188,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(DummyData); } - std::shared_ptr + std::shared_ptr GetPrimitiveDesc() const { return context_.fwd_pd; } @@ -209,9 +209,9 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { // Descriptor and primitive-descriptor for forward inner-product. #ifndef ENABLE_ONEDNN_V3 - std::shared_ptr fwd_desc; + std::shared_ptr fwd_desc; #endif // !ENABLE_ONEDNN_V3 - std::shared_ptr fwd_pd; + std::shared_ptr fwd_pd; // Memory descriptors. std::shared_ptr src_md; @@ -283,12 +283,12 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { } // Create an inner-product. #ifndef ENABLE_ONEDNN_V3 - context_.fwd_desc.reset(new inner_product_forward::desc( + context_.fwd_desc.reset(new matmul::desc( matmul_fwd_params.const_weight ? prop_kind::forward_inference : prop_kind::forward_training, *context_.src_md, *context_.weight_md, *context_.bias_md, *context_.dst_md)); - context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + context_.fwd_pd.reset(new matmul::primitive_desc( *context_.fwd_desc, cpu_engine_)); #endif // !ENABLE_ONEDNN_V3 @@ -396,13 +396,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { } #ifndef ENABLE_ONEDNN_V3 - context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + context_.fwd_pd.reset(new matmul::primitive_desc( *context_.fwd_desc, post_ops_attr, cpu_engine_)); #else - context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + context_.fwd_pd.reset(new matmul::primitive_desc( cpu_engine_, - matmul_fwd_params.const_weight ? prop_kind::forward_inference - : prop_kind::forward_training, *context_.src_md, *context_.weight_md, *context_.bias_md, *context_.dst_md, post_ops_attr)); #endif // !ENABLE_ONEDNN_V3 @@ -421,7 +419,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); // Create inner-product primitive. - context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd)); + context_.matmul_fwd.reset(new matmul(*context_.fwd_pd)); std::unordered_map net_args = { {DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_WEIGHTS, *context_.weight_mem}, @@ -561,7 +559,7 @@ class MklDnnMatMulOpBase : public OpKernel { // Allocate output tensor. virtual void AllocateOutputTensor( OpKernelContext* context, - const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc, + const matmul::primitive_desc& mkldnn_matmul_prim_desc, const memory::dims& output_dims_mkl_order, MklTensorFormat output_tf_format, Tensor** output_tensor, bool native_format = false) { @@ -599,7 +597,7 @@ class MklDnnMatMulOpBase : public OpKernel { // Only one thread can execute this method at any given time. void CacheWeight( OpKernelContext* context, - const std::shared_ptr& + const std::shared_ptr& matmul_fwd_pd, Tweight* weight_data, const Tensor& weight_tensor, MklDnnData& weight, const memory::desc& weight_md) diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index efb33375d1669d..f7aa531723ea1a 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -262,7 +262,7 @@ class MklDnnQuantizedMatMulOp Toutput>::Get(matmul_fwd_dims, 0); // Allocate output Tensor. - std::shared_ptr + std::shared_ptr matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc(); this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order, input_output_fmt_mkldnn, &dst_tensor, @@ -515,7 +515,7 @@ class MklDnnQuantizedMatMulOp #ifndef ENABLE_ONEDNN_V3 Tbias* GetBiasHandle( OpKernelContext* context, - std::shared_ptr& + std::shared_ptr& mkldnn_matmul_fwd_pd, const Tensor& bias_tensor, const Tensor& weight_tensor, std::shared_ptr reorder_stream) { @@ -621,7 +621,7 @@ class MklDnnQuantizedMatMulOp #else void GetBiasHandle( OpKernelContext* context, - std::shared_ptr& + std::shared_ptr& mkldnn_matmul_fwd_pd, const Tensor& bias_tensor, const Tensor& weight_tensor, std::shared_ptr reorder_stream, Tensor* temp_scaled_bias_tensor,