From 689325721e9d6dbf62e75699af51420ed423343c Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 16 Jan 2024 06:41:33 +0000 Subject: [PATCH 1/3] do not modify output at same cases --- .../transforms/auto_mixed_precision_pass.cc | 58 +++++++++++-------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 6846ccc6b57536..17a252838f06d9 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -610,15 +610,6 @@ class AutoMixedPrecisionPass : public pir::Pass { VLOG(6) << "Change result's dtype to low precision " << op->name() << std::endl; - if (op->HasAttribute("dtype") && - IsPhiDataTypeFloat( - op->attribute("dtype") - .data())) { - pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( - builder.ir_context(), precision_mode_); - op->set_attribute("dtype", attr_dtype); - } - auto phi_kernel = GetPhiKernelInPrecision(op_type, backend, precision_mode_); PADDLE_ENFORCE( @@ -633,6 +624,40 @@ class AutoMixedPrecisionPass : public pir::Pass { auto input_defs = args_def.input_defs(); auto output_defs = args_def.output_defs(); + // if any of the op's input is not in low precision, insert cast op + // input_defs will always be the smaller one? + for (size_t i = 0; i < input_defs.size(); i++) { + auto operand = op->operand(i); + auto in_phi_dtype = input_defs[i].dtype; + if (!IsOperandHasDenseTensorType(operand)) continue; + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype != in_phi_dtype) { + VLOG(6) << "Support low precision, insert CastOp for " << op->name() + << " operand " << i << std::endl; + DoInsertCastOp(op, operand, in_phi_dtype, builder); + } + } + + if (op->HasAttribute("dtype")) { + auto phi_dtype = op->attribute("dtype") + .dyn_cast() + .data(); + if (IsPhiDataTypeFloat(phi_dtype)) { + pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( + builder.ir_context(), precision_mode_); + op->set_attribute("dtype", attr_dtype); + } else if (phi_dtype == + phi::DataType::UNDEFINED) { // dtype is not set, means all + // ok + pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( + builder.ir_context(), precision_mode_); + op->set_attribute("dtype", attr_dtype); + } else { + return; // don't modify output dtype + } + } + PADDLE_ENFORCE_EQ( op->num_results(), output_defs.size(), @@ -653,21 +678,6 @@ class AutoMixedPrecisionPass : public pir::Pass { // type SetResultDataType(result, out_phi_dtype, builder.ir_context()); } - - // if any of the op's input is not in low precision, insert cast op - // input_defs will always be the smaller one? - for (size_t i = 0; i < input_defs.size(); i++) { - auto operand = op->operand(i); - auto in_phi_dtype = input_defs[i].dtype; - if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); - if (IsPhiDataTypeFloat(operand_phi_dtype) && - operand_phi_dtype != in_phi_dtype) { - VLOG(6) << "Support low precision, insert CastOp for " << op->name() - << " operand " << i << std::endl; - DoInsertCastOp(op, operand, in_phi_dtype, builder); - } - } } else { // current op doesn't support low precision, should cast to float // if the op's input is in low precision, insert cast op auto phi_dtype = phi::DataType::FLOAT32; From af97319c3a1b0d192ab57fae04ae3a984ad48867 Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 20 Jan 2024 05:58:26 +0000 Subject: [PATCH 2/3] auto_mixed_precision ca run --- paddle/fluid/inference/api/analysis_predictor.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index de6a2d0c97189c..e056c1d3631065 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -839,9 +839,7 @@ bool AnalysisPredictor::PrepareExecutor() { if (!config_.glog_info_disabled()) { gpu_pm.EnablePrintStatistics(); } - if (config_.ir_debug_) { - gpu_pm.EnableIRPrinting(); - } + gpu_pm.EnableIRPrinting(); gpu_pm.Run(pir_program_.get()); } From bbe99e9775c1ba75b0932f0680fabcf063962168 Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 20 Jan 2024 14:30:29 +0000 Subject: [PATCH 3/3] delete useless code --- test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc index 8c16fd37dd34a8..68d4de4758e178 100644 --- a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc +++ b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/pir/core/builtin_dialect.h"