From d0197776b718ea18cbf2f5e579f7b3c47fb6b56f Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 7 Dec 2023 13:32:11 +0000 Subject: [PATCH 01/40] new auto-mixed-precision --- .../framework/ir/auto_mixed_precision_pass.cc | 8 + .../fluid/inference/api/analysis_predictor.cc | 5 + .../transforms/auto_mixed_precision_pass.cc | 505 ++++++++++++++++++ .../transforms/auto_mixed_precision_pass.h | 29 + .../transforms/transform_general_functions.cc | 17 +- test/cpp/pir/pattern_rewrite/CMakeLists.txt | 3 + .../auto_mixed_precision_test.cc | 104 ++++ 7 files changed, 666 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc create mode 100644 paddle/fluid/pir/transforms/auto_mixed_precision_pass.h create mode 100644 test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index fb75c18a6fae65..69292e18edabf0 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -26,6 +26,8 @@ #include "paddle/phi/backends/device_manager.h" #endif +PHI_DECLARE_bool(enable_pir_in_executor); + namespace paddle { namespace framework { namespace ir { @@ -40,6 +42,9 @@ bool PhiKernelSupportPrecision( phi::DataType data_type, phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { const auto& kernels = phi::KernelFactory::Instance().kernels(); + // for (auto [k, v] : kernels) { + // LOG(INFO) << "kernel name " << k << std::endl; + // } if (kernels.count(op_type) == 0) { return false; } @@ -270,6 +275,9 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const { } void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { + if (FLAGS_enable_pir_in_executor) { + return; + } PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::PreconditionNotMet( "During the auto_mixed_precision_pass, the graph " diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c70ef74e94baad..bc4c242c8ccb57 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -103,6 +103,7 @@ #endif #include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" @@ -801,6 +802,10 @@ bool AnalysisPredictor::PrepareExecutor() { //----------------------------------------------------------------------------------------------// // Functional pass + // Do auto mixed precision pass first, so do not need to handle + // shadowoutput. + pm_for_op_program.AddPass(::pir::CreateAutoMixedPrecisionPass( + place_, ConvertPrecision(config_.mixed_precision_mode_))); gpu_pm.AddPass(::pir::CreateIdentityOpCleanPass()); //----------------------------------------------------------------------------------------------// diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc new file mode 100644 index 00000000000000..c95d976b765128 --- /dev/null +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -0,0 +1,505 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" +#include +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/parameter.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class AutoMixedPrecisionPattern : public pir::RewritePattern { + public: + AutoMixedPrecisionPattern( + pir::IrContext* context, + const phi::Place& place, + const phi::DataType& precision_mode, + bool enable_low_precision_io = false, + pir::PatternBenefit benefit = 1, + const std::vector& generated_names = {}) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { + precision_mode_ = precision_mode; // should be set by user + place_ = place; // should be set by user + enable_low_precision_io_ = enable_low_precision_io; + SetDefaultBlacklist(); + SetDefaultWhitelist(); + } + + void SetDefaultBlacklist() { + black_list_.insert({ + paddle::dialect::ExpOp::name(), + paddle::dialect::SquareOp::name(), + paddle::dialect::LogOp::name(), + // paddle::dialect::FetchOp::name(), + + // paddle::dialect::Mean::name(), + // paddle::dialect::Sum::name(), + paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), + }); + } + + void SetDefaultWhitelist() { + // white_list_.insert({paddle::dialect::FullOp::name(), + // paddle::dialect::Conv2dOp::name(), + // paddle::dialect::TransposeOp::name()}); + // return; + } + + bool Match(pir::Operation* op) const override { + // if enable_low_precision_io_ is true, all the op will be transformed into, + // input and output included + if (op->isa() || op->isa() || + op->isa() || + op->isa()) + return false; + + if (!enable_low_precision_io_) { + if (op->isa()) return false; + } + + if (!IsBuiltinOp(op)) { + return OpHasFloatResult(op); + } + + return true; + } + + void Rewrite(pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // NOLINT + LOG(INFO) << "Rewrite op " << op->name() << std::endl; + if (IsBuiltinOp(op)) { + RewriteBuiltinOp(op, rewriter); + return; + } else { + RewritePdOp(op, rewriter); + return; + } + } + + private: + std::unordered_set black_list_; + std::unordered_set white_list_; + phi::DataType precision_mode_{phi::DataType::UNDEFINED}; + + phi::Place place_; + bool enable_low_precision_io_; + + bool PhiKernelSupportPrecision( + const std::string& op_type, + phi::Backend backend, + phi::DataType data_type, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) const { + const auto& kernels = phi::KernelFactory::Instance().kernels(); + if (kernels.count(op_type) == 0) { + return false; + } + phi::KernelKey kernel_key(backend, layout, data_type); + return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); + } + + phi::Backend ConvertPlaceToBackend(const phi::Place& place) const { + switch (place.GetType()) { + case phi::AllocationType::CPU: + return phi::Backend::CPU; + case phi::AllocationType::GPU: + return phi::Backend::GPU; + case phi::AllocationType::XPU: + return phi::Backend::XPU; + default: + return phi::Backend::UNDEFINED; + } + return phi::Backend::UNDEFINED; + } + + bool KernelSupportPrecision( + const std::string& op_type, + phi::Backend backend, + phi::DataType precision, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) const { + auto& phi_op_type = op_type; + LOG(INFO) << "phi_op_type = " << phi_op_type << std::endl; + + bool support = + PhiKernelSupportPrecision(phi_op_type, backend, precision, layout); + if (backend == phi::Backend::GPU) { + support |= PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPUDNN, precision, layout); + } + + if (!support) { + const auto& all_kernels = + paddle::framework::OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it != all_kernels.end()) { + for (const auto& kern_pair : it->second) { + if (ConvertPlaceToBackend(kern_pair.first.place_) == backend && + kern_pair.first.data_type_ == + paddle::framework::TransToProtoVarType(precision)) { + support = true; + break; + } + } + } + } + return support; + } + + phi::Kernel GetPhiKernelInPrecision(const std::string& kernel_fn_str, + phi::Backend backend, + phi::DataType precision) const { + if (backend == phi::Backend::GPU) { + if (PhiKernelSupportPrecision( + kernel_fn_str, phi::Backend::GPUDNN, precision)) { + phi::KernelKey kernel_key( + phi::Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, precision); + return phi::KernelFactory::Instance().SelectKernel(kernel_fn_str, + kernel_key); + } + phi::KernelKey kernel_key( + phi::Backend::GPU, phi::DataLayout::ALL_LAYOUT, precision); + return phi::KernelFactory::Instance().SelectKernel(kernel_fn_str, + kernel_key); + } + return phi::KernelFactory::Instance().SelectKernel( + kernel_fn_str, + phi::KernelKey(backend, phi::DataLayout::ALL_LAYOUT, precision)); + } + + bool IsBuiltinOp(pir::Operation* op) const { + return op->name().find("builtin") != std::string::npos; + } + + bool OpSupportPrecision(const std::string& kernel_fn_str, + phi::Backend backend, + phi::DataType precision) const { + // if the op is in white list, return true + if (white_list_.count(kernel_fn_str)) { + return true; + } + + // if the op is in black list, return false + if (black_list_.count(kernel_fn_str)) { + return false; + } + + return KernelSupportPrecision(kernel_fn_str, backend, precision); + } + + bool ValueInPrecision(pir::Value value, phi::DataType precision) const { + auto dtype = pir::GetDataTypeFromValue(value); + return paddle::dialect::TransToPhiDataType(dtype) == precision; + } + + void SetResultDataType(pir::Value result, + phi::DataType precision, + pir::IrContext* context) const { + auto type = result.type(); + if (type.isa()) { + auto dense_type = type.dyn_cast(); + auto new_type = paddle::dialect::DenseTensorType::get( + context, + paddle::dialect::TransToIrDataType(precision, context), + dense_type.dims(), + dense_type.data_layout(), + dense_type.lod(), + dense_type.offset()); + result.set_type(new_type); + } else if (type.isa()) { + auto vec_type = type.dyn_cast(); + auto output_num = vec_type.size(); + std::vector results_type(output_num); + for (size_t idx = 0; idx < output_num; ++idx) { + auto dense_type = + vec_type[idx].dyn_cast(); + auto new_type = paddle::dialect::DenseTensorType::get( + context, + paddle::dialect::TransToIrDataType(precision, context), + dense_type.dims(), + dense_type.data_layout(), + dense_type.lod(), + dense_type.offset()); + results_type[idx] = new_type; + } + auto new_vec_type = pir::VectorType::get(context, results_type); + result.set_type(new_vec_type); + } else { + LOG(INFO) << "result type is not DenseTensorType or VectorType" + << std::endl; + } + } + + bool OpHasFloatResult(pir::Operation* op) const { + for (size_t i = 0; i < op->num_results(); i++) { + auto result = op->result(i); + if (!result.type()) continue; + if (result.type().isa()) { + auto dtype = pir::GetDataTypeFromValue(result); + if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { + return true; + } + } else if (result.type().isa()) { + auto vec_type = result.type().dyn_cast(); + for (size_t j = 0; j < vec_type.size(); j++) { + auto dtype = + vec_type[j].dyn_cast().dtype(); + if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { + return true; + } + } + } + } + LOG(INFO) << "op " << op->name() << " doesn't have float result" + << std::endl; + return false; + } + + bool IsDataTypeFloat(const phi::DataType& dtype) const { + return dtype == phi::DataType::FLOAT32 || dtype == phi::DataType::FLOAT16 || + dtype == phi::DataType::BFLOAT16; + } + + bool IsOperandDenseTensorType(pir::OpOperand operand) const { + return operand.type() && + operand.type().isa(); + } + + void InsertCastOp(pir::Operation* op, + pir::OpOperand operand, + phi::DataType precision, + pir::PatternRewriter& rewriter) const { // NOLINT + auto value = operand.source(); + rewriter.set_insertion_point(op); // before op + paddle::dialect::CastOp cast_op = + rewriter.Build(value, precision); + operand.set_source(cast_op->result(0)); + } + + void RewriteBuiltinOp(pir::Operation* op, + pir::PatternRewriter& rewriter) const { // NOLINT + LOG(INFO) << "Rewrite builtin op " << op->name() << std::endl; + // Rewrite CombineOp + if (op->isa()) { + // auto vec_type = op->result(0).type().dyn_cast(); + auto input_num = op->num_operands(); + std::vector inputs_type(input_num); + for (size_t idx = 0; idx < input_num; ++idx) { + inputs_type[idx] = op->operand(idx).type(); + } + auto new_vec_type = + pir::VectorType::get(rewriter.ir_context(), inputs_type); + op->result(0).set_type(new_vec_type); + } + + // Rewrite SliceOp + if (op->isa()) { + auto index = + op->attribute("index").dyn_cast().data(); + auto input_type = op->operand(0).type().dyn_cast(); + auto new_type = input_type[index]; + op->result(0).set_type(new_type); + } + + // Rewrite SplitOp + if (op->isa()) { + auto input_type = op->operand(0).type().dyn_cast(); + int output_num = op->num_results(); + for (int i = 0; i < output_num; ++i) { + op->result(i).set_type(input_type[i]); + } + } + } + + void RewritePdOp(pir::Operation* op, + pir::PatternRewriter& rewriter) const { // NOLINT + LOG(INFO) << "Rewrite pd op " << op->name() << std::endl; + phi::Backend backend = ConvertPlaceToBackend(place_); + std::string op_type = op->name().substr(op->name().find(".") + 1); + + // Rewrite FetchOp + if (op->isa()) { + auto fetch_operand = op->operand(0); + if (enable_low_precision_io_) { + SetResultDataType( + op->result(0), precision_mode_, rewriter.ir_context()); + } + if (!op->result(0).type().isa()) return; + auto result_dtype = pir::GetDataTypeFromValue(op->result(0)); + if (!ValueInPrecision( + fetch_operand.source(), + paddle::dialect::TransToPhiDataType(result_dtype))) { + InsertCastOp(op, + fetch_operand, + paddle::dialect::TransToPhiDataType(result_dtype), + rewriter); + } + return; + } + // Rewrite FeedOp + if (op->isa() && enable_low_precision_io_) { + SetResultDataType(op->result(0), precision_mode_, rewriter.ir_context()); + return; + } + + if (OpSupportPrecision(op_type, backend, precision_mode_)) { + // change result's dtype to low precision + LOG(INFO) << "Change result's dtype to low precision " << op->name() + << std::endl; + + if (op->HasAttribute("dtype")) { + if (!IsDataTypeFloat( + op->attribute("dtype") + .data())) + return; + pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( + rewriter.ir_context(), precision_mode_); + op->set_attribute("dtype", attr_dtype); + } + + auto phi_kernel = + GetPhiKernelInPrecision(op_type, backend, precision_mode_); + PADDLE_ENFORCE( + phi_kernel.IsValid(), + phi::errors::PreconditionNotMet( + "op [%s] kernel doesn't support precision [%s] on backend [%s]", + op->name(), + phi::DataTypeToString(precision_mode_).c_str(), + paddle::experimental::BackendToString(backend).c_str())); + + auto args_def = phi_kernel.args_def(); + auto input_defs = args_def.input_defs(); + auto output_defs = args_def.output_defs(); + + PADDLE_ENFORCE_EQ( + op->num_results(), + output_defs.size(), + phi::errors::PreconditionNotMet( + "op [%s] kernel output args defs should equal op outputs", + op->name())); + + for (size_t i = 0; i < op->num_results(); i++) { + auto result = op->result(i); + if (!result.type()) continue; + phi::DataType out_phi_dtype = output_defs[i].dtype; + LOG(INFO) << "result dtype = " << phi::DataTypeToString(out_phi_dtype) + << std::endl; + if (out_phi_dtype == phi::DataType::UNDEFINED) continue; + SetResultDataType(result, out_phi_dtype, rewriter.ir_context()); + } + + // if any of the op's input is not in low precision, insert cast op + // input_defs will always be the smaller one? + for (size_t i = 0; i < input_defs.size(); i++) { + auto operand = op->operand(i); + if (!IsOperandDenseTensorType(operand)) continue; + auto in_phi_dtype = input_defs[i].dtype; + if (IsDataTypeFloat(in_phi_dtype) && + !ValueInPrecision(operand.source(), in_phi_dtype)) { + InsertCastOp(op, operand, in_phi_dtype, rewriter); + } + } + } else { // current op doesn't support low precision, should cast to float + // if the op's input is in low precision, insert cast op + auto phi_dtype = phi::DataType::FLOAT32; + for (size_t i = 0; i < op->num_operands(); i++) { + auto operand = op->operand(i); + if (!IsOperandDenseTensorType(operand)) continue; + auto operand_dtype = pir::GetDataTypeFromValue(operand.source()); + if (!IsDataTypeFloat( + paddle::dialect::TransToPhiDataType(operand_dtype))) + continue; + + // Only cast float16 or bfloat16 to float32 + if (ValueInPrecision(operand.source(), precision_mode_)) { + InsertCastOp(op, operand, phi_dtype, rewriter); + } + } + } + } +}; + +class AutoMixedPrecisionPass : public pir::Pass { + public: + AutoMixedPrecisionPass(const phi::Place& place, + const phi::DataType& precision_mode) + : pir::Pass("auto_mixed_precision_pass", 1), + place_(place), + precision_mode_(precision_mode) {} + + bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context, place_, precision_mode_); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation* op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 5; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0 && + place_ == paddle::PlaceType::kGPU && + (precision_mode_ == phi::DataType::FLOAT16 || + precision_mode_ == phi::DataType::BFLOAT16); + } + + private: + pir::FrozenRewritePatternSet patterns_; + phi::Place place_; + phi::DataType precision_mode_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateAutoMixedPrecisionPass( + const phi::Place& place, const phi::DataType& precision_mode) { + return std::make_unique(place, precision_mode); +} + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h new file mode 100644 index 00000000000000..2544219494a10f --- /dev/null +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateAutoMixedPrecisionPass( + const phi::Place& place, const phi::DataType& precision_mode); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index d0d44b1a720af9..f17e00e28f961e 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -51,11 +51,18 @@ const phi::DDim& GetShapeFromValue(pir::Value value) { pir::Type GetDataTypeFromValue(pir::Value value) { // TODO(dev): Support other types like DenseTensor. - PADDLE_ENFORCE_EQ( - value.type().isa(), - true, - phi::errors::InvalidArgument("Value's type must be a DenseTensorType.")); - return value.type().dyn_cast().dtype(); + if (value.type().isa()) { + return value.type().dyn_cast().dtype(); + } else if (value.type().isa()) { + return value.type().dyn_cast().dtype(); + } else if (value.type().isa()) { + return value.type() + .dyn_cast() + .dtype(); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Unsupported pir data type: %s.", + value.type())); + } } Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index b06577552d52b3..05977180d951f7 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -3,6 +3,9 @@ cc_test( SRCS pattern_rewrite_test.cc DEPS gtest op_dialect_vjp pir pir_transforms) +cc_test_old(auto_mixed_precision_test SRCS auto_mixed_precision_test.cc DEPS + ${PATTERN_REWRITE_TEST_DEPS}) + cc_test( drr_test SRCS drr_test.cc diff --git a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc new file mode 100644 index 00000000000000..6f880900e21d00 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" +#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_filter_op = + builder.Build(std::vector{64, 3, 3, 3}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_mean_op = builder.Build( + std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::FullOp full_variance_op = + builder.Build(std::vector{64}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_scale_op = + builder.Build(std::vector{64}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullOp full_bias_op = builder.Build( + std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::Conv2dOp conv2d_op = + builder.Build(full_input_op.out(), + full_filter_op.out()); + + paddle::dialect::BatchNormOp batch_norm_op = + builder.Build(conv2d_op.out(), + full_mean_op.out(), + full_variance_op.out(), + full_scale_op.out(), + full_bias_op.out(), + true, + 0.9, + 1e-6, + "NCHW", + false, + false); + + auto transpose1_op = builder.Build( + batch_norm_op.out(), std::vector{0, 2, 3, 1}); + + auto transpose2_op = builder.Build( + transpose1_op.out(), std::vector{0, 3, 1, 2}); + + builder.Build(transpose2_op.out(), "out", 0); +} + +TEST(AutoMixedPrecisonTest, MixedPrecisionTest) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 11u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateAutoMixedPrecisionPass(phi::GPUPlace(), + phi::DataType::FLOAT16)); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); +} From dc209103656937c3e4b944ab0c205fdeab26369d Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 8 Dec 2023 02:57:55 +0000 Subject: [PATCH 02/40] clean code --- .../transforms/auto_mixed_precision_pass.cc | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index c95d976b765128..3e49da0f580d09 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -299,7 +299,7 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { dtype == phi::DataType::BFLOAT16; } - bool IsOperandDenseTensorType(pir::OpOperand operand) const { + bool IsOperandHasDenseTensorType(pir::OpOperand operand) const { return operand.type() && operand.type().isa(); } @@ -359,19 +359,17 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // Rewrite FetchOp if (op->isa()) { auto fetch_operand = op->operand(0); + auto fetch_operand_dtype = + pir::GetDataTypeFromValue(fetch_operand.source()); if (enable_low_precision_io_) { SetResultDataType( op->result(0), precision_mode_, rewriter.ir_context()); } if (!op->result(0).type().isa()) return; - auto result_dtype = pir::GetDataTypeFromValue(op->result(0)); - if (!ValueInPrecision( - fetch_operand.source(), - paddle::dialect::TransToPhiDataType(result_dtype))) { - InsertCastOp(op, - fetch_operand, - paddle::dialect::TransToPhiDataType(result_dtype), - rewriter); + auto result_dtype = paddle::dialect::TransToPhiDataType( + pir::GetDataTypeFromValue(op->result(0))); + if (fetch_operand_dtype != result_dtype) { + InsertCastOp(op, fetch_operand, result_dtype, rewriter); } return; } @@ -431,10 +429,12 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // input_defs will always be the smaller one? for (size_t i = 0; i < input_defs.size(); i++) { auto operand = op->operand(i); - if (!IsOperandDenseTensorType(operand)) continue; - auto in_phi_dtype = input_defs[i].dtype; - if (IsDataTypeFloat(in_phi_dtype) && - !ValueInPrecision(operand.source(), in_phi_dtype)) { + if (!IsOperandHasDenseTensorType(operand)) continue; + auto operand_dtype = pir::GetDataTypeFromValue(operand.source()); + if (!IsDataTypeFloat( + paddle::dialect::TransToPhiDataType(operand_dtype))) + continue; + if (operand_dtype != in_phi_dtype) { InsertCastOp(op, operand, in_phi_dtype, rewriter); } } @@ -443,14 +443,14 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { auto phi_dtype = phi::DataType::FLOAT32; for (size_t i = 0; i < op->num_operands(); i++) { auto operand = op->operand(i); - if (!IsOperandDenseTensorType(operand)) continue; + if (!IsOperandHasDenseTensorType(operand)) continue; auto operand_dtype = pir::GetDataTypeFromValue(operand.source()); if (!IsDataTypeFloat( paddle::dialect::TransToPhiDataType(operand_dtype))) continue; // Only cast float16 or bfloat16 to float32 - if (ValueInPrecision(operand.source(), precision_mode_)) { + if (operand_dtype != precision_mode_) { InsertCastOp(op, operand, phi_dtype, rewriter); } } From b3271d89bcf95e0f50e13590a6b79e992b4cba71 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 8 Dec 2023 10:43:58 +0000 Subject: [PATCH 03/40] refine CMakeLists --- test/cpp/pir/pattern_rewrite/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 05977180d951f7..5bc063240fb4af 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -3,8 +3,10 @@ cc_test( SRCS pattern_rewrite_test.cc DEPS gtest op_dialect_vjp pir pir_transforms) -cc_test_old(auto_mixed_precision_test SRCS auto_mixed_precision_test.cc DEPS - ${PATTERN_REWRITE_TEST_DEPS}) +cc_test( + auto_mixed_precision_test + SRCS auto_mixed_precision_test.cc + DEPS gtest pir pir_transforms) cc_test( drr_test From 7c1fcb3ec6edc1a08bd7bf68fbd5c394b8489489 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 8 Dec 2023 11:16:14 +0000 Subject: [PATCH 04/40] fix bug --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 3e49da0f580d09..599cb105f793b1 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -17,6 +17,9 @@ #include #include +#include "paddle/common/enforce.h" +#include "paddle/common/errors.h" + #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" @@ -34,8 +37,6 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/errors.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/ir_context.h" @@ -359,8 +360,8 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // Rewrite FetchOp if (op->isa()) { auto fetch_operand = op->operand(0); - auto fetch_operand_dtype = - pir::GetDataTypeFromValue(fetch_operand.source()); + auto fetch_operand_dtype = paddle::dialect::TransToPhiDataType( + pir::GetDataTypeFromValue(fetch_operand.source())); if (enable_low_precision_io_) { SetResultDataType( op->result(0), precision_mode_, rewriter.ir_context()); From 6a9f3dc2e51512f0c93f1453b7b2c643d9064f58 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 8 Dec 2023 11:39:28 +0000 Subject: [PATCH 05/40] can compile --- .../transforms/auto_mixed_precision_pass.cc | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 599cb105f793b1..e91134991fec2d 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -300,6 +300,11 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { dtype == phi::DataType::BFLOAT16; } + phi::DataType OperandDataType(const pir::OpOperand& operand) const { + auto dtype = pir::GetDataTypeFromValue(operand.source()); + return paddle::dialect::TransToPhiDataType(dtype); + } + bool IsOperandHasDenseTensorType(pir::OpOperand operand) const { return operand.type() && operand.type().isa(); @@ -360,8 +365,7 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // Rewrite FetchOp if (op->isa()) { auto fetch_operand = op->operand(0); - auto fetch_operand_dtype = paddle::dialect::TransToPhiDataType( - pir::GetDataTypeFromValue(fetch_operand.source())); + auto fetch_operand_dtype = OperandDataType(fetch_operand); if (enable_low_precision_io_) { SetResultDataType( op->result(0), precision_mode_, rewriter.ir_context()); @@ -430,12 +434,10 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // input_defs will always be the smaller one? for (size_t i = 0; i < input_defs.size(); i++) { auto operand = op->operand(i); + auto in_phi_dtype = input_defs[i].dtype; if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_dtype = pir::GetDataTypeFromValue(operand.source()); - if (!IsDataTypeFloat( - paddle::dialect::TransToPhiDataType(operand_dtype))) - continue; - if (operand_dtype != in_phi_dtype) { + auto operand_dtype = OperandDataType(operand); + if (IsDataTypeFloat(operand_dtype) && operand_dtype != in_phi_dtype) { InsertCastOp(op, operand, in_phi_dtype, rewriter); } } @@ -445,13 +447,9 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { for (size_t i = 0; i < op->num_operands(); i++) { auto operand = op->operand(i); if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_dtype = pir::GetDataTypeFromValue(operand.source()); - if (!IsDataTypeFloat( - paddle::dialect::TransToPhiDataType(operand_dtype))) - continue; - - // Only cast float16 or bfloat16 to float32 - if (operand_dtype != precision_mode_) { + auto operand_dtype = OperandDataType(operand); + if (IsDataTypeFloat(operand_dtype) && + operand_dtype == precision_mode_) { InsertCastOp(op, operand, phi_dtype, rewriter); } } From 821caa34ce3900e183bb989ea6a7efca4eb326f2 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 8 Dec 2023 13:40:27 +0000 Subject: [PATCH 06/40] fix:Cast op int to f32,full op scale,ShareDataOp,Result all type -> precision_mode --- .../transforms/auto_mixed_precision_pass.cc | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index e91134991fec2d..4f9e1cf4a7d294 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -91,7 +91,6 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // if enable_low_precision_io_ is true, all the op will be transformed into, // input and output included if (op->isa() || op->isa() || - op->isa() || op->isa()) return false; @@ -99,6 +98,23 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { if (op->isa()) return false; } + // is op is a cast op, its input type should be not be float + if (op->isa()) { + auto cast_operand = op->operand(0); + auto cast_operand_dtype = OperandDataType(cast_operand); + return !IsDataTypeFloat(cast_operand_dtype); + } + + // if op is a full op, its user cannot be a scale op + if (op->isa()) { + auto use_ops = GetUseOpsForOutput(op, 0); + for (auto [use_op, idx] : use_ops) { + if (use_op->isa()) { + return false; + } + } + } + if (!IsBuiltinOp(op)) { return OpHasFloatResult(op); } @@ -374,6 +390,7 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); if (fetch_operand_dtype != result_dtype) { + LOG(INFO) << "Insert CastOp for FetchOp" << std::endl; InsertCastOp(op, fetch_operand, result_dtype, rewriter); } return; @@ -384,6 +401,15 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { return; } + // Rewrite ShareDataOp + if (op->isa()) { + auto share_data_operand = op->operand(0); + auto share_data_operand_dtype = OperandDataType(share_data_operand); + SetResultDataType( + op->result(0), share_data_operand_dtype, rewriter.ir_context()); + return; + } + if (OpSupportPrecision(op_type, backend, precision_mode_)) { // change result's dtype to low precision LOG(INFO) << "Change result's dtype to low precision " << op->name() @@ -426,7 +452,8 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { phi::DataType out_phi_dtype = output_defs[i].dtype; LOG(INFO) << "result dtype = " << phi::DataTypeToString(out_phi_dtype) << std::endl; - if (out_phi_dtype == phi::DataType::UNDEFINED) continue; + if (out_phi_dtype == phi::DataType::UNDEFINED) + out_phi_dtype = precision_mode_; SetResultDataType(result, out_phi_dtype, rewriter.ir_context()); } @@ -438,6 +465,8 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { if (!IsOperandHasDenseTensorType(operand)) continue; auto operand_dtype = OperandDataType(operand); if (IsDataTypeFloat(operand_dtype) && operand_dtype != in_phi_dtype) { + LOG(INFO) << "Support low precision, insert CastOp for " << op->name() + << " operand " << i << std::endl; InsertCastOp(op, operand, in_phi_dtype, rewriter); } } @@ -450,6 +479,8 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { auto operand_dtype = OperandDataType(operand); if (IsDataTypeFloat(operand_dtype) && operand_dtype == precision_mode_) { + LOG(INFO) << "Not support low precision, insert CastOp for " + << op->name() << " operand " << i << std::endl; InsertCastOp(op, operand, phi_dtype, rewriter); } } @@ -475,7 +506,7 @@ class AutoMixedPrecisionPass : public pir::Pass { void Run(pir::Operation* op) override { pir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; - cfg.max_iterations = 5; + cfg.max_iterations = 1; pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } From 1744eab5b2b8c2db792470963efa16719762bcd6 Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 9 Dec 2023 02:10:22 +0000 Subject: [PATCH 07/40] recover transform_general_functions.cc --- .../transforms/transform_general_functions.cc | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index f17e00e28f961e..d0d44b1a720af9 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -51,18 +51,11 @@ const phi::DDim& GetShapeFromValue(pir::Value value) { pir::Type GetDataTypeFromValue(pir::Value value) { // TODO(dev): Support other types like DenseTensor. - if (value.type().isa()) { - return value.type().dyn_cast().dtype(); - } else if (value.type().isa()) { - return value.type().dyn_cast().dtype(); - } else if (value.type().isa()) { - return value.type() - .dyn_cast() - .dtype(); - } else { - PADDLE_THROW(phi::errors::Unimplemented("Unsupported pir data type: %s.", - value.type())); - } + PADDLE_ENFORCE_EQ( + value.type().isa(), + true, + phi::errors::InvalidArgument("Value's type must be a DenseTensorType.")); + return value.type().dyn_cast().dtype(); } Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { From 1eb9cd084431339605cb2048a0f67947236eed5d Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 13 Dec 2023 17:58:33 +0000 Subject: [PATCH 08/40] fix bug, now can run --- .../transforms/auto_mixed_precision_pass.cc | 63 +++++++++++++++---- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 4f9e1cf4a7d294..29b446a8190e8a 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -50,6 +50,28 @@ namespace { +// This pattern is used to rewrite the CastOp that has a CastOp as its operand +class FoldMultiCastOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite( + paddle::dialect::CastOp cast_op, + pir::PatternRewriter& rewriter) const override { // NOLINT + auto input_op = pir::GetDefiningOpForInput(cast_op, 0) + ->dyn_cast(); + if (!input_op) return false; + auto op_type = pir::GetDataTypeFromValue(cast_op.out()); + auto new_cast_op = rewriter.Build( + input_op.x().dyn_cast(), + paddle::dialect::TransToPhiDataType(op_type)); + rewriter.ReplaceOp(cast_op, std::vector{new_cast_op.out()}); + rewriter.EraseOp(cast_op); + return true; + } +}; + class AutoMixedPrecisionPattern : public pir::RewritePattern { public: AutoMixedPrecisionPattern( @@ -91,6 +113,7 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { // if enable_low_precision_io_ is true, all the op will be transformed into, // input and output included if (op->isa() || op->isa() || + op->isa() || op->isa()) return false; @@ -98,13 +121,6 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { if (op->isa()) return false; } - // is op is a cast op, its input type should be not be float - if (op->isa()) { - auto cast_operand = op->operand(0); - auto cast_operand_dtype = OperandDataType(cast_operand); - return !IsDataTypeFloat(cast_operand_dtype); - } - // if op is a full op, its user cannot be a scale op if (op->isa()) { auto use_ops = GetUseOpsForOutput(op, 0); @@ -344,6 +360,27 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { if (op->isa()) { // auto vec_type = op->result(0).type().dyn_cast(); auto input_num = op->num_operands(); + bool in_low_precision = false; + bool should_insert_cast = false; + for (size_t i = 0; i < input_num; ++i) { + auto operand = op->operand(i); + auto operand_dtype = OperandDataType(operand); + if (operand_dtype == precision_mode_) { + in_low_precision = true; + } else if (IsDataTypeFloat(operand_dtype)) { + should_insert_cast = true; + } + } + if (in_low_precision && should_insert_cast) { + LOG(INFO) << "Insert CastOp for CombineOp" << std::endl; + for (size_t i = 0; i < input_num; ++i) { + auto operand = op->operand(i); + auto operand_dtype = OperandDataType(operand); + if (operand_dtype != precision_mode_) { + InsertCastOp(op, operand, precision_mode_, rewriter); + } + } + } std::vector inputs_type(input_num); for (size_t idx = 0; idx < input_num; ++idx) { inputs_type[idx] = op->operand(idx).type(); @@ -415,11 +452,10 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { LOG(INFO) << "Change result's dtype to low precision " << op->name() << std::endl; - if (op->HasAttribute("dtype")) { - if (!IsDataTypeFloat( - op->attribute("dtype") - .data())) - return; + if (op->HasAttribute("dtype") && + IsDataTypeFloat( + op->attribute("dtype") + .data())) { pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( rewriter.ir_context(), precision_mode_); op->set_attribute("dtype", attr_dtype); @@ -499,6 +535,7 @@ class AutoMixedPrecisionPass : public pir::Pass { bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); ps.Add(context, place_, precision_mode_); + // ps.Add(context); patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } @@ -506,7 +543,7 @@ class AutoMixedPrecisionPass : public pir::Pass { void Run(pir::Operation* op) override { pir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; - cfg.max_iterations = 1; + cfg.max_iterations = 2; pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } From 432aa8d4384c2e0ed825e1feb6e22400a300dfc8 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 18 Dec 2023 14:32:41 +0000 Subject: [PATCH 09/40] refactor code --- .../transforms/auto_mixed_precision_pass.cc | 183 +++++++++++------- 1 file changed, 109 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 29b446a8190e8a..0621706c6b0bb8 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -82,24 +82,24 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { pir::PatternBenefit benefit = 1, const std::vector& generated_names = {}) : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { - precision_mode_ = precision_mode; // should be set by user - place_ = place; // should be set by user - enable_low_precision_io_ = enable_low_precision_io; - SetDefaultBlacklist(); - SetDefaultWhitelist(); + // precision_mode_ = precision_mode; // should be set by user + // place_ = place; // should be set by user + // // enable_low_precision_io_ = enable_low_precision_io; + // SetDefaultBlacklist(); + // SetDefaultWhitelist(); } void SetDefaultBlacklist() { - black_list_.insert({ - paddle::dialect::ExpOp::name(), - paddle::dialect::SquareOp::name(), - paddle::dialect::LogOp::name(), - // paddle::dialect::FetchOp::name(), - - // paddle::dialect::Mean::name(), - // paddle::dialect::Sum::name(), - paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), - }); + // black_list_.insert({ + // paddle::dialect::ExpOp::name(), + // paddle::dialect::SquareOp::name(), + // paddle::dialect::LogOp::name(), + // // paddle::dialect::FetchOp::name(), + + // // paddle::dialect::Mean::name(), + // // paddle::dialect::Sum::name(), + // paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), + // }); } void SetDefaultWhitelist() { @@ -117,29 +117,108 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { op->isa()) return false; - if (!enable_low_precision_io_) { - if (op->isa()) return false; - } + // if (!enable_low_precision_io_) { + // if (op->isa()) return false; + // } // if op is a full op, its user cannot be a scale op - if (op->isa()) { - auto use_ops = GetUseOpsForOutput(op, 0); - for (auto [use_op, idx] : use_ops) { - if (use_op->isa()) { - return false; - } + // if (op->isa()) { + // auto use_ops = GetUseOpsForOutput(op, 0); + // for (auto [use_op, idx] : use_ops) { + // if (use_op->isa()) { + // return false; + // } + // } + // } + + // if (!IsBuiltinOp(op)) { + // return OpHasFloatResult(op); + // } + + // return true; + // } + return true; + } +}; + +class AutoMixedPrecisionPass : public pir::Pass { + public: + AutoMixedPrecisionPass(const phi::Place& place, + const phi::DataType& precision_mode) + : pir::Pass("auto_mixed_precision_pass", 1), + place_(place), + precision_mode_(precision_mode) {} + + bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context, place_, precision_mode_); + // ps.Add(context); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation* op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 2; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0 && + place_ == paddle::PlaceType::kGPU && + (precision_mode_ == phi::DataType::FLOAT16 || + precision_mode_ == phi::DataType::BFLOAT16); + } + + private: + pir::FrozenRewritePatternSet patterns_; + phi::Place place_; + phi::DataType precision_mode_; + bool enable_low_precision_io_; + + std::unordered_set black_list_; + std::unordered_set white_list_; + + std::unordered_set op_run_low_precision_; + + void ProcessBlock(pir::Block* block) {} + + void GetOpPrecision(pir::Block* block) { + for (auto& op_item : *block) { + VLOG(6) << "op name " << op_item.name(); + auto op_name = op_item.name(); + bool support_low_precision = true; + if (black_list.count(op_name)) { + support_low_precision = false; + } else if (IsBuiltinOp(&op_item)) { // other builtin ops + if (op->isa() || op->isa()) + support_low_precision = false; + } else if (op_item->isa() || + op_item->isa()) { + support_low_precision = enable_low_precision_io_; + } else if (OpHasFloatResult(&op_item)) { // pd op with float result + auto op_type = op_name.substr(op_name.find(".") + 1); + auto backend = ConvertPlaceToBackend(place_); + support_low_precision = + OpSupportPrecision(op_type, backend, precision_mode_); + } + if (support_low_precision) { + op_run_low_precision_.insert(&op_item); } } + } - if (!IsBuiltinOp(op)) { - return OpHasFloatResult(op); + void UpdateOpPrecision(pir::Block* block) { + for (auto& op_item : *block) { + if (op_run_low_precision_.count(&op_item)) { + RewriteOp(&op_item); + } } - - return true; } - void Rewrite(pir::Operation* op, - pir::PatternRewriter& rewriter) const override { // NOLINT + void RewriteOp(pir::Operation* op, + pir::PatternRewriter& rewriter) const { // NOLINT LOG(INFO) << "Rewrite op " << op->name() << std::endl; if (IsBuiltinOp(op)) { RewriteBuiltinOp(op, rewriter); @@ -150,14 +229,6 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { } } - private: - std::unordered_set black_list_; - std::unordered_set white_list_; - phi::DataType precision_mode_{phi::DataType::UNDEFINED}; - - phi::Place place_; - bool enable_low_precision_io_; - bool PhiKernelSupportPrecision( const std::string& op_type, phi::Backend backend, @@ -524,42 +595,6 @@ class AutoMixedPrecisionPattern : public pir::RewritePattern { } }; -class AutoMixedPrecisionPass : public pir::Pass { - public: - AutoMixedPrecisionPass(const phi::Place& place, - const phi::DataType& precision_mode) - : pir::Pass("auto_mixed_precision_pass", 1), - place_(place), - precision_mode_(precision_mode) {} - - bool Initialize(pir::IrContext* context) override { - pir::RewritePatternSet ps(context); - ps.Add(context, place_, precision_mode_); - // ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation* op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 2; - pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa<::pir::ModuleOp>() && op->num_regions() > 0 && - place_ == paddle::PlaceType::kGPU && - (precision_mode_ == phi::DataType::FLOAT16 || - precision_mode_ == phi::DataType::BFLOAT16); - } - - private: - pir::FrozenRewritePatternSet patterns_; - phi::Place place_; - phi::DataType precision_mode_; -}; - } // namespace namespace pir { From 21575143ebbbb6f2c5143887a871a57570727d22 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 19 Dec 2023 12:19:11 +0000 Subject: [PATCH 10/40] add cache for cast ops --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 0621706c6b0bb8..fe4937a704aea3 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -181,6 +181,7 @@ class AutoMixedPrecisionPass : public pir::Pass { std::unordered_set white_list_; std::unordered_set op_run_low_precision_; + std::unordered_map cached_cast_ops_; void ProcessBlock(pir::Block* block) {} @@ -418,10 +419,15 @@ class AutoMixedPrecisionPass : public pir::Pass { phi::DataType precision, pir::PatternRewriter& rewriter) const { // NOLINT auto value = operand.source(); + if (cached_cast_ops_.count(value)) { + operand.set_source(cached_cast_ops_[value]->result(0)); + return; + } rewriter.set_insertion_point(op); // before op paddle::dialect::CastOp cast_op = rewriter.Build(value, precision); operand.set_source(cast_op->result(0)); + cached_cast_ops_[value] = cast_op; } void RewriteBuiltinOp(pir::Operation* op, From 203f4bd2ef5f243c70b22f4754d60055503d4a59 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 19 Dec 2023 14:24:59 +0000 Subject: [PATCH 11/40] pass compile --- .../transforms/auto_mixed_precision_pass.cc | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index fe4937a704aea3..3036030971d77a 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" #include #include +#include #include #include "paddle/common/enforce.h" @@ -180,43 +181,39 @@ class AutoMixedPrecisionPass : public pir::Pass { std::unordered_set black_list_; std::unordered_set white_list_; - std::unordered_set op_run_low_precision_; - std::unordered_map cached_cast_ops_; + mutable std::unordered_set op_run_low_precision_; + mutable std::unordered_map + cached_cast_ops_; void ProcessBlock(pir::Block* block) {} void GetOpPrecision(pir::Block* block) { for (auto& op_item : *block) { - VLOG(6) << "op name " << op_item.name(); - auto op_name = op_item.name(); + auto op = &op_item; + VLOG(6) << "op name " << op->name(); + auto op_name = op->name(); bool support_low_precision = true; - if (black_list.count(op_name)) { + if (black_list_.count(op_name)) { support_low_precision = false; - } else if (IsBuiltinOp(&op_item)) { // other builtin ops + } else if (IsBuiltinOp(op)) { // other builtin ops if (op->isa() || op->isa()) support_low_precision = false; - } else if (op_item->isa() || - op_item->isa()) { + } else if (op->isa() || + op->isa()) { support_low_precision = enable_low_precision_io_; - } else if (OpHasFloatResult(&op_item)) { // pd op with float result + } else if (OpHasFloatResult(op)) { // pd op with float result auto op_type = op_name.substr(op_name.find(".") + 1); auto backend = ConvertPlaceToBackend(place_); support_low_precision = OpSupportPrecision(op_type, backend, precision_mode_); } if (support_low_precision) { - op_run_low_precision_.insert(&op_item); + op_run_low_precision_.insert(op); } } } - void UpdateOpPrecision(pir::Block* block) { - for (auto& op_item : *block) { - if (op_run_low_precision_.count(&op_item)) { - RewriteOp(&op_item); - } - } - } + void UpdateOpPrecision(pir::Block* block) {} void RewriteOp(pir::Operation* op, pir::PatternRewriter& rewriter) const { // NOLINT From 1ab926494a359f79db1cc6a9863e38e14f5f5897 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 19 Dec 2023 19:28:34 +0000 Subject: [PATCH 12/40] finish refactor --- .../transforms/auto_mixed_precision_pass.cc | 211 ++++++++++++------ 1 file changed, 142 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 3036030971d77a..63687d0ce3ff7c 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -51,28 +51,6 @@ namespace { -// This pattern is used to rewrite the CastOp that has a CastOp as its operand -class FoldMultiCastOpPattern - : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - - bool MatchAndRewrite( - paddle::dialect::CastOp cast_op, - pir::PatternRewriter& rewriter) const override { // NOLINT - auto input_op = pir::GetDefiningOpForInput(cast_op, 0) - ->dyn_cast(); - if (!input_op) return false; - auto op_type = pir::GetDataTypeFromValue(cast_op.out()); - auto new_cast_op = rewriter.Build( - input_op.x().dyn_cast(), - paddle::dialect::TransToPhiDataType(op_type)); - rewriter.ReplaceOp(cast_op, std::vector{new_cast_op.out()}); - rewriter.EraseOp(cast_op); - return true; - } -}; - class AutoMixedPrecisionPattern : public pir::RewritePattern { public: AutoMixedPrecisionPattern( @@ -159,10 +137,21 @@ class AutoMixedPrecisionPass : public pir::Pass { } void Run(pir::Operation* op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 2; - pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + auto module_op = op->dyn_cast(); + pir::Block* block = &module_op.block(); + LOG(INFO) << "===========Get Op Precision============" << std::endl; + GetOpPrecision(block); + LOG(INFO) << "===========Update Op Precision============" << std::endl; + UpdateOpPrecision(block); + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::Builder builder = pir::Builder(ctx, block); + LOG(INFO) << "===========Process Op Precision============" << std::endl; + + ProcessBlock(block, builder); + // pir::GreedyRewriteConfig cfg; + // cfg.use_top_down_traversal = true; + // cfg.max_iterations = 2; + // pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } bool CanApplyOn(pir::Operation* op) const override { @@ -185,7 +174,12 @@ class AutoMixedPrecisionPass : public pir::Pass { mutable std::unordered_map cached_cast_ops_; - void ProcessBlock(pir::Block* block) {} + void ProcessBlock(pir::Block* block, pir::Builder& builder) const { // NOLINT + for (auto& op_item : *block) { + auto op = &op_item; + RewriteOp(op, builder); + } + } void GetOpPrecision(pir::Block* block) { for (auto& op_item : *block) { @@ -206,23 +200,99 @@ class AutoMixedPrecisionPass : public pir::Pass { auto backend = ConvertPlaceToBackend(place_); support_low_precision = OpSupportPrecision(op_type, backend, precision_mode_); + } else { // pd op without float result + support_low_precision = false; } if (support_low_precision) { op_run_low_precision_.insert(op); + LOG(INFO) << "op " << op->name() << " support low precision" + << std::endl; + } else { + LOG(INFO) << "op " << op->name() << " doesn't support low precision" + << std::endl; } } } - void UpdateOpPrecision(pir::Block* block) {} + bool VectorTypeFloat(pir::VectorType vec_type) { + size_t output_num = vec_type.size(); + for (size_t j = 0; j < output_num; j++) { + auto dtype = + vec_type[j].dyn_cast().dtype(); + if (!IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { + return false; + } + } + return true; + } + + void UpdateOpPrecision(pir::Block* block) { + for (auto& op_item : *block) { + auto op = &op_item; + // remove attribute input op + if (op->HasInterface()) { + auto [input_infos, _1, _2, _3, _4] = + op->dyn_cast().GetOpInfo(); + for (size_t idx = 0; idx < input_infos.size(); ++idx) { + if (op->operand_source(idx) && + input_infos[idx].type_name.find("ScalarAttribute") != + std::string::npos) { + LOG(INFO) << "op name " << op->name() << " try to remove attribute" + << std::endl; + LOG(INFO) << "Remove op name " + << GetDefiningOpForInput(op, idx)->name() << " attribute" + << std::endl; + op_run_low_precision_.erase(GetDefiningOpForInput(op, idx)); + } + } + } + // precision should be same as input + // if (op->isa()) { + // auto input_operation = GetDefiningOpForInput(op, 0); + // if (!op_run_low_precision_.count(input_operation)) { + // op_run_low_precision_.erase(op); + // } + // } + } + for (auto& op_item : *block) { + auto op = &op_item; + for (size_t idx = 0; idx < op->num_operands(); ++idx) { + if (!op->operand_source(idx)) continue; + auto operand = op->operand(idx); + if (operand.type() && operand.type().isa()) { + // check if there are all float in the vectortype + auto vec_type = operand.type().dyn_cast(); + if (VectorTypeFloat(vec_type)) { + auto input_operation = GetDefiningOpForInput(op, idx); + // 如果有一个是高精的话,则必须都跑在高精上 + if (!op_run_low_precision_.count(op) || + !op_run_low_precision_.count(input_operation)) { + op_run_low_precision_.erase(op); + op_run_low_precision_.erase(input_operation); + } + } + } + } + } + // print if op run low precision + for (auto& op_item : *block) { + auto op = &op_item; + if (op_run_low_precision_.count(op)) { + LOG(INFO) << "op " << op->name() << " run low precision" << std::endl; + } else { + LOG(INFO) << "op " << op->name() << " run high precision" << std::endl; + } + } + } void RewriteOp(pir::Operation* op, - pir::PatternRewriter& rewriter) const { // NOLINT + pir::Builder& builder) const { // NOLINT LOG(INFO) << "Rewrite op " << op->name() << std::endl; if (IsBuiltinOp(op)) { - RewriteBuiltinOp(op, rewriter); + RewriteBuiltinOp(op, builder); return; } else { - RewritePdOp(op, rewriter); + RewritePdOp(op, builder); return; } } @@ -414,58 +484,60 @@ class AutoMixedPrecisionPass : public pir::Pass { void InsertCastOp(pir::Operation* op, pir::OpOperand operand, phi::DataType precision, - pir::PatternRewriter& rewriter) const { // NOLINT + pir::Builder& builder) const { // NOLINT auto value = operand.source(); if (cached_cast_ops_.count(value)) { operand.set_source(cached_cast_ops_[value]->result(0)); return; } - rewriter.set_insertion_point(op); // before op + builder.set_insertion_point(op); // before op paddle::dialect::CastOp cast_op = - rewriter.Build(value, precision); + builder.Build(value, precision); operand.set_source(cast_op->result(0)); cached_cast_ops_[value] = cast_op; } + bool OpRunLowPrecision(pir::Operation* op) const { + return op_run_low_precision_.count(op); + } + void RewriteBuiltinOp(pir::Operation* op, - pir::PatternRewriter& rewriter) const { // NOLINT + pir::Builder& builder) const { // NOLINT LOG(INFO) << "Rewrite builtin op " << op->name() << std::endl; // Rewrite CombineOp if (op->isa()) { // auto vec_type = op->result(0).type().dyn_cast(); auto input_num = op->num_operands(); - bool in_low_precision = false; - bool should_insert_cast = false; - for (size_t i = 0; i < input_num; ++i) { - auto operand = op->operand(i); - auto operand_dtype = OperandDataType(operand); - if (operand_dtype == precision_mode_) { - in_low_precision = true; - } else if (IsDataTypeFloat(operand_dtype)) { - should_insert_cast = true; + if (OpRunLowPrecision(op)) { + for (size_t i = 0; i < input_num; ++i) { + auto operand = op->operand(i); + auto operand_dtype = OperandDataType(operand); + if (IsDataTypeFloat(operand_dtype) && + operand_dtype != precision_mode_) { + InsertCastOp(op, operand, precision_mode_, builder); + } } - } - if (in_low_precision && should_insert_cast) { - LOG(INFO) << "Insert CastOp for CombineOp" << std::endl; + std::vector inputs_type(input_num); + for (size_t idx = 0; idx < input_num; ++idx) { + inputs_type[idx] = op->operand(idx).type(); + } + auto new_vec_type = + pir::VectorType::get(builder.ir_context(), inputs_type); + op->result(0).set_type(new_vec_type); + } else { for (size_t i = 0; i < input_num; ++i) { auto operand = op->operand(i); auto operand_dtype = OperandDataType(operand); - if (operand_dtype != precision_mode_) { - InsertCastOp(op, operand, precision_mode_, rewriter); + if (operand_dtype == precision_mode_) { + InsertCastOp(op, operand, phi::DataType::FLOAT32, builder); } } } - std::vector inputs_type(input_num); - for (size_t idx = 0; idx < input_num; ++idx) { - inputs_type[idx] = op->operand(idx).type(); - } - auto new_vec_type = - pir::VectorType::get(rewriter.ir_context(), inputs_type); - op->result(0).set_type(new_vec_type); } // Rewrite SliceOp if (op->isa()) { + if (!OpRunLowPrecision(op)) return; auto index = op->attribute("index").dyn_cast().data(); auto input_type = op->operand(0).type().dyn_cast(); @@ -475,6 +547,7 @@ class AutoMixedPrecisionPass : public pir::Pass { // Rewrite SplitOp if (op->isa()) { + if (!OpRunLowPrecision(op)) return; auto input_type = op->operand(0).type().dyn_cast(); int output_num = op->num_results(); for (int i = 0; i < output_num; ++i) { @@ -484,7 +557,7 @@ class AutoMixedPrecisionPass : public pir::Pass { } void RewritePdOp(pir::Operation* op, - pir::PatternRewriter& rewriter) const { // NOLINT + pir::Builder& builder) const { // NOLINT LOG(INFO) << "Rewrite pd op " << op->name() << std::endl; phi::Backend backend = ConvertPlaceToBackend(place_); std::string op_type = op->name().substr(op->name().find(".") + 1); @@ -493,22 +566,21 @@ class AutoMixedPrecisionPass : public pir::Pass { if (op->isa()) { auto fetch_operand = op->operand(0); auto fetch_operand_dtype = OperandDataType(fetch_operand); - if (enable_low_precision_io_) { - SetResultDataType( - op->result(0), precision_mode_, rewriter.ir_context()); + if (OpRunLowPrecision(op)) { + SetResultDataType(op->result(0), precision_mode_, builder.ir_context()); } if (!op->result(0).type().isa()) return; auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); if (fetch_operand_dtype != result_dtype) { LOG(INFO) << "Insert CastOp for FetchOp" << std::endl; - InsertCastOp(op, fetch_operand, result_dtype, rewriter); + InsertCastOp(op, fetch_operand, result_dtype, builder); } return; } // Rewrite FeedOp - if (op->isa() && enable_low_precision_io_) { - SetResultDataType(op->result(0), precision_mode_, rewriter.ir_context()); + if (op->isa() && OpRunLowPrecision(op)) { + SetResultDataType(op->result(0), precision_mode_, builder.ir_context()); return; } @@ -517,11 +589,12 @@ class AutoMixedPrecisionPass : public pir::Pass { auto share_data_operand = op->operand(0); auto share_data_operand_dtype = OperandDataType(share_data_operand); SetResultDataType( - op->result(0), share_data_operand_dtype, rewriter.ir_context()); + op->result(0), share_data_operand_dtype, builder.ir_context()); return; } - if (OpSupportPrecision(op_type, backend, precision_mode_)) { + // Other pd ops + if (OpRunLowPrecision(op)) { // change result's dtype to low precision LOG(INFO) << "Change result's dtype to low precision " << op->name() << std::endl; @@ -531,7 +604,7 @@ class AutoMixedPrecisionPass : public pir::Pass { op->attribute("dtype") .data())) { pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( - rewriter.ir_context(), precision_mode_); + builder.ir_context(), precision_mode_); op->set_attribute("dtype", attr_dtype); } @@ -564,7 +637,7 @@ class AutoMixedPrecisionPass : public pir::Pass { << std::endl; if (out_phi_dtype == phi::DataType::UNDEFINED) out_phi_dtype = precision_mode_; - SetResultDataType(result, out_phi_dtype, rewriter.ir_context()); + SetResultDataType(result, out_phi_dtype, builder.ir_context()); } // if any of the op's input is not in low precision, insert cast op @@ -577,7 +650,7 @@ class AutoMixedPrecisionPass : public pir::Pass { if (IsDataTypeFloat(operand_dtype) && operand_dtype != in_phi_dtype) { LOG(INFO) << "Support low precision, insert CastOp for " << op->name() << " operand " << i << std::endl; - InsertCastOp(op, operand, in_phi_dtype, rewriter); + InsertCastOp(op, operand, in_phi_dtype, builder); } } } else { // current op doesn't support low precision, should cast to float @@ -591,7 +664,7 @@ class AutoMixedPrecisionPass : public pir::Pass { operand_dtype == precision_mode_) { LOG(INFO) << "Not support low precision, insert CastOp for " << op->name() << " operand " << i << std::endl; - InsertCastOp(op, operand, phi_dtype, rewriter); + InsertCastOp(op, operand, phi_dtype, builder); } } } From b4642e135d833224d682fd0de317891235eb0549 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 20 Dec 2023 02:58:39 +0000 Subject: [PATCH 13/40] special rewrite sharedata op, because it actually has no kernel --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 63687d0ce3ff7c..b4e88f6c0f5103 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -585,11 +585,8 @@ class AutoMixedPrecisionPass : public pir::Pass { } // Rewrite ShareDataOp - if (op->isa()) { - auto share_data_operand = op->operand(0); - auto share_data_operand_dtype = OperandDataType(share_data_operand); - SetResultDataType( - op->result(0), share_data_operand_dtype, builder.ir_context()); + if (op->isa() && OpRunLowPrecision(op)) { + SetResultDataType(op->result(0), precision_mode_, builder.ir_context()); return; } From 142012b238cad2c2cd405f425fc69a3e910d0bb2 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 20 Dec 2023 07:16:56 +0000 Subject: [PATCH 14/40] special rewrite FullLikeOp --- .../transforms/auto_mixed_precision_pass.cc | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index b4e88f6c0f5103..b5984a0badccc9 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -195,7 +195,11 @@ class AutoMixedPrecisionPass : public pir::Pass { } else if (op->isa() || op->isa()) { support_low_precision = enable_low_precision_io_; - } else if (OpHasFloatResult(op)) { // pd op with float result + } else if (OpHasFloatOpOperand(op) || + OpHasFloatResult(op)) { // pd op with float result, + // but op like not_equal has float input, but has not float result + // if I don't add this, not_equal will be added to op_run_low_precision_ + // if I add this, full_like op will be run at high precision auto op_type = op_name.substr(op_name.find(".") + 1); auto backend = ConvertPlaceToBackend(place_); support_low_precision = @@ -227,6 +231,17 @@ class AutoMixedPrecisionPass : public pir::Pass { } void UpdateOpPrecision(pir::Block* block) { + // handle full like op + for (auto& op_item : *block) { + auto op = &op_item; + if (op->isa()) { + auto input_operation = GetDefiningOpForInput(op, 0); + if (!op_run_low_precision_.count(input_operation)) { + op_run_low_precision_.erase(op); + } + } + } + for (auto& op_item : *block) { auto op = &op_item; // remove attribute input op @@ -441,6 +456,29 @@ class AutoMixedPrecisionPass : public pir::Pass { } } + bool OpHasFloatOpOperand(pir::Operation* op) const { + for (size_t i = 0; i < op->num_operands(); i++) { + auto operand = op->operand_source(i); + if (!operand.type()) continue; + if (operand.type().isa()) { + auto dtype = pir::GetDataTypeFromValue(operand); + if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { + return true; + } + } else if (operand.type().isa()) { + auto vec_type = operand.type().dyn_cast(); + for (size_t j = 0; j < vec_type.size(); j++) { + auto dtype = + vec_type[j].dyn_cast().dtype(); + if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { + return true; + } + } + } + } + return false; + } + bool OpHasFloatResult(pir::Operation* op) const { for (size_t i = 0; i < op->num_results(); i++) { auto result = op->result(i); @@ -634,6 +672,7 @@ class AutoMixedPrecisionPass : public pir::Pass { << std::endl; if (out_phi_dtype == phi::DataType::UNDEFINED) out_phi_dtype = precision_mode_; + if (!IsDataTypeFloat(out_phi_dtype)) continue; SetResultDataType(result, out_phi_dtype, builder.ir_context()); } From 480c43b9e348dc0cfa887478e35b6656d6dec799 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 21 Dec 2023 11:55:47 +0000 Subject: [PATCH 15/40] fix my own error in analysis_predictor.cc, add op_should_not_handle, and delete useless code --- .../fluid/inference/api/analysis_predictor.cc | 2 +- .../transforms/auto_mixed_precision_pass.cc | 98 +++++-------------- 2 files changed, 26 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index bc4c242c8ccb57..f06ec762a81bc4 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -804,7 +804,7 @@ bool AnalysisPredictor::PrepareExecutor() { // Functional pass // Do auto mixed precision pass first, so do not need to handle // shadowoutput. - pm_for_op_program.AddPass(::pir::CreateAutoMixedPrecisionPass( + gpu_pm.AddPass(::pir::CreateAutoMixedPrecisionPass( place_, ConvertPrecision(config_.mixed_precision_mode_))); gpu_pm.AddPass(::pir::CreateIdentityOpCleanPass()); //----------------------------------------------------------------------------------------------// diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index b5984a0badccc9..231235fe09a298 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -51,75 +51,6 @@ namespace { -class AutoMixedPrecisionPattern : public pir::RewritePattern { - public: - AutoMixedPrecisionPattern( - pir::IrContext* context, - const phi::Place& place, - const phi::DataType& precision_mode, - bool enable_low_precision_io = false, - pir::PatternBenefit benefit = 1, - const std::vector& generated_names = {}) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { - // precision_mode_ = precision_mode; // should be set by user - // place_ = place; // should be set by user - // // enable_low_precision_io_ = enable_low_precision_io; - // SetDefaultBlacklist(); - // SetDefaultWhitelist(); - } - - void SetDefaultBlacklist() { - // black_list_.insert({ - // paddle::dialect::ExpOp::name(), - // paddle::dialect::SquareOp::name(), - // paddle::dialect::LogOp::name(), - // // paddle::dialect::FetchOp::name(), - - // // paddle::dialect::Mean::name(), - // // paddle::dialect::Sum::name(), - // paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), - // }); - } - - void SetDefaultWhitelist() { - // white_list_.insert({paddle::dialect::FullOp::name(), - // paddle::dialect::Conv2dOp::name(), - // paddle::dialect::TransposeOp::name()}); - // return; - } - - bool Match(pir::Operation* op) const override { - // if enable_low_precision_io_ is true, all the op will be transformed into, - // input and output included - if (op->isa() || op->isa() || - op->isa() || - op->isa()) - return false; - - // if (!enable_low_precision_io_) { - // if (op->isa()) return false; - // } - - // if op is a full op, its user cannot be a scale op - // if (op->isa()) { - // auto use_ops = GetUseOpsForOutput(op, 0); - // for (auto [use_op, idx] : use_ops) { - // if (use_op->isa()) { - // return false; - // } - // } - // } - - // if (!IsBuiltinOp(op)) { - // return OpHasFloatResult(op); - // } - - // return true; - // } - return true; - } -}; - class AutoMixedPrecisionPass : public pir::Pass { public: AutoMixedPrecisionPass(const phi::Place& place, @@ -129,10 +60,8 @@ class AutoMixedPrecisionPass : public pir::Pass { precision_mode_(precision_mode) {} bool Initialize(pir::IrContext* context) override { - pir::RewritePatternSet ps(context); - ps.Add(context, place_, precision_mode_); - // ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + SetDefaultBlacklist(); + SetDefaultWhitelist(); return true; } @@ -171,12 +100,34 @@ class AutoMixedPrecisionPass : public pir::Pass { std::unordered_set white_list_; mutable std::unordered_set op_run_low_precision_; + mutable std::unordered_set op_should_not_handle_; mutable std::unordered_map cached_cast_ops_; + void SetDefaultBlacklist() { + // black_list_.insert({ + // paddle::dialect::ExpOp::name(), + // paddle::dialect::SquareOp::name(), + // paddle::dialect::LogOp::name(), + // // paddle::dialect::FetchOp::name(), + + // // paddle::dialect::Mean::name(), + // // paddle::dialect::Sum::name(), + // paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), + // }); + } + + void SetDefaultWhitelist() { + // white_list_.insert({paddle::dialect::FullOp::name(), + // paddle::dialect::Conv2dOp::name(), + // paddle::dialect::TransposeOp::name()}); + // return; + } + void ProcessBlock(pir::Block* block, pir::Builder& builder) const { // NOLINT for (auto& op_item : *block) { auto op = &op_item; + if (op_should_not_handle_.count(op)) continue; RewriteOp(op, builder); } } @@ -206,6 +157,7 @@ class AutoMixedPrecisionPass : public pir::Pass { OpSupportPrecision(op_type, backend, precision_mode_); } else { // pd op without float result support_low_precision = false; + op_should_not_handle_.insert(op); } if (support_low_precision) { op_run_low_precision_.insert(op); From 02d1bca6bb8996f64482d188375f0b9db68dced4 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 21 Dec 2023 12:48:34 +0000 Subject: [PATCH 16/40] add insert cast_op_num and insert TODO --- .../transforms/auto_mixed_precision_pass.cc | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 231235fe09a298..03fac19c175152 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -72,11 +72,17 @@ class AutoMixedPrecisionPass : public pir::Pass { GetOpPrecision(block); LOG(INFO) << "===========Update Op Precision============" << std::endl; UpdateOpPrecision(block); + + LOG(INFO) << "===========" << op_run_low_precision_.size() << " of " + << block->size() << " ops" + << " run low precision" << std::endl; pir::IrContext* ctx = pir::IrContext::Instance(); pir::Builder builder = pir::Builder(ctx, block); LOG(INFO) << "===========Process Op Precision============" << std::endl; ProcessBlock(block, builder); + LOG(INFO) << "===========Insert Cast Op Num : " << insert_cast_op_num_ + << "============" << std::endl; // pir::GreedyRewriteConfig cfg; // cfg.use_top_down_traversal = true; // cfg.max_iterations = 2; @@ -104,6 +110,8 @@ class AutoMixedPrecisionPass : public pir::Pass { mutable std::unordered_map cached_cast_ops_; + mutable int insert_cast_op_num_ = 0; + void SetDefaultBlacklist() { // black_list_.insert({ // paddle::dialect::ExpOp::name(), @@ -183,6 +191,7 @@ class AutoMixedPrecisionPass : public pir::Pass { } void UpdateOpPrecision(pir::Block* block) { + bool updated = false; // handle full like op for (auto& op_item : *block) { auto op = &op_item; @@ -213,6 +222,12 @@ class AutoMixedPrecisionPass : public pir::Pass { } } } + // 输出的方式比较好 + // 一个op和他的输出是绑定的 + // op1 -> var1 + // var1 -> op2 + // var1 的精度 + // precision should be same as input // if (op->isa()) { // auto input_operation = GetDefiningOpForInput(op, 0); @@ -221,6 +236,10 @@ class AutoMixedPrecisionPass : public pir::Pass { // } // } } + // builtin.combine -> vector type + // reshape op + // reshape -> vector_type + for (auto& op_item : *block) { auto op = &op_item; for (size_t idx = 0; idx < op->num_operands(); ++idx) { @@ -241,6 +260,9 @@ class AutoMixedPrecisionPass : public pir::Pass { } } } + // 产生(op1, op2)的跑在高精度 + // (op1 -> var1, op2 -> var2) => combine => var3 是 op3(属性输入) + // print if op run low precision for (auto& op_item : *block) { auto op = &op_item; @@ -485,6 +507,7 @@ class AutoMixedPrecisionPass : public pir::Pass { builder.Build(value, precision); operand.set_source(cast_op->result(0)); cached_cast_ops_[value] = cast_op; + insert_cast_op_num_++; } bool OpRunLowPrecision(pir::Operation* op) const { From 8ab3566be0220816d7b952c036f15865b5da74b7 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 21 Dec 2023 18:30:50 +0000 Subject: [PATCH 17/40] special handle for CastOp; refactor UpdateOpPrecision using do while; And handle attribute use mkldnn --- .../fluid/inference/api/analysis_predictor.cc | 2 +- .../transforms/auto_mixed_precision_pass.cc | 159 +++++++++++------- 2 files changed, 100 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f06ec762a81bc4..42f1235516e3d7 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -827,7 +827,7 @@ bool AnalysisPredictor::PrepareExecutor() { gpu_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); //----------------------------------------------------------------------------------------------// - // gpu_pm.EnableIRPrinting(); + gpu_pm.EnableIRPrinting(); gpu_pm.Run(pir_program_.get()); } diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 03fac19c175152..e257b892d23bb6 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -60,6 +60,7 @@ class AutoMixedPrecisionPass : public pir::Pass { precision_mode_(precision_mode) {} bool Initialize(pir::IrContext* context) override { + enable_low_precision_io_ = false; SetDefaultBlacklist(); SetDefaultWhitelist(); return true; @@ -190,83 +191,107 @@ class AutoMixedPrecisionPass : public pir::Pass { return true; } - void UpdateOpPrecision(pir::Block* block) { - bool updated = false; - // handle full like op - for (auto& op_item : *block) { - auto op = &op_item; - if (op->isa()) { - auto input_operation = GetDefiningOpForInput(op, 0); - if (!op_run_low_precision_.count(input_operation)) { - op_run_low_precision_.erase(op); + bool CheckUseOpsScalaAttribute( + const std::vector>& use_ops) const { + for (auto [use_op, idx] : use_ops) { + if (use_op->HasInterface()) { + auto [input_infos, _1, _2, _3, _4] = + use_op->dyn_cast() + .GetOpInfo(); + if (input_infos[idx].type_name.find("ScalarAttribute") != + std::string::npos) { + return true; } } } + return false; + } - for (auto& op_item : *block) { - auto op = &op_item; - // remove attribute input op - if (op->HasInterface()) { - auto [input_infos, _1, _2, _3, _4] = - op->dyn_cast().GetOpInfo(); - for (size_t idx = 0; idx < input_infos.size(); ++idx) { - if (op->operand_source(idx) && - input_infos[idx].type_name.find("ScalarAttribute") != - std::string::npos) { - LOG(INFO) << "op name " << op->name() << " try to remove attribute" - << std::endl; - LOG(INFO) << "Remove op name " - << GetDefiningOpForInput(op, idx)->name() << " attribute" - << std::endl; - op_run_low_precision_.erase(GetDefiningOpForInput(op, idx)); + bool CheckOutputIsScalarAttribute(pir::Operation* op) { + for (uint32_t i = 0; i < op->num_results(); i++) { + auto use_ops = pir::GetUseOpsForOutput(op, i); + if (CheckUseOpsScalaAttribute(use_ops)) return true; + } + return false; + } + + void UpdateOpPrecision(pir::Block* block) { + bool precision_updated = false; + do { + precision_updated = false; + // handle full like op + for (auto& op_item : *block) { + auto op = &op_item; + if (op_should_not_handle_.count(op)) continue; + if (!OpRunLowPrecision(op)) continue; + if (op->isa()) { + auto input_operation = GetDefiningOpForInput(op, 0); + if (!op_run_low_precision_.count(input_operation)) { + op_run_low_precision_.erase(op); + precision_updated = true; + } + } + if (!OpRunLowPrecision(op)) continue; + if (op->isa()) { // add for cast op, not cast + // to float. i.e cast to bool + // or int + // if datatype of result0 is not float, then cast op should be not + // handled + auto result_dtype = paddle::dialect::TransToPhiDataType( + pir::GetDataTypeFromValue(op->result(0))); + if (!IsDataTypeFloat(result_dtype)) { + op_run_low_precision_.erase(op); + op_should_not_handle_.insert(op); + precision_updated = true; + } + } + if (!OpRunLowPrecision(op)) continue; + if (CheckOutputIsScalarAttribute(op)) { // Output is ScalarAttribute + LOG(INFO) << "op " << op->name() << " output is ScalarAttribute" + << std::endl; + op_run_low_precision_.erase(op); + precision_updated = true; + } + if (!OpRunLowPrecision(op)) continue; + for (size_t idx = 0; idx < op->num_operands(); ++idx) { + if (!op->operand_source(idx)) continue; + auto operand = op->operand(idx); + if (operand.type() && operand.type().isa()) { + // check if there are all float in the vectortype + auto vec_type = operand.type().dyn_cast(); + if (VectorTypeFloat(vec_type)) { + auto input_operation = GetDefiningOpForInput(op, idx); + // 如果有一个是高精的话,则必须都跑在高精上 + if (!op_run_low_precision_.count(op) || + !op_run_low_precision_.count(input_operation)) { + op_run_low_precision_.erase(op); + op_run_low_precision_.erase(input_operation); + precision_updated = true; + } + } } } } + // 输出的方式比较好 // 一个op和他的输出是绑定的 // op1 -> var1 // var1 -> op2 // var1 的精度 - // precision should be same as input - // if (op->isa()) { - // auto input_operation = GetDefiningOpForInput(op, 0); - // if (!op_run_low_precision_.count(input_operation)) { - // op_run_low_precision_.erase(op); - // } - // } - } - // builtin.combine -> vector type - // reshape op - // reshape -> vector_type - - for (auto& op_item : *block) { - auto op = &op_item; - for (size_t idx = 0; idx < op->num_operands(); ++idx) { - if (!op->operand_source(idx)) continue; - auto operand = op->operand(idx); - if (operand.type() && operand.type().isa()) { - // check if there are all float in the vectortype - auto vec_type = operand.type().dyn_cast(); - if (VectorTypeFloat(vec_type)) { - auto input_operation = GetDefiningOpForInput(op, idx); - // 如果有一个是高精的话,则必须都跑在高精上 - if (!op_run_low_precision_.count(op) || - !op_run_low_precision_.count(input_operation)) { - op_run_low_precision_.erase(op); - op_run_low_precision_.erase(input_operation); - } - } - } - } - } + // builtin.combine -> vector type + // reshape op + // reshape -> vector_type + } while (precision_updated); // 产生(op1, op2)的跑在高精度 // (op1 -> var1, op2 -> var2) => combine => var3 是 op3(属性输入) - // print if op run low precision for (auto& op_item : *block) { auto op = &op_item; - if (op_run_low_precision_.count(op)) { + if (op_should_not_handle_.count(op)) { + LOG(INFO) << "op " << op->name() << " should not be handled" + << std::endl; + } else if (op_run_low_precision_.count(op)) { LOG(INFO) << "op " << op->name() << " run low precision" << std::endl; } else { LOG(INFO) << "op " << op->name() << " run high precision" << std::endl; @@ -618,6 +643,21 @@ class AutoMixedPrecisionPass : public pir::Pass { op->set_attribute("dtype", attr_dtype); } + if (op->HasAttribute("use_mkldnn") && + op->attribute("use_mkldnn").dyn_cast().data() == + true && + op->HasAttribute("mkldnn_data_type")) { // useless now? + std::string mkldnn_data_type = op->attribute("mkldnn_data_type") + .dyn_cast() + .AsString(); + std::string low_precision = phi::DataTypeToString(precision_mode_); + if (mkldnn_data_type != low_precision) { + pir::Attribute attr_mkldnn_data_type = + pir::StrAttribute::get(builder.ir_context(), low_precision); + op->set_attribute("mkldnn_data_type", attr_mkldnn_data_type); + } + } + auto phi_kernel = GetPhiKernelInPrecision(op_type, backend, precision_mode_); PADDLE_ENFORCE( @@ -681,7 +721,6 @@ class AutoMixedPrecisionPass : public pir::Pass { } } }; - } // namespace namespace pir { From 5c5a826c723776768a7d9aa7dcff2d330a76190b Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 05:05:42 +0000 Subject: [PATCH 18/40] delete mkldnn judge --- .../pir/transforms/auto_mixed_precision_pass.cc | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index e257b892d23bb6..d04316f1f75830 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -643,21 +643,6 @@ class AutoMixedPrecisionPass : public pir::Pass { op->set_attribute("dtype", attr_dtype); } - if (op->HasAttribute("use_mkldnn") && - op->attribute("use_mkldnn").dyn_cast().data() == - true && - op->HasAttribute("mkldnn_data_type")) { // useless now? - std::string mkldnn_data_type = op->attribute("mkldnn_data_type") - .dyn_cast() - .AsString(); - std::string low_precision = phi::DataTypeToString(precision_mode_); - if (mkldnn_data_type != low_precision) { - pir::Attribute attr_mkldnn_data_type = - pir::StrAttribute::get(builder.ir_context(), low_precision); - op->set_attribute("mkldnn_data_type", attr_mkldnn_data_type); - } - } - auto phi_kernel = GetPhiKernelInPrecision(op_type, backend, precision_mode_); PADDLE_ENFORCE( From ef7fcbc6781dff3fd10fa401e89b2714f06f095b Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 05:09:12 +0000 Subject: [PATCH 19/40] refine get context logic --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index d04316f1f75830..87844cfc81031a 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -60,6 +60,7 @@ class AutoMixedPrecisionPass : public pir::Pass { precision_mode_(precision_mode) {} bool Initialize(pir::IrContext* context) override { + context_ = context; enable_low_precision_io_ = false; SetDefaultBlacklist(); SetDefaultWhitelist(); @@ -77,8 +78,7 @@ class AutoMixedPrecisionPass : public pir::Pass { LOG(INFO) << "===========" << op_run_low_precision_.size() << " of " << block->size() << " ops" << " run low precision" << std::endl; - pir::IrContext* ctx = pir::IrContext::Instance(); - pir::Builder builder = pir::Builder(ctx, block); + pir::Builder builder = pir::Builder(context_, block); LOG(INFO) << "===========Process Op Precision============" << std::endl; ProcessBlock(block, builder); @@ -98,10 +98,10 @@ class AutoMixedPrecisionPass : public pir::Pass { } private: - pir::FrozenRewritePatternSet patterns_; phi::Place place_; phi::DataType precision_mode_; bool enable_low_precision_io_; + pir::IrContext* context_; std::unordered_set black_list_; std::unordered_set white_list_; From 018e6fe6772e60d47d8eb2ffd9767ea40d8be808 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 13:32:26 +0000 Subject: [PATCH 20/40] handle op has multiple region --- .../transforms/auto_mixed_precision_pass.cc | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 87844cfc81031a..cf4fb4f81708c9 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -69,30 +69,29 @@ class AutoMixedPrecisionPass : public pir::Pass { void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); - pir::Block* block = &module_op.block(); - LOG(INFO) << "===========Get Op Precision============" << std::endl; - GetOpPrecision(block); - LOG(INFO) << "===========Update Op Precision============" << std::endl; - UpdateOpPrecision(block); - - LOG(INFO) << "===========" << op_run_low_precision_.size() << " of " - << block->size() << " ops" - << " run low precision" << std::endl; - pir::Builder builder = pir::Builder(context_, block); - LOG(INFO) << "===========Process Op Precision============" << std::endl; - - ProcessBlock(block, builder); - LOG(INFO) << "===========Insert Cast Op Num : " << insert_cast_op_num_ - << "============" << std::endl; - // pir::GreedyRewriteConfig cfg; - // cfg.use_top_down_traversal = true; - // cfg.max_iterations = 2; - // pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + for (size_t i = 0; i < op->num_regions(); ++i) { + auto& region = op->region(i); + for (auto& block : region) { + LOG(INFO) << "===========Get Op Precision============" << std::endl; + GetOpPrecision(block); + LOG(INFO) << "===========Update Op Precision============" << std::endl; + UpdateOpPrecision(block); + + LOG(INFO) << "===========" << op_run_low_precision_.size() << " of " + << block->size() << " ops" + << " run low precision" << std::endl; + pir::Builder builder = pir::Builder(context_, block); + LOG(INFO) << "===========Process Op Precision============" << std::endl; + + ProcessBlock(block, builder); + LOG(INFO) << "===========Insert Cast Op Num : " << insert_cast_op_num_ + << "============" << std::endl; + } + } } bool CanApplyOn(pir::Operation* op) const override { - return op->isa<::pir::ModuleOp>() && op->num_regions() > 0 && - place_ == paddle::PlaceType::kGPU && + return op->num_regions() > 0 && place_ == paddle::PlaceType::kGPU && (precision_mode_ == phi::DataType::FLOAT16 || precision_mode_ == phi::DataType::BFLOAT16); } From 3b7278a652a5fef37a2d06f3ad28c194c658ec4d Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 14:55:28 +0000 Subject: [PATCH 21/40] refine code --- .../transforms/auto_mixed_precision_pass.cc | 158 ++++++++---------- 1 file changed, 74 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index cf4fb4f81708c9..b19f475417cef3 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -105,12 +105,11 @@ class AutoMixedPrecisionPass : public pir::Pass { std::unordered_set black_list_; std::unordered_set white_list_; - mutable std::unordered_set op_run_low_precision_; - mutable std::unordered_set op_should_not_handle_; - mutable std::unordered_map - cached_cast_ops_; + std::unordered_set op_run_low_precision_; + std::unordered_set op_should_not_handle_; + std::unordered_map cached_cast_ops_; - mutable int insert_cast_op_num_ = 0; + int insert_cast_op_num_ = 0; void SetDefaultBlacklist() { // black_list_.insert({ @@ -132,7 +131,7 @@ class AutoMixedPrecisionPass : public pir::Pass { // return; } - void ProcessBlock(pir::Block* block, pir::Builder& builder) const { // NOLINT + void ProcessBlock(pir::Block* block, pir::Builder& builder) { // NOLINT for (auto& op_item : *block) { auto op = &op_item; if (op_should_not_handle_.count(op)) continue; @@ -155,10 +154,7 @@ class AutoMixedPrecisionPass : public pir::Pass { op->isa()) { support_low_precision = enable_low_precision_io_; } else if (OpHasFloatOpOperand(op) || - OpHasFloatResult(op)) { // pd op with float result, - // but op like not_equal has float input, but has not float result - // if I don't add this, not_equal will be added to op_run_low_precision_ - // if I add this, full_like op will be run at high precision + OpHasFloatResult(op)) { // pd op without float result, auto op_type = op_name.substr(op_name.find(".") + 1); auto backend = ConvertPlaceToBackend(place_); support_low_precision = @@ -178,18 +174,6 @@ class AutoMixedPrecisionPass : public pir::Pass { } } - bool VectorTypeFloat(pir::VectorType vec_type) { - size_t output_num = vec_type.size(); - for (size_t j = 0; j < output_num; j++) { - auto dtype = - vec_type[j].dyn_cast().dtype(); - if (!IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { - return false; - } - } - return true; - } - bool CheckUseOpsScalaAttribute( const std::vector>& use_ops) const { for (auto [use_op, idx] : use_ops) { @@ -238,7 +222,7 @@ class AutoMixedPrecisionPass : public pir::Pass { // handled auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); - if (!IsDataTypeFloat(result_dtype)) { + if (!IsPhiDataTypeFloat(result_dtype)) { op_run_low_precision_.erase(op); op_should_not_handle_.insert(op); precision_updated = true; @@ -258,7 +242,7 @@ class AutoMixedPrecisionPass : public pir::Pass { if (operand.type() && operand.type().isa()) { // check if there are all float in the vectortype auto vec_type = operand.type().dyn_cast(); - if (VectorTypeFloat(vec_type)) { + if (IsVectorTypeFloat(vec_type)) { auto input_operation = GetDefiningOpForInput(op, idx); // 如果有一个是高精的话,则必须都跑在高精上 if (!op_run_low_precision_.count(op) || @@ -411,11 +395,6 @@ class AutoMixedPrecisionPass : public pir::Pass { return KernelSupportPrecision(kernel_fn_str, backend, precision); } - bool ValueInPrecision(pir::Value value, phi::DataType precision) const { - auto dtype = pir::GetDataTypeFromValue(value); - return paddle::dialect::TransToPhiDataType(dtype) == precision; - } - void SetResultDataType(pir::Value result, phi::DataType precision, pir::IrContext* context) const { @@ -458,20 +437,14 @@ class AutoMixedPrecisionPass : public pir::Pass { for (size_t i = 0; i < op->num_operands(); i++) { auto operand = op->operand_source(i); if (!operand.type()) continue; - if (operand.type().isa()) { - auto dtype = pir::GetDataTypeFromValue(operand); - if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { - return true; - } - } else if (operand.type().isa()) { - auto vec_type = operand.type().dyn_cast(); - for (size_t j = 0; j < vec_type.size(); j++) { - auto dtype = - vec_type[j].dyn_cast().dtype(); - if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { - return true; - } - } + if (operand.type().isa() && + IsDenseTensorTypeFloat( + operand.type().dyn_cast())) { + return true; + } else if (operand.type().isa() && + IsVectorTypeFloat( + operand.type().dyn_cast())) { + return true; } } return false; @@ -481,20 +454,11 @@ class AutoMixedPrecisionPass : public pir::Pass { for (size_t i = 0; i < op->num_results(); i++) { auto result = op->result(i); if (!result.type()) continue; - if (result.type().isa()) { - auto dtype = pir::GetDataTypeFromValue(result); - if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { - return true; - } - } else if (result.type().isa()) { - auto vec_type = result.type().dyn_cast(); - for (size_t j = 0; j < vec_type.size(); j++) { - auto dtype = - vec_type[j].dyn_cast().dtype(); - if (IsDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { - return true; - } - } + if (result.type().isa() && + IsDenseTensorTypeFloat( + result.type().dyn_cast())) { + } else if (result.type().isa() && + IsVectorTypeFloat(result.type().dyn_cast())) { } } LOG(INFO) << "op " << op->name() << " doesn't have float result" @@ -502,13 +466,36 @@ class AutoMixedPrecisionPass : public pir::Pass { return false; } - bool IsDataTypeFloat(const phi::DataType& dtype) const { + bool IsPhiDataTypeFloat(const phi::DataType& dtype) const { return dtype == phi::DataType::FLOAT32 || dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16; } - phi::DataType OperandDataType(const pir::OpOperand& operand) const { - auto dtype = pir::GetDataTypeFromValue(operand.source()); + bool IsDenseTensorTypeFloat( + paddle::dialect::DenseTensorType dense_type) const { + auto dtype = dense_type.dtype(); + return IsPhiDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype)); + } + + bool IsVectorTypeFloat(pir::VectorType vec_type) const { + size_t output_num = vec_type.size(); + for (size_t j = 0; j < output_num; j++) { + auto dtype = + vec_type[j].dyn_cast().dtype(); + if (!IsPhiDataTypeFloat(paddle::dialect::TransToPhiDataType(dtype))) { + return false; + } + } + return true; + } + + phi::DataType GetPhiDataTypeFromOpOperand( + const pir::OpOperand& operand) const { + return GetPhiDataTypeFromValue(operand.source()); + } + + phi::DataType GetPhiDataTypeFromValue(const pir::Value& value) const { + auto dtype = pir::GetDataTypeFromValue(value); return paddle::dialect::TransToPhiDataType(dtype); } @@ -517,10 +504,10 @@ class AutoMixedPrecisionPass : public pir::Pass { operand.type().isa(); } - void InsertCastOp(pir::Operation* op, - pir::OpOperand operand, - phi::DataType precision, - pir::Builder& builder) const { // NOLINT + void DoInsertCastOp(pir::Operation* op, + pir::OpOperand operand, + phi::DataType precision, + pir::Builder& builder) { // NOLINT auto value = operand.source(); if (cached_cast_ops_.count(value)) { operand.set_source(cached_cast_ops_[value]->result(0)); @@ -539,7 +526,7 @@ class AutoMixedPrecisionPass : public pir::Pass { } void RewriteBuiltinOp(pir::Operation* op, - pir::Builder& builder) const { // NOLINT + pir::Builder& builder) { // NOLINT LOG(INFO) << "Rewrite builtin op " << op->name() << std::endl; // Rewrite CombineOp if (op->isa()) { @@ -548,10 +535,10 @@ class AutoMixedPrecisionPass : public pir::Pass { if (OpRunLowPrecision(op)) { for (size_t i = 0; i < input_num; ++i) { auto operand = op->operand(i); - auto operand_dtype = OperandDataType(operand); - if (IsDataTypeFloat(operand_dtype) && - operand_dtype != precision_mode_) { - InsertCastOp(op, operand, precision_mode_, builder); + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype != precision_mode_) { + DoInsertCastOp(op, operand, precision_mode_, builder); } } std::vector inputs_type(input_num); @@ -564,9 +551,9 @@ class AutoMixedPrecisionPass : public pir::Pass { } else { for (size_t i = 0; i < input_num; ++i) { auto operand = op->operand(i); - auto operand_dtype = OperandDataType(operand); - if (operand_dtype == precision_mode_) { - InsertCastOp(op, operand, phi::DataType::FLOAT32, builder); + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (operand_phi_dtype == precision_mode_) { + DoInsertCastOp(op, operand, phi::DataType::FLOAT32, builder); } } } @@ -602,16 +589,16 @@ class AutoMixedPrecisionPass : public pir::Pass { // Rewrite FetchOp if (op->isa()) { auto fetch_operand = op->operand(0); - auto fetch_operand_dtype = OperandDataType(fetch_operand); + auto fetch_operand_phi_dtype = GetPhiDataTypeFromOpOperand(fetch_operand); if (OpRunLowPrecision(op)) { SetResultDataType(op->result(0), precision_mode_, builder.ir_context()); } if (!op->result(0).type().isa()) return; auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); - if (fetch_operand_dtype != result_dtype) { + if (fetch_operand_phi_dtype != result_dtype) { LOG(INFO) << "Insert CastOp for FetchOp" << std::endl; - InsertCastOp(op, fetch_operand, result_dtype, builder); + DoInsertCastOp(op, fetch_operand, result_dtype, builder); } return; } @@ -634,7 +621,7 @@ class AutoMixedPrecisionPass : public pir::Pass { << std::endl; if (op->HasAttribute("dtype") && - IsDataTypeFloat( + IsPhiDataTypeFloat( op->attribute("dtype") .data())) { pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( @@ -671,7 +658,9 @@ class AutoMixedPrecisionPass : public pir::Pass { << std::endl; if (out_phi_dtype == phi::DataType::UNDEFINED) out_phi_dtype = precision_mode_; - if (!IsDataTypeFloat(out_phi_dtype)) continue; + if (!IsPhiDataTypeFloat(out_phi_dtype)) + continue; // here handle op like "unequal", which has bool result + // type SetResultDataType(result, out_phi_dtype, builder.ir_context()); } @@ -681,11 +670,12 @@ class AutoMixedPrecisionPass : public pir::Pass { auto operand = op->operand(i); auto in_phi_dtype = input_defs[i].dtype; if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_dtype = OperandDataType(operand); - if (IsDataTypeFloat(operand_dtype) && operand_dtype != in_phi_dtype) { + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype != in_phi_dtype) { LOG(INFO) << "Support low precision, insert CastOp for " << op->name() << " operand " << i << std::endl; - InsertCastOp(op, operand, in_phi_dtype, builder); + DoInsertCastOp(op, operand, in_phi_dtype, builder); } } } else { // current op doesn't support low precision, should cast to float @@ -694,12 +684,12 @@ class AutoMixedPrecisionPass : public pir::Pass { for (size_t i = 0; i < op->num_operands(); i++) { auto operand = op->operand(i); if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_dtype = OperandDataType(operand); - if (IsDataTypeFloat(operand_dtype) && - operand_dtype == precision_mode_) { + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype == precision_mode_) { LOG(INFO) << "Not support low precision, insert CastOp for " << op->name() << " operand " << i << std::endl; - InsertCastOp(op, operand, phi_dtype, builder); + DoInsertCastOp(op, operand, phi_dtype, builder); } } } From 34078a3f5d63db180d24b3f4be6b630bbd8e9519 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 15:43:37 +0000 Subject: [PATCH 22/40] refine --- .../transforms/auto_mixed_precision_pass.cc | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index b19f475417cef3..533da26d1c32c1 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -68,22 +68,21 @@ class AutoMixedPrecisionPass : public pir::Pass { } void Run(pir::Operation* op) override { - auto module_op = op->dyn_cast(); for (size_t i = 0; i < op->num_regions(); ++i) { auto& region = op->region(i); for (auto& block : region) { LOG(INFO) << "===========Get Op Precision============" << std::endl; - GetOpPrecision(block); + GetOpPrecision(&block); LOG(INFO) << "===========Update Op Precision============" << std::endl; - UpdateOpPrecision(block); + UpdateOpPrecision(&block); LOG(INFO) << "===========" << op_run_low_precision_.size() << " of " - << block->size() << " ops" + << block.size() << " ops" << " run low precision" << std::endl; - pir::Builder builder = pir::Builder(context_, block); + pir::Builder builder = pir::Builder(context_, &block); LOG(INFO) << "===========Process Op Precision============" << std::endl; - ProcessBlock(block, builder); + ProcessBlock(&block, builder); LOG(INFO) << "===========Insert Cast Op Num : " << insert_cast_op_num_ << "============" << std::endl; } @@ -256,12 +255,6 @@ class AutoMixedPrecisionPass : public pir::Pass { } } - // 输出的方式比较好 - // 一个op和他的输出是绑定的 - // op1 -> var1 - // var1 -> op2 - // var1 的精度 - // builtin.combine -> vector type // reshape op // reshape -> vector_type @@ -283,7 +276,7 @@ class AutoMixedPrecisionPass : public pir::Pass { } void RewriteOp(pir::Operation* op, - pir::Builder& builder) const { // NOLINT + pir::Builder& builder) { // NOLINT LOG(INFO) << "Rewrite op " << op->name() << std::endl; if (IsBuiltinOp(op)) { RewriteBuiltinOp(op, builder); @@ -581,7 +574,7 @@ class AutoMixedPrecisionPass : public pir::Pass { } void RewritePdOp(pir::Operation* op, - pir::Builder& builder) const { // NOLINT + pir::Builder& builder) { // NOLINT LOG(INFO) << "Rewrite pd op " << op->name() << std::endl; phi::Backend backend = ConvertPlaceToBackend(place_); std::string op_type = op->name().substr(op->name().find(".") + 1); From a3a5fb06caa368e5e046700f493bf7b4a07188aa Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 17:02:06 +0000 Subject: [PATCH 23/40] refine pass parameter --- .../fluid/inference/api/analysis_predictor.cc | 9 +++++-- .../transforms/auto_mixed_precision_pass.cc | 24 +++++++++++++------ .../transforms/auto_mixed_precision_pass.h | 3 +-- .../auto_mixed_precision_test.cc | 10 ++++++-- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 42f1235516e3d7..72ac3d79268f34 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -804,8 +804,13 @@ bool AnalysisPredictor::PrepareExecutor() { // Functional pass // Do auto mixed precision pass first, so do not need to handle // shadowoutput. - gpu_pm.AddPass(::pir::CreateAutoMixedPrecisionPass( - place_, ConvertPrecision(config_.mixed_precision_mode_))); + auto auto_mixed_precision_pass = ::pir::CreateAutoMixedPrecisionPass(); + auto_mixed_precision_pass->SetNotOwned(pir::kPlaceAttr, &place_); + phi::DataType data_type = + ConvertPrecision(config_.mixed_precision_mode_); + auto_mixed_precision_pass->SetNotOwned("__mixed_precision_mode__", + &data_type); + gpu_pm.AddPass(std::move(auto_mixed_precision_pass)); gpu_pm.AddPass(::pir::CreateIdentityOpCleanPass()); //----------------------------------------------------------------------------------------------// diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 533da26d1c32c1..e421d55a67f5a0 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -53,13 +53,24 @@ namespace { class AutoMixedPrecisionPass : public pir::Pass { public: - AutoMixedPrecisionPass(const phi::Place& place, - const phi::DataType& precision_mode) + AutoMixedPrecisionPass() : pir::Pass("auto_mixed_precision_pass", 1), - place_(place), - precision_mode_(precision_mode) {} + place_(phi::CPUPlace{}), + precision_mode_(phi::DataType::FLOAT16) {} bool Initialize(pir::IrContext* context) override { + IR_ENFORCE(Has(pir::kPlaceAttr), + "Pass initialize failed." + "When using AutoMixedPrecisionPass, place attribute is required!" + "Use Set method to set the place attribute."); + IR_ENFORCE(Has("__mixed_precision_mode__"), + "Pass initialize failed." + "When using AutoMixedPrecisionPass, precison_mode attribute is " + "required!" + "Use Set method to set the scope attribute."); + + place_ = Get(pir::kPlaceAttr); + precision_mode_ = Get("__mixed_precision_mode__"); context_ = context; enable_low_precision_io_ = false; SetDefaultBlacklist(); @@ -692,9 +703,8 @@ class AutoMixedPrecisionPass : public pir::Pass { namespace pir { -std::unique_ptr CreateAutoMixedPrecisionPass( - const phi::Place& place, const phi::DataType& precision_mode) { - return std::make_unique(place, precision_mode); +std::unique_ptr CreateAutoMixedPrecisionPass() { + return std::make_unique(); } } // namespace pir diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h index 2544219494a10f..4ab0fb12cb723a 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h @@ -23,7 +23,6 @@ namespace pir { class Pass; -IR_API std::unique_ptr CreateAutoMixedPrecisionPass( - const phi::Place& place, const phi::DataType& precision_mode); +IR_API std::unique_ptr CreateAutoMixedPrecisionPass(); } // namespace pir diff --git a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc index 6f880900e21d00..8c16fd37dd34a8 100644 --- a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc +++ b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc @@ -94,8 +94,14 @@ TEST(AutoMixedPrecisonTest, MixedPrecisionTest) { EXPECT_EQ(program.block()->size(), 11u); pir::PassManager pm(ctx); - pm.AddPass(pir::CreateAutoMixedPrecisionPass(phi::GPUPlace(), - phi::DataType::FLOAT16)); + std::unique_ptr auto_mixed_precision_pass = + pir::CreateAutoMixedPrecisionPass(); + phi::Place place = phi::GPUPlace(); + phi::DataType data_type = phi::DataType::FLOAT16; + auto_mixed_precision_pass->SetNotOwned(pir::kPlaceAttr, &place); + auto_mixed_precision_pass->SetNotOwned("__mixed_precision_mode__", + &data_type); + pm.AddPass(std::move(auto_mixed_precision_pass)); pm.AddPass(pir::CreateDeadCodeEliminationPass()); // pm.EnablePassTiming(); pm.EnableIRPrinting(); From 110793291c98a7270499c8883ce1c56c9da66e13 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 17:22:14 +0000 Subject: [PATCH 24/40] handle op1->combine->op2 --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index e421d55a67f5a0..960c0cffdf391e 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -187,7 +187,11 @@ class AutoMixedPrecisionPass : public pir::Pass { bool CheckUseOpsScalaAttribute( const std::vector>& use_ops) const { for (auto [use_op, idx] : use_ops) { - if (use_op->HasInterface()) { + if (use_op->isa()) { + if (CheckOutputIsScalarAttribute(use_op)) { + return true; + } + } else if (use_op->HasInterface()) { auto [input_infos, _1, _2, _3, _4] = use_op->dyn_cast() .GetOpInfo(); From 1b99b8864cde2b2335321a161f1ac09b8ca427ef Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 26 Dec 2023 17:23:40 +0000 Subject: [PATCH 25/40] handle op1->combine->op2 --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 960c0cffdf391e..bdd1975b45772f 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -204,7 +204,7 @@ class AutoMixedPrecisionPass : public pir::Pass { return false; } - bool CheckOutputIsScalarAttribute(pir::Operation* op) { + bool CheckOutputIsScalarAttribute(pir::Operation* op) const { for (uint32_t i = 0; i < op->num_results(); i++) { auto use_ops = pir::GetUseOpsForOutput(op, i); if (CheckUseOpsScalaAttribute(use_ops)) return true; From b59e2f9d151ac54814762b6dc6026b833c24480f Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 30 Dec 2023 02:00:59 +0000 Subject: [PATCH 26/40] add black list info --- .../transforms/auto_mixed_precision_pass.cc | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index bdd1975b45772f..cc62fc261f5517 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -74,7 +75,6 @@ class AutoMixedPrecisionPass : public pir::Pass { context_ = context; enable_low_precision_io_ = false; SetDefaultBlacklist(); - SetDefaultWhitelist(); return true; } @@ -122,23 +122,15 @@ class AutoMixedPrecisionPass : public pir::Pass { int insert_cast_op_num_ = 0; void SetDefaultBlacklist() { - // black_list_.insert({ - // paddle::dialect::ExpOp::name(), - // paddle::dialect::SquareOp::name(), - // paddle::dialect::LogOp::name(), - // // paddle::dialect::FetchOp::name(), - - // // paddle::dialect::Mean::name(), - // // paddle::dialect::Sum::name(), - // paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), - // }); - } - - void SetDefaultWhitelist() { - // white_list_.insert({paddle::dialect::FullOp::name(), - // paddle::dialect::Conv2dOp::name(), - // paddle::dialect::TransposeOp::name()}); - // return; + black_list_.insert({ + paddle::dialect::ExpOp::name(), + paddle::dialect::SquareOp::name(), + paddle::dialect::LogOp::name(), + paddle::dialect::MeanOp::name(), + paddle::dialect::SumOp::name(), + paddle::dialect::SigmoidCrossEntropyWithLogitsOp::name(), + paddle::dialect::CrossEntropyWithSoftmax_Op::name(), + }); } void ProcessBlock(pir::Block* block, pir::Builder& builder) { // NOLINT From 53ca3e8cb3b97ff011638ccc71dd4ff1579bfa14 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 5 Jan 2024 15:51:49 +0000 Subject: [PATCH 27/40] replace LOG(INFO) with VLOG(6) --- .../transforms/auto_mixed_precision_pass.cc | 78 ++++++++----------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index cc62fc261f5517..6846ccc6b57536 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -82,20 +82,20 @@ class AutoMixedPrecisionPass : public pir::Pass { for (size_t i = 0; i < op->num_regions(); ++i) { auto& region = op->region(i); for (auto& block : region) { - LOG(INFO) << "===========Get Op Precision============" << std::endl; + VLOG(6) << "===========Get Op Precision============" << std::endl; GetOpPrecision(&block); - LOG(INFO) << "===========Update Op Precision============" << std::endl; + VLOG(6) << "===========Update Op Precision============" << std::endl; UpdateOpPrecision(&block); - LOG(INFO) << "===========" << op_run_low_precision_.size() << " of " - << block.size() << " ops" - << " run low precision" << std::endl; + VLOG(6) << "===========" << op_run_low_precision_.size() << " of " + << block.size() << " ops" + << " run low precision" << std::endl; pir::Builder builder = pir::Builder(context_, &block); - LOG(INFO) << "===========Process Op Precision============" << std::endl; + VLOG(6) << "===========Process Op Precision============" << std::endl; ProcessBlock(&block, builder); - LOG(INFO) << "===========Insert Cast Op Num : " << insert_cast_op_num_ - << "============" << std::endl; + VLOG(6) << "===========Insert Cast Op Num : " << insert_cast_op_num_ + << "============" << std::endl; } } } @@ -167,11 +167,10 @@ class AutoMixedPrecisionPass : public pir::Pass { } if (support_low_precision) { op_run_low_precision_.insert(op); - LOG(INFO) << "op " << op->name() << " support low precision" - << std::endl; + VLOG(6) << "op " << op->name() << " support low precision" << std::endl; } else { - LOG(INFO) << "op " << op->name() << " doesn't support low precision" - << std::endl; + VLOG(6) << "op " << op->name() << " doesn't support low precision" + << std::endl; } } } @@ -224,8 +223,8 @@ class AutoMixedPrecisionPass : public pir::Pass { if (op->isa()) { // add for cast op, not cast // to float. i.e cast to bool // or int - // if datatype of result0 is not float, then cast op should be not - // handled + // if datatype of cast op result is not float, then cast op should be + // not handled auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); if (!IsPhiDataTypeFloat(result_dtype)) { @@ -236,8 +235,8 @@ class AutoMixedPrecisionPass : public pir::Pass { } if (!OpRunLowPrecision(op)) continue; if (CheckOutputIsScalarAttribute(op)) { // Output is ScalarAttribute - LOG(INFO) << "op " << op->name() << " output is ScalarAttribute" - << std::endl; + VLOG(6) << "op " << op->name() << " output is ScalarAttribute" + << std::endl; op_run_low_precision_.erase(op); precision_updated = true; } @@ -261,30 +260,22 @@ class AutoMixedPrecisionPass : public pir::Pass { } } } - - // builtin.combine -> vector type - // reshape op - // reshape -> vector_type } while (precision_updated); - // 产生(op1, op2)的跑在高精度 - // (op1 -> var1, op2 -> var2) => combine => var3 是 op3(属性输入) - // print if op run low precision for (auto& op_item : *block) { auto op = &op_item; if (op_should_not_handle_.count(op)) { - LOG(INFO) << "op " << op->name() << " should not be handled" - << std::endl; + VLOG(6) << "op " << op->name() << " should not be handled" << std::endl; } else if (op_run_low_precision_.count(op)) { - LOG(INFO) << "op " << op->name() << " run low precision" << std::endl; + VLOG(6) << "op " << op->name() << " run low precision" << std::endl; } else { - LOG(INFO) << "op " << op->name() << " run high precision" << std::endl; + VLOG(6) << "op " << op->name() << " run high precision" << std::endl; } } } void RewriteOp(pir::Operation* op, pir::Builder& builder) { // NOLINT - LOG(INFO) << "Rewrite op " << op->name() << std::endl; + VLOG(6) << "Rewrite op " << op->name() << std::endl; if (IsBuiltinOp(op)) { RewriteBuiltinOp(op, builder); return; @@ -327,7 +318,7 @@ class AutoMixedPrecisionPass : public pir::Pass { phi::DataType precision, phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) const { auto& phi_op_type = op_type; - LOG(INFO) << "phi_op_type = " << phi_op_type << std::endl; + VLOG(6) << "phi_op_type = " << phi_op_type << std::endl; bool support = PhiKernelSupportPrecision(phi_op_type, backend, precision, layout); @@ -428,8 +419,8 @@ class AutoMixedPrecisionPass : public pir::Pass { auto new_vec_type = pir::VectorType::get(context, results_type); result.set_type(new_vec_type); } else { - LOG(INFO) << "result type is not DenseTensorType or VectorType" - << std::endl; + VLOG(6) << "result type is not DenseTensorType or VectorType" + << std::endl; } } @@ -461,8 +452,7 @@ class AutoMixedPrecisionPass : public pir::Pass { IsVectorTypeFloat(result.type().dyn_cast())) { } } - LOG(INFO) << "op " << op->name() << " doesn't have float result" - << std::endl; + VLOG(6) << "op " << op->name() << " doesn't have float result" << std::endl; return false; } @@ -527,7 +517,7 @@ class AutoMixedPrecisionPass : public pir::Pass { void RewriteBuiltinOp(pir::Operation* op, pir::Builder& builder) { // NOLINT - LOG(INFO) << "Rewrite builtin op " << op->name() << std::endl; + VLOG(6) << "Rewrite builtin op " << op->name() << std::endl; // Rewrite CombineOp if (op->isa()) { // auto vec_type = op->result(0).type().dyn_cast(); @@ -582,7 +572,7 @@ class AutoMixedPrecisionPass : public pir::Pass { void RewritePdOp(pir::Operation* op, pir::Builder& builder) { // NOLINT - LOG(INFO) << "Rewrite pd op " << op->name() << std::endl; + VLOG(6) << "Rewrite pd op " << op->name() << std::endl; phi::Backend backend = ConvertPlaceToBackend(place_); std::string op_type = op->name().substr(op->name().find(".") + 1); @@ -597,7 +587,7 @@ class AutoMixedPrecisionPass : public pir::Pass { auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); if (fetch_operand_phi_dtype != result_dtype) { - LOG(INFO) << "Insert CastOp for FetchOp" << std::endl; + VLOG(6) << "Insert CastOp for FetchOp" << std::endl; DoInsertCastOp(op, fetch_operand, result_dtype, builder); } return; @@ -617,8 +607,8 @@ class AutoMixedPrecisionPass : public pir::Pass { // Other pd ops if (OpRunLowPrecision(op)) { // change result's dtype to low precision - LOG(INFO) << "Change result's dtype to low precision " << op->name() - << std::endl; + VLOG(6) << "Change result's dtype to low precision " << op->name() + << std::endl; if (op->HasAttribute("dtype") && IsPhiDataTypeFloat( @@ -654,8 +644,8 @@ class AutoMixedPrecisionPass : public pir::Pass { auto result = op->result(i); if (!result.type()) continue; phi::DataType out_phi_dtype = output_defs[i].dtype; - LOG(INFO) << "result dtype = " << phi::DataTypeToString(out_phi_dtype) - << std::endl; + VLOG(6) << "result dtype = " << phi::DataTypeToString(out_phi_dtype) + << std::endl; if (out_phi_dtype == phi::DataType::UNDEFINED) out_phi_dtype = precision_mode_; if (!IsPhiDataTypeFloat(out_phi_dtype)) @@ -673,8 +663,8 @@ class AutoMixedPrecisionPass : public pir::Pass { auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); if (IsPhiDataTypeFloat(operand_phi_dtype) && operand_phi_dtype != in_phi_dtype) { - LOG(INFO) << "Support low precision, insert CastOp for " << op->name() - << " operand " << i << std::endl; + VLOG(6) << "Support low precision, insert CastOp for " << op->name() + << " operand " << i << std::endl; DoInsertCastOp(op, operand, in_phi_dtype, builder); } } @@ -687,8 +677,8 @@ class AutoMixedPrecisionPass : public pir::Pass { auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); if (IsPhiDataTypeFloat(operand_phi_dtype) && operand_phi_dtype == precision_mode_) { - LOG(INFO) << "Not support low precision, insert CastOp for " - << op->name() << " operand " << i << std::endl; + VLOG(6) << "Not support low precision, insert CastOp for " + << op->name() << " operand " << i << std::endl; DoInsertCastOp(op, operand, phi_dtype, builder); } } From 2cac6f8a104653c47164402bc42c83d5607b5c18 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Mon, 15 Jan 2024 14:59:10 +0800 Subject: [PATCH 28/40] Auto mixed precision no log (#4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [DimExpr] DimExpr support hash (#60471) * open warning with `paddle.utils.deprecated` (#60458) * open_warning * update unittest * update * fix typos * fix warning in test runner * uncomment * cleanup todo * using VisibleDeprecationWarning * update comment * fix typo * fix indentation * fix * fix * fix indent level and test * update --------- Co-authored-by: SigureMo * [AutoParallel] Auto Trans PP to VPP (#60467) * [AutoParallel] Auto Trans PP to VPP * add comment * 【PIR OpTest Fix No.23】 fix test_distribute_fpn_proposals_op (#60335) * fix * fix * fix test_lookup_table_v2_bf16_op (#60332) * Fix shape error in combined-indexing setitem (#60447) * add ut * fix shape error in combine-indexing * fix ut * [auto parallel] Add pp lazy init, bug fix for xavier (#60441) * [PIR] add slice_array_dense api (#60433) * fix * fix * Set value with scalar (#60452) * set_value with scalar * fix ut * [PIR]Support custom op in PIR (#59790) * support custom op in pir * fix compile bugs * fix bugs * delete code * fix windows bugs * fix windows bugs * add symbol to paddle lib * fix windows bugs * revert code * fix bugs * fix bugs * perfect code according comment * fix py3 * revert third party * fix bugs * fix bug * fix compile bugs * fix windows * [Prim][PIR] support roll, gather, scatter, scatter_nd_add op backward in pir prim (#60481) * prim gather op backward * prim scatter op backward * prim roll op backward * prim scatter_nd op backward * [PIR] delete dense_tensor mem_desc_ (#60024) * delete dense_tensor mem_desc_ * [PIR] Complement op defs (#60475) * complement translation of legacy matmul * Complement op mappings in translation for deformable_conv_v1. * [pir]Supporting constant_folding_pass for train (#60355) * [pir]Supporting constant_folding_pass for train * fix * Update constant_folding_pass.cc * [Dynamic Shape] Fuse shape ops into generate shape op pass (#60490) * add shape.generate_shape op * rename shape.generate_shape to cinn_op.generate_shape * refactor GenerateShapeOp::SymbolBinding * move GenerateShapeOp related helper functions into generate_shape_util.cc * minor fix * minor fix * backup * refine signature of ConvertDimExprToAttribute * minor fix for signature of ConvertDimExprToAttributes * remove SubstituteDimExpr from generate_shape_util.h * Fix compile error * Fix unittest compile error * Code format * Code format * Fix _hiden_size to _hidden_size (#60485) * [DimExpr] Add substitute DimExpr util (#60493) * add SubstituteDimExpr * Fix compile error * Code format * Polish DimExprUtilTest * Change namesapce * Fix unittest * Polish DimExprUtilTest * [xpu]add sine_pos fuse pass and sine_pos xpu kernel (#60025) * add split with variable in factors and rewrite vectorize,unroll,bind error handling mechanism (#60449) * [CodeStyle] Fix regression of Ruff in sot (#60483) * support cast op from FP32 to low precision (#60385) * test=document_fix (#60399) * [XPU] refine flash attention ut (#60474) * [XPU] refine flash attention ut * refine tolerance * [Inference] support collect shape in sub block (#60451) * support collect shape in sub block * udpate * udpate * fix process mesh incorrect set in converter (#60504) * 【CMake opt No.13】Remove CINN DEPS in test/cpp/pir/shape_dialect/CMakeLists.txt (#60517) * Update CMakeLists.txt * Apply suggestions from code review * Apply suggestions from code review * Update CMakeLists.txt * Update CMakeLists.txt * 【pir】 add tensorarray op createarrylike, add_n (#60460) * optimize backward * [PIR] add vjp interface for while op * [PIR] fix ci error. * modify while stopgradient * merge * modify while grad bug * modify while grad op * modify * increment vp * [PIR] add get_used_external_value interface for block. * while case * delete print * delete print * Update python/paddle/autograd/ir_backward.py * [PIR] add unit_test for get_used_external_value * modify while_loop * code_style * modofy ci bug * modify while api * modify ci * modify array * Update python/paddle/autograd/ir_backward.py * Update test/legacy_test/test_cond.py * update * modify array_write grad info * merge * add_n and createarraylike * conflict * modify exe bug * modify kernel choose --------- Co-authored-by: winter-wang <1030748926@qq.com> * Add align iter space tactic (#60498) Add align iter space tactic * [Dynamic Shape] Add helper function MakeGenerateShapeOpAttribute (#60512) * add helper function MakeGenerateShapeOpAttribute * fix complier complaint * Code format * [Prim][PIR] Set prim gflag for pure cpp (#60505) * inference support decomp * polish code * add decomp base define * add decomp base define2 * change decomp infer * fix symbol overload * fix test case * debug * debug * decomp add debug info * add cpp flag * revert * remove unused flag * [PIR] Refine and fix pir exe (#60443) * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * update 2023 security advisory, test=document_fix (#60527) * [Inference] refine common/*.h for inference lib (#60513) * 【complex op】No.19 add complex support for triangular_solve (#59529) * fix reshard dist_attr (#60535) * 【auto parallel】剔除切分推导相关的头文件对proto 的依赖 (#60543) * decouple proto * format * format * strcuct pre def * [PIR] Support Operation::Clone Interface (#60536) * [PIR] Support Operation::Clone Interface * modify into shared_ptr * [Dynamic Shape] Add FullyInsertBroadcastPass and Broadcast Op (#60511) * add ShapeBroadcastOp * add pass FullyInsertBroadcastPass * InferSymbolicShape of BroadcastShape Op * Delete unit test * Fix return error * Code format * Fix error message * Update paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --------- Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> * Fix OpTranslatorTest name (#60518) * fix name * fix name * fix name * fix name * [PIR] migrate DataFeeder into pir (#60434) * 【PIR API adaptor No.90,92】Migrate some ops into pir (#59801) * [DimExpr] Convert Broadcast to BroadcastTree (#60440) * backup BroadcastTree * add SubstituteDimExpr * add helper function ConstructBroadcastTree * Fix compile error * Code format * Polish DimExprUtilTest * Add cmake file * Change namesapce * Fix compile error * Fix unittest * reconstruct BroadcastTree * Polish DimExprUtilTest * Reconstruct BroadcastTree * Finish BroadcastBranch * Finish BroadcastBranch * Finish BroadcastBranch * Add Unittest * Remove unnecessary dim_expr_util * Add header file * [Dynamic Shape] Erase expand (#60525) * EraseExpandOp * minor fix * minor fix * Code format * [inference] Support wint4 groupwise with cutlass gemm (#60422) * support gemv-groupwise func && weightQuanter-groupwise && weightDeQuanter-groupwise * fix build bug * add unit_test && fix bug * delete useless code * fix ci build bug * fix ci && optimize * fix merge conflict * add op change info * fix weight_only_linear_pass * fix format * solve ci unit_test * init * support cutlass gemm with groupwise * add unit test * fix strange bug * delete random bug * fix sm70 build bug * try to fix ci build bug * fix bug * fix volta build bug * skip sm70 in groupwise mode * change cutlass branch * simplify extent of loop after fuse and add corresponding test case (#60538) * fix bug of put_along_axis (#60551) * remove clearPass to allow custom device use fusion under fp16 (#60541) * fix fleetutil get_online_pass_interval bug2; test=develop (#60544) * fix vs2017 limit (#60528) * 【Hackathon 5th No.20】为 Paddle 新增 Exponential 和 Gamma API (#57899) * add exponential * add gamma distribution * refine docs * add kl_divergence and test * resolve conflicts * resolve conflicts * fix bug * refine test * fix test timeout * refine code * add standard_gamma kernel * fix comments * fix tests * fix tests * fix comments * fix tests * fix gamma grad * fix yaml * fix bugs * fix tests * fix standard_gamma_grad * fix test * fix test * add cdf & icdf * add cdf & icdf * refine comments * fix * fix * fix head file * fix * fix cuda op * fix * fix * refine test * fix test * refine comments * fix comments * fix * fix * fix type check * fix docs * delete useless comments * [CINN] Add IntrinsicOps into ir_codes_collector (#60556) This PR fixed a bug of running Resnet PaddleClas. The bug is due to vectorize introduce an intrinsic GetAddr and we didn't collect the tensor of GetAddr in ir_node_collector, this would caused tensor alias won't create in cuda code. TODO: we may modify IntrinsicOp in the near future * 【auto parallel】custom op spmd rule register (#60509) * custom op spmd rule register * custom op spmd rule register * custom op spmd rule register * custom op spmd rule register * polish * 【AutoParallel】Add master grad in AMP-O2 of AutoParallel (#59987) * add master_grad in auto-parallel * reset third_party * fix coverage * support bf16 master_grad * fix bug in master_grad * change code according to review * change the way to find optimizer op * [Dy2St] Fix `NameloadJstTransformer` missing transform call kwargs (#60515) --------- Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com> * cinn(backends): generate infer shape kernel to infer shape of output tensor (#60519) 通过二维指针来返回后端infer shape的结果。生成的cinn ir如下。tensor_shape_args是一个二维指针。 infer_shape_set_value(0, 0, S1, tensor_shape_args) 表示将第0个output tensor的第0维设置为S1。 * fix tensor math method inplace converter (#60546) * [xpu]Add vis_decoder_attention_xpu_pass && modify qkv_attention_xpu_kernel (#60361) * [Prim][PIR] support abs, instance_norm op backward in prim pir (#60444) * abs op backward * add test case * update code * update code * update code * update code * update code * instance_norm op backward * add instance_norm_v2 test cast * custom op * [PIR] remove log simply name mechnism from phi to common. (#60507) * [InferSymbolicShape] Delete redundent value_id_to_shapeordata_ (#60554) * 【Hackathon 5th No.25】add gammaln api (#60553) * fix (#60570) * [CINN] Add tile tactic and bind cuda tactic (#60534) * [CINN] Add tile tactic * [CINN] Add bind cuda tactic * 【PIR OpTest Fix No.8】 fix test_shuffle_batch_op (#59631) * fix test_shuffle_batch_op * fix * 【PIR OpTest Fix No.14】 fix test_nce (#60255) * fix test_nce * fix test_nce * Update ops.yaml * fix * Update utils.cc * Update ops.yaml * 【PIR OpTest Fix No.19】 fix test_ftrl_op (#60329) * fix test_ftrl_op * fix * [auto parallel] Lazy init for MP. Add reshard infer shape. (#60563) * [PIR] Add unittest for Operation::Clone and Group::Clone (#60577) * [PIR] dce pass disable custom op (#60578) * [Inference] Fix bug of RunWithExternalStream API in new executor (#60122) * fix bug of RunWithExternalStream API in new executor * add test * fix bug of RunWithExternalStream API in new executor * reset flage in RunWithExternalStream * fix bug * add param swith_stream * fix bug * modify python api * fix bug * Resubmit PR-58859 (#60310) * allow multiple rng state in generator * Fix 60142; Fix some comments from sneaxiy * Overwrite copy constructors * add api * pre-commit * tensor_array slice in PIR (#60503) * use slice_array, now will meet error of destory opresult still in use * disable the pir test until the bug fixed * Set DistModel state_dict keys to structure_names (#60478) * exclude xpu * check structure name mapping * test pp * polish * support dynamic save static load * support dygraph save static load * polish * polish * use structured_name as key in DistModel state_dict * polish * polish * fix checkpoint path conflict * test get_rank_to_files * static save dynamic load test * fix sm75 build bug (#60583) * replace LOG(INFO) with VLOG(6) * Add CanProveDivisible for symbolic calculation (#60572) * add CanProveDivisible for symbolic calculation * delete extra cout for debug * fix according to some comments * [PIR][DynamicShape] make shape pass default and fix some bugs (#60548) att, make shape pass default and fix some bugs * Fix words (#60603) * 【auto parallel】custom op use spmd rule (#60571) * custom op use smpd rule * custom op use smpd rule * [auto parallel] add lazy init ut to llama (#60585) * 【pir】 modify array_write and array_read vjp , add a simple while with array_write (#60575) * optimize backward * [PIR] add vjp interface for while op * [PIR] fix ci error. * modify while stopgradient * merge * modify while grad bug * modify while grad op * modify * increment vp * [PIR] add get_used_external_value interface for block. * while case * delete print * delete print * Update python/paddle/autograd/ir_backward.py * [PIR] add unit_test for get_used_external_value * modify while_loop * code_style * modofy ci bug * modify while api * modify ci * modify array * Update python/paddle/autograd/ir_backward.py * Update test/legacy_test/test_cond.py * update * modify array_write grad info * merge * add_n and createarraylike * conflict * modify array_write vjp * modify array_write vjp * Update paddle/fluid/pybind/manual_static_op_function.h * modify array_write vjp * modify ci bug * modify * modify * Update test/legacy_test/test_while_loop_op.py * modify inplace array_read * Update test/legacy_test/test_while_op.py * Update test/ir/pir/test_while_api.py --------- Co-authored-by: winter-wang <1030748926@qq.com> * [Prim][PIR] add leaky_relu, sigmoid, instance_norm op forward prim (#60564) * hardswish op prim sink * hardswish op prim * add composite * add leaky_relu, sigmoid op forward prim * remove hardswish op forward * add instance_norm op forward prim * [CINN]Add bucket context (#60549) * [CINN] Add tile tactic * [CINN] Add bind cuda tactic * [CINN] Add bucket contexts * fix group output args bug * Add CUDNNv8 max pooling (#59413) * Add CUDNNv8 version of pool2d * Minor fix * Fix build failure * Remove dygraph API * Fix CI failure * Fix CI failure * Fix timeout * Fix timeout * Add comments * Minor fix * update lbfgs to avoid the randomness caused by paddle.dot() temporarily (#60591) * update lbfgs to avoid the randomness caused by paddle.dot() temporarily * add note * set_pir_tests_properties for some tests (#60401) * fix * Update CMakeLists.txt * Update pir_op_test_white_list * Update pir_op_test_white_list * Update pir_op_test_white_list * Add tests to whitelist (#60522) * fix * add * fix double grad without convert inplace (#60614) * fix fleetutil get_online_pass_interval bug3 (#60615) * fix fleetutil get_online_pass_interval bug3; test=develop * fix fleetutil get_online_pass_interval bug3; test=develop * fix fleetutil get_online_pass_interval bug3; test=develop * [PIR][DynamicShape] Add an example for broadcast in dynamic shape infer (#60608) * Add an example for broadcast in dynamic shape infer * fix_convert_all_blocks (#60613) * fix_convert_all_blocks * [Paddle-TRT] support set_value dynamic shape (#60508) [Paddle-TRT] support set_value dynamic shape (#60508) * fix (#60625) * [PIR] Support Region Clone in Operation::Clone (#60590) * deg2rad test passed (#60619) * [PIR+CINN]Fix Pool2d Variant Attibute for kernel_size (#60623) * [PIR+CINN]Fix Pool2d Variant Attibute for kernel_size * fix padding_size * fix pooling_type * [SOT] move_gpu_pinned_to_gpu (#60395) * PIR API adaptor No.35、40】 Migrate paddle.nn.ChannelShuffle/ClipGradByNorm into pir (#60445) * fix some bugs * fix bugs * Update clip.py * Update test_channel_shuffle.py * Update test_clip_by_norm_op.py * Update test_clip_by_norm_op.py * add param name for dist_tensor parameter (#60574) * Fix (#60631) * [PIR] Reify InferSymbolicShapeInterface (#60438) * Reify InferSymbolicShapeInterface * [Dynamic Shape] Remove ShapeBroadcastOp redundant codes (#60609) * [Dy2St] fix `test_grad` in PIR mode (#60621) --------- Co-authored-by: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> * reconstruct llama ci cases (#60637) * 【AutoParallel】Unify the fp16 and bf16 in auto-parallel (#60514) * unify the fp16 and bf16 * change white_list in AMP * add dtype support * fix bug in dtype * [Dynamic Shape] Add SplitGenerateShapeIntoShapeOpsPass (#60624) * [Dynamic Shape] Add SplitGenerateShapeIntoShapeOpsPass * Fix compile error * Fix compile error * update pdsa-2023-019, test=document_fix (#60646) * [SOT] sot export test files (#60547) * Improve the performence of put_along_axis (#60618) * fix bug of put_along_axis * improve performence of put_along_axis * [AutoParallel] Fit vpp for gradient_merge pass (#60560) * add dist attr * add op namescope * add test_semi_auto_parallel_hybrid_strategy (#60537) * [PIR]Open uts for AdaptiveAvgPool3D (#60636) * test (#60654) * [CINN] Add OptimizeReductionTactic (#60661) * [Paddle-Trt]update set_value cmakelist (#60664) [Paddle-Trt]update set_value cmakelist * [auto parallel] fix reshape infer shape (#60632) * [CINN+PIR]Clean Old GroupScheduler logic and switch into new_group_scheduler (#60642) * [CINN]Fix HasDynamicShape Bug while Type is NULL (#60658) * [PIR] pir onednn support legact istruction and lrn (#60502) * pir onednn support legact istruction and lrn * c_softmax_with_cross_entropy support bf16 for xpu (#60472) * enable custom device to use silu_fuse_pass (#60595) move SetUseCustomDevice to all platform * [XPU] add empty_like op and test, update XHPC to 20240105 (#60617) * [XPU] update XHPC date and refine FA ut (#60598) * [XPU] update XHPC date * update comments for ut * correct adamw bf16 unit test and the way to get data type (#60565) * Fix some PADDLE_THROW error type and change test cases (#60487) * fix error type * fix TypeError fix type fix fix fix fix * fix typo * as_complex as_real check_grad (#60666) * [Fix Bug] Fix Bugs of Two Pass (#60626) * [Fix Bug] Fix Bugs of Two Pass * Fix GenerateShapeOp bug * Modify unit test * Fix MakeGetterDimExpr4SymbolName * 【Hackathon 5th No.34】为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API (#58092) * This PR enable offset of generator for custom device. (#60616) * [SOT] Convert dtype to `DataType` in PIR mode (#60627) * [PIR] Change output to block_arg from copy to a shared for the execution of while (#60607) * test * fix * fix * fix * 【auto parallel】custom op spmd infer add args check (#60633) * add bound check * add bound check * [PIR] Open PIR flag for test_ifelse (#60685) * open pir flag for test_ifelse * Update test_ifelse.py * Update test_ifelse.py * [CIN+PIR]Fix SplitOpPattern Bug in pd_to_cinn_pass (#60669) * [CIN+PIR]Fix SplitOpPattern Bug in pd_to_cinn_pass * fix index error * refine pir_all_path UT * fix bug * fix uncontiguous tensor resize bug (#60684) * fix uncontiguous tensor resize bug * [PIR]Support inplace custom op in pir (#60529) * support inplace in pir * fix inference ut * fix win bugs * fix win bug * fix * polish code * polish code * print log * print log * debug * fix win bugs * fix windows * fix (#60634) * [Docs] Update latest release version in README (#60691) * [CINN] Refine cmake for pass in cinn (#60683) * refine cmake for pass in cinn * add dependency in cmake * add dependency in cmake * [PIR]Open uts for PReLU (#60645) * [PIR]Open uts for ReLU6 (#60650) * [PIR]Open uts for RReLU (#60660) * [NPU] fix storage_properties type mismatch with OneDNN and NPU (#60566) * fix ttfnet_darknet53_1x_coco in pir mode (#60663) * [auto parallel] shard tensor stop gradient support (#60699) * [PIR][DynamicShape] Polish some codes (#60651) att, polish some codes * [PIR] fix onednn double reg (#60720) * fix onednn double reg * 【pir】modify add_n in while use blockarg instead of input value (#60668) * test * fix * fix * fix * modify add_n block_arg * modify increment return value * merge * modfiy whiel_op.py --------- Co-authored-by: zhangbo9674 * [PIR] Open test_case ut (#60721) * fix * fix * [PIR] rename data_layout (#60678) * rename data_layout * [xpu]: check op is null (#60656) * 【Hackathon 5th No.1】 为 Paddle 新增 copysign API (#57785) * add copysign op * fix codestyle * codestyle * fix test * fix std bug * merge init * merge init * merge init * add static cast * add std * static cast * static cast * copysignf * static cast to float input * float input * static cast to double input * fix * add inplace test * fix api * fix cast when grad * modify paddle.cast_ to cast_ * remove cast in python api * support fp16 && bf16 * set grad y to zero * fix en doc * support number input * add hostdevice * refactor kernel * fix nan when backward * add broadcast unit test * modify .cu * Update __init__.py * Update __init__.py * for ci test * static float * codestyle * static double * fix broadcast, try coverage * Delete paddle/phi/kernels/funcs/broadcast_function.h * remove unused * Update math.py * Update math.py * fix en doc * add test for output dtype, integer unsupported for now * update * update * fix * fix * add cast for input * fix * add pir test * fix doc * fix doc * fix doc * detail doc * adjust for MSVC * fix * Update python/paddle/tensor/math.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * Update python/paddle/tensor/math.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * fix doc output dtype, fix Equation * codestyle * codestyle * Update math.py --------- Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * rms_norm_infer_spmd (#60709) * [PIR]Open more tests for bernoulli and celu (#60706) * bernoulli && celu * celu test_error * [PIR]Open uts for scatter_nd_add (#60698) * [PIR]Open uts for scatter_nd_add * Fix ut * [PIR]Open uts for sinh (#60714) * [PIR]Open uts for Softshrink and Softsign (#60716) * [PIR] polish the ir_mapping implimentation. (#60675) * [PIR] fix onednn layout transform yaml format (#60680) * fix onednn layout transform yaml format * 【CINN】Complete error handler mechanism of dynamic schedule (#60718) * complete error handler mechanism of dynamic schedule * fix some output info * fix windows C++17 bug (#60736) * [XPU] fc pass and delete pass nodes check (#60314) * fix_local_windows_compile (#60682) * [PIR] fix onednn dialect name (#60665) * fix onednn dialect name * 【pir】add tesnor to array kernel etc (#60703) * merge * modfiy kernel * modify net * modify print * Fix defition definition (#60679) * cholesky and cholesky_solve tests (#60726) * [PIR]Open uts for searchsorted (#60700) * [PIR]Open uts for selu (#60702) * [PIR]Open uts for selu * Fix ut * [PIR]Open uts for sequence_mask (#60704) * [PIR] adjust pir pass log printing (#60723) * adjust pir pass log printing * update * update * update * fix compile * Fix Throughtput Throughput (#60741) * please last md (#60749) * [CINN+PIR]Fix Fetch XShape Variable logic (#60722) * [PIR][DynamicShape] Remove redundant code for shapeAnalysis and shapedTypeInterface (#60744) att, remove redundant code for shapeAnalysis and shapedTypeInterface * 【PIR Dist Op Reg No.1】 reg push_sparse_v2 (#60473) * code reg push_sparse_v2 * [Dynamic Shape] Provide operator<< For BroadcastTree (#60730) * [PIR] change IR clone to const and support clone operation successors (#60752) * support ir clone const and support clone operation successors * refine ir_mapping * refine region clone * [CINN] Refine fully_insert_broadcast_pass (#60676) * refine fully_insert_broadcast_pass * fix complie bug * fix complie * fix conflict * [PIR] einsum's inner_cache and xshape set to optional (#60748) * einsum's inner_cache and xshape set to intermediate * Update paddle/fluid/pir/dialect/operator/ir/ops.yaml --------- Co-authored-by: kangguangli * reduce runtime of unit-tests in windows-trt (#60731) * modify trt test to deal with Timeout * windows * [Paddle-TRT] upgrade EnqueueV2 to EnqueueV3 (#59950) * 【Hackathon 5th No.110】为 Paddle 增强 sparse.matmul API (#59890) * Fix rank_relatvie rank_relative (#60770) * add graph_key to specific graph's varmap (#60567) * add graph_key to specific graph's varmap * fix inpalce case * fix inpalce case * 【Hackathon 5th No.38】为 Paddle 新增 FractionalMaxPool2d / FractionalMaxPool3d API -kernel (#59847) * [Init] add fractional max pool kernel and api * [Fix] pooling.cu seed offset * [Change] remove adaptive from fractional max pool * [Change] fractional max 2d gpu pooling.cu grad * [Change] fractional max 2d gpu pooling.cu grad with dim3 * [Change] use UnchangedInferMeta * [Change] test api with uint16 * [Change] wrap test disable_static * [Change] regiester float16/bfloat16 * [Change] remove bfloat16 from cpu kernrl * [Change] test dtypes in cpu and gpu * [Change] test_fractional_max_pool3d_2d/3d timeout to 30s * [Fix] resolve conflict * [Change] win32 cannot detect bfloat16 correctly * [Change] force set_device * [Add] test random_u is None * [Change] use kernel_size for overlapping mode * [Change] clean headers * [CodeStyle] pooling * [Change] rename op * [Change] rename func without index * [Prim][PIR] Recover pir bn (#60689) * reopen bn prim pir * fix atol * decomp support batch_norm_ * fix test case * fix bug * fix code * [PIR]fc_with_special_op_fuse_pass bug fix (#60751) * bug fix update * update * delete all debug message * add code deleted wrong at last commit * delete createAutoMixedPrecisionPass in analysis_predictor.cc --------- Co-authored-by: HongyuJia Co-authored-by: ooo oo <106524776+ooooo-create@users.noreply.github.com> Co-authored-by: SigureMo Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Co-authored-by: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Co-authored-by: JYChen Co-authored-by: Yuang Liu Co-authored-by: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Co-authored-by: YuanRisheng Co-authored-by: kevin Co-authored-by: wanghuancoder Co-authored-by: kangguangli Co-authored-by: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Co-authored-by: co63oc Co-authored-by: NeroLoh <745827440@qq.com> Co-authored-by: 傅剑寒 Co-authored-by: lzydev Co-authored-by: tianshuo78520a <707759223@qq.com> Co-authored-by: houj04 <35131887+houj04@users.noreply.github.com> Co-authored-by: Yuanle Liu Co-authored-by: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Co-authored-by: 张春乔 <83450930+Liyulingyue@users.noreply.github.com> Co-authored-by: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Co-authored-by: winter-wang <1030748926@qq.com> Co-authored-by: BiynXu <62832681+BiynXu@users.noreply.github.com> Co-authored-by: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Co-authored-by: Vigi Zhang Co-authored-by: zbt78 <1095497213@qq.com> Co-authored-by: liuzhenhai93 Co-authored-by: Aurelius84 Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com> Co-authored-by: LoneRanger <836253168@qq.com> Co-authored-by: freeliuzc Co-authored-by: YibLiu <68105073+YibinLiu666@users.noreply.github.com> Co-authored-by: engineer1109 Co-authored-by: danleifeng <52735331+danleifeng@users.noreply.github.com> Co-authored-by: xuxinyi389 <104957571+xuxinyi389@users.noreply.github.com> Co-authored-by: MayYouBeProsperous Co-authored-by: Huihuang Zheng Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com> Co-authored-by: 6clc Co-authored-by: Terry <38135104+TR666@users.noreply.github.com> Co-authored-by: winter-wang <78149749+winter-wang@users.noreply.github.com> Co-authored-by: Wang Xin Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> Co-authored-by: Frank Lin Co-authored-by: pangengzheng <117730991+pangengzheng@users.noreply.github.com> Co-authored-by: lanxianghit <47554610+lanxianghit@users.noreply.github.com> Co-authored-by: Tian Zheng Co-authored-by: lijialin03 <124568209+lijialin03@users.noreply.github.com> Co-authored-by: Wangzheee <634486483@qq.com> Co-authored-by: zhink <33270771+zhink@users.noreply.github.com> Co-authored-by: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Co-authored-by: Chen Zhiyang <1792266893@qq.com> Co-authored-by: feifei-111 <2364819892@qq.com> Co-authored-by: fsczz <57291768+fsczz@users.noreply.github.com> Co-authored-by: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Co-authored-by: Sonder <55493212+AndSonder@users.noreply.github.com> Co-authored-by: Liujie0926 <44688141+Liujie0926@users.noreply.github.com> Co-authored-by: WangZhen <23097963+0x45f@users.noreply.github.com> Co-authored-by: risemeup1 <62429225+risemeup1@users.noreply.github.com> Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com> Co-authored-by: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Co-authored-by: Jianbang Yang Co-authored-by: enzodechine Co-authored-by: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com> Co-authored-by: coco <69197635+cocoshe@users.noreply.github.com> Co-authored-by: zhaohaixu <49297029+zhaohaixu@users.noreply.github.com> Co-authored-by: chen2016013 <111894720+chen2016013@users.noreply.github.com> Co-authored-by: zyfncg Co-authored-by: Qi Li Co-authored-by: zhangbo9674 Co-authored-by: Liuyinfeng <30849840+gitliuyf@users.noreply.github.com> Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> Co-authored-by: wendaxiao <113992173+wenxiaohahaha@users.noreply.github.com> Co-authored-by: cyberslack_lee Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com> Co-authored-by: GGBond8488 <33050871+GGBond8488@users.noreply.github.com> Co-authored-by: megemini --- .../fluid/inference/api/analysis_predictor.cc | 10 ---- .../transforms/auto_mixed_precision_pass.cc | 50 ++----------------- 2 files changed, 3 insertions(+), 57 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index de6a2d0c97189c..a8c73c32183988 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -105,7 +105,6 @@ #endif #include "paddle/fluid/ir_adaptor/translator/translate.h" -#include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" @@ -807,15 +806,6 @@ bool AnalysisPredictor::PrepareExecutor() { //----------------------------------------------------------------------------------------------// // Functional pass - // Do auto mixed precision pass first, so do not need to handle - // shadowoutput. - auto auto_mixed_precision_pass = ::pir::CreateAutoMixedPrecisionPass(); - auto_mixed_precision_pass->SetNotOwned(pir::kPlaceAttr, &place_); - phi::DataType data_type = - ConvertPrecision(config_.mixed_precision_mode_); - auto_mixed_precision_pass->SetNotOwned("__mixed_precision_mode__", - &data_type); - gpu_pm.AddPass(std::move(auto_mixed_precision_pass)); gpu_pm.AddPass(::pir::CreateIdentityOpCleanPass()); //----------------------------------------------------------------------------------------------// diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 6846ccc6b57536..0bf137bb09a23b 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -82,20 +82,10 @@ class AutoMixedPrecisionPass : public pir::Pass { for (size_t i = 0; i < op->num_regions(); ++i) { auto& region = op->region(i); for (auto& block : region) { - VLOG(6) << "===========Get Op Precision============" << std::endl; GetOpPrecision(&block); - VLOG(6) << "===========Update Op Precision============" << std::endl; UpdateOpPrecision(&block); - - VLOG(6) << "===========" << op_run_low_precision_.size() << " of " - << block.size() << " ops" - << " run low precision" << std::endl; pir::Builder builder = pir::Builder(context_, &block); - VLOG(6) << "===========Process Op Precision============" << std::endl; - ProcessBlock(&block, builder); - VLOG(6) << "===========Insert Cast Op Num : " << insert_cast_op_num_ - << "============" << std::endl; } } } @@ -144,7 +134,6 @@ class AutoMixedPrecisionPass : public pir::Pass { void GetOpPrecision(pir::Block* block) { for (auto& op_item : *block) { auto op = &op_item; - VLOG(6) << "op name " << op->name(); auto op_name = op->name(); bool support_low_precision = true; if (black_list_.count(op_name)) { @@ -167,10 +156,6 @@ class AutoMixedPrecisionPass : public pir::Pass { } if (support_low_precision) { op_run_low_precision_.insert(op); - VLOG(6) << "op " << op->name() << " support low precision" << std::endl; - } else { - VLOG(6) << "op " << op->name() << " doesn't support low precision" - << std::endl; } } } @@ -235,8 +220,6 @@ class AutoMixedPrecisionPass : public pir::Pass { } if (!OpRunLowPrecision(op)) continue; if (CheckOutputIsScalarAttribute(op)) { // Output is ScalarAttribute - VLOG(6) << "op " << op->name() << " output is ScalarAttribute" - << std::endl; op_run_low_precision_.erase(op); precision_updated = true; } @@ -261,21 +244,10 @@ class AutoMixedPrecisionPass : public pir::Pass { } } } while (precision_updated); - for (auto& op_item : *block) { - auto op = &op_item; - if (op_should_not_handle_.count(op)) { - VLOG(6) << "op " << op->name() << " should not be handled" << std::endl; - } else if (op_run_low_precision_.count(op)) { - VLOG(6) << "op " << op->name() << " run low precision" << std::endl; - } else { - VLOG(6) << "op " << op->name() << " run high precision" << std::endl; - } - } } void RewriteOp(pir::Operation* op, pir::Builder& builder) { // NOLINT - VLOG(6) << "Rewrite op " << op->name() << std::endl; if (IsBuiltinOp(op)) { RewriteBuiltinOp(op, builder); return; @@ -318,7 +290,6 @@ class AutoMixedPrecisionPass : public pir::Pass { phi::DataType precision, phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) const { auto& phi_op_type = op_type; - VLOG(6) << "phi_op_type = " << phi_op_type << std::endl; bool support = PhiKernelSupportPrecision(phi_op_type, backend, precision, layout); @@ -419,8 +390,8 @@ class AutoMixedPrecisionPass : public pir::Pass { auto new_vec_type = pir::VectorType::get(context, results_type); result.set_type(new_vec_type); } else { - VLOG(6) << "result type is not DenseTensorType or VectorType" - << std::endl; + PADDLE_THROW(phi::errors::Unimplemented( + "result type is not DenseTensorType or VectorType")); } } @@ -452,7 +423,6 @@ class AutoMixedPrecisionPass : public pir::Pass { IsVectorTypeFloat(result.type().dyn_cast())) { } } - VLOG(6) << "op " << op->name() << " doesn't have float result" << std::endl; return false; } @@ -517,10 +487,8 @@ class AutoMixedPrecisionPass : public pir::Pass { void RewriteBuiltinOp(pir::Operation* op, pir::Builder& builder) { // NOLINT - VLOG(6) << "Rewrite builtin op " << op->name() << std::endl; // Rewrite CombineOp if (op->isa()) { - // auto vec_type = op->result(0).type().dyn_cast(); auto input_num = op->num_operands(); if (OpRunLowPrecision(op)) { for (size_t i = 0; i < input_num; ++i) { @@ -572,10 +540,8 @@ class AutoMixedPrecisionPass : public pir::Pass { void RewritePdOp(pir::Operation* op, pir::Builder& builder) { // NOLINT - VLOG(6) << "Rewrite pd op " << op->name() << std::endl; - phi::Backend backend = ConvertPlaceToBackend(place_); std::string op_type = op->name().substr(op->name().find(".") + 1); - + phi::Backend backend = ConvertPlaceToBackend(place_); // Rewrite FetchOp if (op->isa()) { auto fetch_operand = op->operand(0); @@ -587,7 +553,6 @@ class AutoMixedPrecisionPass : public pir::Pass { auto result_dtype = paddle::dialect::TransToPhiDataType( pir::GetDataTypeFromValue(op->result(0))); if (fetch_operand_phi_dtype != result_dtype) { - VLOG(6) << "Insert CastOp for FetchOp" << std::endl; DoInsertCastOp(op, fetch_operand, result_dtype, builder); } return; @@ -607,9 +572,6 @@ class AutoMixedPrecisionPass : public pir::Pass { // Other pd ops if (OpRunLowPrecision(op)) { // change result's dtype to low precision - VLOG(6) << "Change result's dtype to low precision " << op->name() - << std::endl; - if (op->HasAttribute("dtype") && IsPhiDataTypeFloat( op->attribute("dtype") @@ -644,8 +606,6 @@ class AutoMixedPrecisionPass : public pir::Pass { auto result = op->result(i); if (!result.type()) continue; phi::DataType out_phi_dtype = output_defs[i].dtype; - VLOG(6) << "result dtype = " << phi::DataTypeToString(out_phi_dtype) - << std::endl; if (out_phi_dtype == phi::DataType::UNDEFINED) out_phi_dtype = precision_mode_; if (!IsPhiDataTypeFloat(out_phi_dtype)) @@ -663,8 +623,6 @@ class AutoMixedPrecisionPass : public pir::Pass { auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); if (IsPhiDataTypeFloat(operand_phi_dtype) && operand_phi_dtype != in_phi_dtype) { - VLOG(6) << "Support low precision, insert CastOp for " << op->name() - << " operand " << i << std::endl; DoInsertCastOp(op, operand, in_phi_dtype, builder); } } @@ -677,8 +635,6 @@ class AutoMixedPrecisionPass : public pir::Pass { auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); if (IsPhiDataTypeFloat(operand_phi_dtype) && operand_phi_dtype == precision_mode_) { - VLOG(6) << "Not support low precision, insert CastOp for " - << op->name() << " operand " << i << std::endl; DoInsertCastOp(op, operand, phi_dtype, builder); } } From f395366ea2b958761957a362d9d1aacb9bbdf699 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 15 Jan 2024 07:02:14 +0000 Subject: [PATCH 29/40] delete useless code --- paddle/fluid/framework/ir/auto_mixed_precision_pass.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 69292e18edabf0..ff1ec70ec6292c 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -42,9 +42,6 @@ bool PhiKernelSupportPrecision( phi::DataType data_type, phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { const auto& kernels = phi::KernelFactory::Instance().kernels(); - // for (auto [k, v] : kernels) { - // LOG(INFO) << "kernel name " << k << std::endl; - // } if (kernels.count(op_type) == 0) { return false; } From e48e71e0fa293e0c3435504cb27f775e3e662efc Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 15 Jan 2024 07:09:29 +0000 Subject: [PATCH 30/40] refine code --- .../transforms/auto_mixed_precision_pass.cc | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 0bf137bb09a23b..86d4b95de004ed 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -205,13 +205,12 @@ class AutoMixedPrecisionPass : public pir::Pass { } } if (!OpRunLowPrecision(op)) continue; - if (op->isa()) { // add for cast op, not cast - // to float. i.e cast to bool - // or int - // if datatype of cast op result is not float, then cast op should be - // not handled - auto result_dtype = paddle::dialect::TransToPhiDataType( - pir::GetDataTypeFromValue(op->result(0))); + // if datatype of cast op result is not float, then cast op should be + // not handled + if (op->isa()) { + t.i.e cast to bool auto result_dtype = + paddle::dialect::TransToPhiDataType( + pir::GetDataTypeFromValue(op->result(0))); if (!IsPhiDataTypeFloat(result_dtype)) { op_run_low_precision_.erase(op); op_should_not_handle_.insert(op); @@ -219,11 +218,15 @@ class AutoMixedPrecisionPass : public pir::Pass { } } if (!OpRunLowPrecision(op)) continue; - if (CheckOutputIsScalarAttribute(op)) { // Output is ScalarAttribute + // if consumer's input is a ScalarAttribute, the producer should be in + // high precision + if (CheckOutputIsScalarAttribute(op)) { op_run_low_precision_.erase(op); precision_updated = true; } if (!OpRunLowPrecision(op)) continue; + // if the producer's output is in float VectorType, then the precsion + // between two op should be the same for (size_t idx = 0; idx < op->num_operands(); ++idx) { if (!op->operand_source(idx)) continue; auto operand = op->operand(idx); @@ -232,7 +235,6 @@ class AutoMixedPrecisionPass : public pir::Pass { auto vec_type = operand.type().dyn_cast(); if (IsVectorTypeFloat(vec_type)) { auto input_operation = GetDefiningOpForInput(op, idx); - // 如果有一个是高精的话,则必须都跑在高精上 if (!op_run_low_precision_.count(op) || !op_run_low_precision_.count(input_operation)) { op_run_low_precision_.erase(op); From d1c6e0e74d07522e671f68531e0ce4334f64980b Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 15 Jan 2024 07:10:59 +0000 Subject: [PATCH 31/40] refine code --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 86d4b95de004ed..09526ec02dcb1e 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -208,9 +208,8 @@ class AutoMixedPrecisionPass : public pir::Pass { // if datatype of cast op result is not float, then cast op should be // not handled if (op->isa()) { - t.i.e cast to bool auto result_dtype = - paddle::dialect::TransToPhiDataType( - pir::GetDataTypeFromValue(op->result(0))); + auto result_dtype = paddle::dialect::TransToPhiDataType( + pir::GetDataTypeFromValue(op->result(0))); if (!IsPhiDataTypeFloat(result_dtype)) { op_run_low_precision_.erase(op); op_should_not_handle_.insert(op); From 61ce089ece63513f98c1c23d1b28e225ee15fb67 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 15 Jan 2024 10:36:38 +0000 Subject: [PATCH 32/40] fix comment --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 09526ec02dcb1e..340be3aafd30d5 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -616,7 +616,6 @@ class AutoMixedPrecisionPass : public pir::Pass { } // if any of the op's input is not in low precision, insert cast op - // input_defs will always be the smaller one? for (size_t i = 0; i < input_defs.size(); i++) { auto operand = op->operand(i); auto in_phi_dtype = input_defs[i].dtype; @@ -627,7 +626,8 @@ class AutoMixedPrecisionPass : public pir::Pass { DoInsertCastOp(op, operand, in_phi_dtype, builder); } } - } else { // current op doesn't support low precision, should cast to float + } else { + // current op doesn't support low precision // if the op's input is in low precision, insert cast op auto phi_dtype = phi::DataType::FLOAT32; for (size_t i = 0; i < op->num_operands(); i++) { From dbc02ac817e66fa67cf1f5f31fe5767eec1cfcb8 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 15 Jan 2024 14:31:18 +0000 Subject: [PATCH 33/40] replace cc_test with paddle_test --- test/cpp/pir/pattern_rewrite/CMakeLists.txt | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 5bc063240fb4af..0df9cd88ad4347 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -3,10 +3,14 @@ cc_test( SRCS pattern_rewrite_test.cc DEPS gtest op_dialect_vjp pir pir_transforms) -cc_test( +paddle_test( auto_mixed_precision_test - SRCS auto_mixed_precision_test.cc - DEPS gtest pir pir_transforms) + SRCS + auto_mixed_precision_test.cc + DEPS + gtest + pir + pir_transforms) cc_test( drr_test From 16a5ef6cafe91b733e55dc3585e410402e43314c Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 16 Jan 2024 05:31:38 +0000 Subject: [PATCH 34/40] rm modify of auto_mixed_precision_pass --- paddle/fluid/framework/ir/auto_mixed_precision_pass.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index ff1ec70ec6292c..fb75c18a6fae65 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -26,8 +26,6 @@ #include "paddle/phi/backends/device_manager.h" #endif -PHI_DECLARE_bool(enable_pir_in_executor); - namespace paddle { namespace framework { namespace ir { @@ -272,9 +270,6 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const { } void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { - if (FLAGS_enable_pir_in_executor) { - return; - } PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::PreconditionNotMet( "During the auto_mixed_precision_pass, the graph " From 59f4f5a5f91ab666bd8515d35a0b5244f92d1bd8 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 16 Jan 2024 05:32:40 +0000 Subject: [PATCH 35/40] modify cpp file header --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc | 2 +- paddle/fluid/pir/transforms/auto_mixed_precision_pass.h | 2 +- test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 340be3aafd30d5..5cfbf9270dec16 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h index 4ab0fb12cb723a..888eb64c1a6113 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc index 8c16fd37dd34a8..695078d8fbd809 100644 --- a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc +++ b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 0e8b6b786b66d824688bb9e1405e374af9807ee1 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 16 Jan 2024 06:57:53 +0000 Subject: [PATCH 36/40] fix bug: do not rewrite output in some cases --- .../transforms/auto_mixed_precision_pass.cc | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc index 5cfbf9270dec16..8431517a4db244 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc @@ -572,16 +572,6 @@ class AutoMixedPrecisionPass : public pir::Pass { // Other pd ops if (OpRunLowPrecision(op)) { - // change result's dtype to low precision - if (op->HasAttribute("dtype") && - IsPhiDataTypeFloat( - op->attribute("dtype") - .data())) { - pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( - builder.ir_context(), precision_mode_); - op->set_attribute("dtype", attr_dtype); - } - auto phi_kernel = GetPhiKernelInPrecision(op_type, backend, precision_mode_); PADDLE_ENFORCE( @@ -596,6 +586,38 @@ class AutoMixedPrecisionPass : public pir::Pass { auto input_defs = args_def.input_defs(); auto output_defs = args_def.output_defs(); + // if any of the op's input is not in low precision, insert cast op + for (size_t i = 0; i < input_defs.size(); i++) { + auto operand = op->operand(i); + auto in_phi_dtype = input_defs[i].dtype; + if (!IsOperandHasDenseTensorType(operand)) continue; + auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); + if (IsPhiDataTypeFloat(operand_phi_dtype) && + operand_phi_dtype != in_phi_dtype) { + DoInsertCastOp(op, operand, in_phi_dtype, builder); + } + } + + // change result's dtype to low precision + if (op->HasAttribute("dtype")) { + auto phi_dtype = op->attribute("dtype") + .dyn_cast() + .data(); + if (IsPhiDataTypeFloat(phi_dtype)) { + pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( + builder.ir_context(), precision_mode_); + op->set_attribute("dtype", attr_dtype); + } else if (phi_dtype == + phi::DataType::UNDEFINED) { // dtype is not set, means all + // ok + pir::Attribute attr_dtype = paddle::dialect::DataTypeAttribute::get( + builder.ir_context(), precision_mode_); + op->set_attribute("dtype", attr_dtype); + } else { + return; // don't modify output dtype + } + } + PADDLE_ENFORCE_EQ( op->num_results(), output_defs.size(), @@ -614,18 +636,6 @@ class AutoMixedPrecisionPass : public pir::Pass { // type SetResultDataType(result, out_phi_dtype, builder.ir_context()); } - - // if any of the op's input is not in low precision, insert cast op - for (size_t i = 0; i < input_defs.size(); i++) { - auto operand = op->operand(i); - auto in_phi_dtype = input_defs[i].dtype; - if (!IsOperandHasDenseTensorType(operand)) continue; - auto operand_phi_dtype = GetPhiDataTypeFromOpOperand(operand); - if (IsPhiDataTypeFloat(operand_phi_dtype) && - operand_phi_dtype != in_phi_dtype) { - DoInsertCastOp(op, operand, in_phi_dtype, builder); - } - } } else { // current op doesn't support low precision // if the op's input is in low precision, insert cast op From 29cc05e1b8532cf54600f08beb7455ca7950f5a3 Mon Sep 17 00:00:00 2001 From: yxy Date: Sat, 20 Jan 2024 14:30:29 +0000 Subject: [PATCH 37/40] delete useless code --- test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc index 695078d8fbd809..004341cbe30fc0 100644 --- a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc +++ b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/pir/core/builtin_dialect.h" From ef15b1a085d87692d0b7df41ead005e68dda4094 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 22 Jan 2024 02:58:29 +0000 Subject: [PATCH 38/40] recover CMakeLists.txt --- test/cpp/pir/pattern_rewrite/CMakeLists.txt | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index a6ca1435b5fa6d..0a40d380317e0d 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -3,14 +3,10 @@ cc_test( SRCS pattern_rewrite_test.cc DEPS gtest op_dialect_vjp pir pir_transforms) -paddle_test( +cc_test( auto_mixed_precision_test - SRCS - auto_mixed_precision_test.cc - DEPS - gtest - pir - pir_transforms) + SRCS auto_mixed_precision_test.cc + DEPS gtest pir pir_transforms) cc_test( drr_test From aefe1b7afe580be97afd767a5e2346143b7acaec Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 22 Jan 2024 03:26:57 +0000 Subject: [PATCH 39/40] delete header --- paddle/fluid/pir/transforms/auto_mixed_precision_pass.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h index 888eb64c1a6113..5d28438c5d9690 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h +++ b/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h @@ -15,8 +15,6 @@ #pragma once #include -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/place.h" #include "paddle/pir/core/dll_decl.h" namespace pir { From e9c6516aa810dcc4710957883e0e06c9d64385d8 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 22 Jan 2024 10:49:10 +0000 Subject: [PATCH 40/40] rm test --- test/cpp/pir/pattern_rewrite/CMakeLists.txt | 5 - .../auto_mixed_precision_test.cc | 109 ------------------ 2 files changed, 114 deletions(-) delete mode 100644 test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 0a40d380317e0d..359950e796d155 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -3,11 +3,6 @@ cc_test( SRCS pattern_rewrite_test.cc DEPS gtest op_dialect_vjp pir pir_transforms) -cc_test( - auto_mixed_precision_test - SRCS auto_mixed_precision_test.cc - DEPS gtest pir pir_transforms) - cc_test( drr_test SRCS drr_test.cc diff --git a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc b/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc deleted file mode 100644 index 004341cbe30fc0..00000000000000 --- a/test/cpp/pir/pattern_rewrite/auto_mixed_precision_test.cc +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" -#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/pir/core/builtin_dialect.h" -#include "paddle/pir/pass/pass.h" -#include "paddle/pir/pass/pass_manager.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" - -void BuildProgram(pir::Builder &builder) { // NOLINT - paddle::dialect::FullOp full_input_op = - builder.Build(std::vector{4, 3, 16, 16}, - 1.5, - phi::DataType::FLOAT32, - phi::CPUPlace()); - - paddle::dialect::FullOp full_filter_op = - builder.Build(std::vector{64, 3, 3, 3}, - 1.5, - phi::DataType::FLOAT32, - phi::CPUPlace()); - - paddle::dialect::FullOp full_mean_op = builder.Build( - std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); - - paddle::dialect::FullOp full_variance_op = - builder.Build(std::vector{64}, - 1.5, - phi::DataType::FLOAT32, - phi::CPUPlace()); - - paddle::dialect::FullOp full_scale_op = - builder.Build(std::vector{64}, - 1.5, - phi::DataType::FLOAT32, - phi::CPUPlace()); - - paddle::dialect::FullOp full_bias_op = builder.Build( - std::vector{64}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); - - paddle::dialect::Conv2dOp conv2d_op = - builder.Build(full_input_op.out(), - full_filter_op.out()); - - paddle::dialect::BatchNormOp batch_norm_op = - builder.Build(conv2d_op.out(), - full_mean_op.out(), - full_variance_op.out(), - full_scale_op.out(), - full_bias_op.out(), - true, - 0.9, - 1e-6, - "NCHW", - false, - false); - - auto transpose1_op = builder.Build( - batch_norm_op.out(), std::vector{0, 2, 3, 1}); - - auto transpose2_op = builder.Build( - transpose1_op.out(), std::vector{0, 3, 1, 2}); - - builder.Build(transpose2_op.out(), "out", 0); -} - -TEST(AutoMixedPrecisonTest, MixedPrecisionTest) { - pir::IrContext *ctx = pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - pir::Program program(ctx); - pir::Builder builder = pir::Builder(ctx, program.block()); - BuildProgram(builder); - - EXPECT_EQ(program.block()->size(), 11u); - - pir::PassManager pm(ctx); - std::unique_ptr auto_mixed_precision_pass = - pir::CreateAutoMixedPrecisionPass(); - phi::Place place = phi::GPUPlace(); - phi::DataType data_type = phi::DataType::FLOAT16; - auto_mixed_precision_pass->SetNotOwned(pir::kPlaceAttr, &place); - auto_mixed_precision_pass->SetNotOwned("__mixed_precision_mode__", - &data_type); - pm.AddPass(std::move(auto_mixed_precision_pass)); - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - // pm.EnablePassTiming(); - pm.EnableIRPrinting(); - - CHECK_EQ(pm.Run(&program), true); -}