From 0ee3b5577779672d054d69786f8c5f63608be202 Mon Sep 17 00:00:00 2001 From: YanXu Date: Mon, 19 Aug 2024 16:02:43 +0800 Subject: [PATCH 1/4] init auto remat(offloading) in dynamic shape graph --- tao_compiler/mlir/disc/BUILD | 31 +- tao_compiler/mlir/disc/disc_compiler.cc | 14 +- tao_compiler/mlir/disc/disc_compiler_main.cc | 2 + .../mlir/disc/transforms/disc_offloading.cc | 624 ++++++++++++++++++ .../mlir/disc/transforms/disc_passes.td | 5 + .../mlir/disc/transforms/disc_remat_utils.cc | 446 +++++++++++++ .../mlir/disc/transforms/disc_remat_utils.h | 128 ++++ .../transforms/disc_shape_optimization.cc | 24 +- .../disc_shape_optimization_utils.cc | 54 +- .../disc_shape_optimization_utils.h | 10 +- .../disc/transforms/disc_shape_propagate.cc | 429 ++++++++++-- .../disc/transforms/lhlo_elemental_utils.cc | 3 + tao_compiler/mlir/disc/transforms/passes.h | 1 + .../mlir/disc/transforms/shape_utils.cc | 17 + .../mlir/disc/transforms/shape_utils.h | 1 + 15 files changed, 1697 insertions(+), 92 deletions(-) create mode 100644 tao_compiler/mlir/disc/transforms/disc_offloading.cc create mode 100644 tao_compiler/mlir/disc/transforms/disc_remat_utils.cc create mode 100644 tao_compiler/mlir/disc/transforms/disc_remat_utils.h diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index d3dfd3fbc9b..2a05e43b130 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -341,6 +341,7 @@ cc_library( ":codegen_utils", ":disc_shape_optimization_utils", ":disc_util", + ":disc_offloading", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:map_lmhlo_to_scalar_op", @@ -2228,7 +2229,35 @@ cc_library( ], alwayslink = 1, ) - +cc_library( + name = "disc_offloading", + srcs = [ + "transforms/disc_offloading.cc", + "transforms/disc_remat_utils.cc" + ], + hdrs = [ + "transforms/passes.h", + "transforms/disc_remat_utils.h" + ], + includes = ["include"], + deps = [ + ":disc_ral", + ":disc_util", + ":mhlo_disc", + ":pass_details", + ":shape_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:map_lmhlo_to_scalar_op", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Transforms", + ] +) cc_library( name = "disc_custom_call_rewriter", srcs = ["transforms/disc_custom_call_rewriter.cc"], diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index e8f0eb06ae3..f329721b594 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -242,10 +242,9 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); - + pm.addPass(disc_ral::createDiscShapePropagatePass()); pm.addNestedPass(disc_ral::createDiscAlgebraicSimplifierPass()); pm.addPass(disc_ral::createDiscInputOutputAliasPass()); - pm.addPass(disc_ral::createDiscShapePropagatePass()); pm.addPass(mlir::createInlinerPass()); // TODO(disc): Lower HLO shape constraints instead of eliding them here. pm.addNestedPass(disc_ral::createDiscCollectiveOpsRewriterPass()); @@ -269,8 +268,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass( disc_ral::createDiscLowerQuantizeAndDequantizePass()); } - bool enable_shape_constraint_ir = useShapeConstraintIR(); + if (!enable_shape_constraint_ir) { // propagate some known shape information. pm.addPass(disc_ral::createDiscShapeSimplifierPass()); @@ -279,7 +278,6 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { // shape-related optimization pm.addPass(disc_ral::createDiscShapeOptimizationPass()); } - pm.addNestedPass(disc_ral::createDiscConvertTensorToStandardPass()); pm.addNestedPass(disc_ral::createDiscConvertHloToStandardPass()); pm.addNestedPass(createCanonicalizerPass()); @@ -500,9 +498,9 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { if (gpu_enabled) { // TODO: Support cpu stitch with splat const pm.addNestedPass(disc_ral::createDiscFuseSplatConstPass()); - pm.addNestedPass( - disc_ral::createDiscSpecializeFusionWithSpeculationPass( - gpu_options.sm_count, gpu_options.max_threads_per_sm)); + // pm.addNestedPass( + // disc_ral::createDiscSpecializeFusionWithSpeculationPass( + // gpu_options.sm_count, gpu_options.max_threads_per_sm)); } else { pm.addNestedPass( disc_ral::createDiscDuplicateComputationAfterFusionPass()); @@ -545,6 +543,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(disc_ral::createDiscBufferDeallocationPass()); pm.addPass(disc_ral::createRalInjectExecutionContextPass()); + // pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass()); + pm.addNestedPass(disc_ral::createDiscOffloadingPass()); pm.addNestedPass( disc_ral::createDiscLowerToLibraryCallPass(gpu_enabled)); pm.addPass(disc_ral::createDiscConstToRALPass(options.metadata_file_path)); diff --git a/tao_compiler/mlir/disc/disc_compiler_main.cc b/tao_compiler/mlir/disc/disc_compiler_main.cc index 366eebb88b1..97f065366c3 100644 --- a/tao_compiler/mlir/disc/disc_compiler_main.cc +++ b/tao_compiler/mlir/disc/disc_compiler_main.cc @@ -210,11 +210,13 @@ int RealMain() { << " s.\n"; llvm::dbgs() << "[[ INFO ]] Running TF2XLA\n"; + /* auto s = tensorflow::ConvertTF2MlirHlo(module); if (!s.ok()) { llvm::dbgs() << "ConvertTF2MlirHlo failed: " << s.ToString() << "\n"; return 1; } + */ if (VLOG_IS_ON(0)) { llvm::dbgs() << "======== BEGIN After TF2HLO =========\n"; diff --git a/tao_compiler/mlir/disc/transforms/disc_offloading.cc b/tao_compiler/mlir/disc/transforms/disc_offloading.cc new file mode 100644 index 00000000000..357b80ec848 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_offloading.cc @@ -0,0 +1,624 @@ +// Copyright 2024 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/strings/str_split.h" +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project +#include "mlir/IR/Dominance.h" +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/disc_remat_utils.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +namespace disc_ral { +constexpr StringRef kRematBlockTypeAttr = "disc.remat.type"; +constexpr StringRef kRematBufferAttr = "disc.remat.dummy-buffer"; +constexpr StringRef kRematMinSymDim = "disc.remat.min-sym-dim"; + +struct DiscOffloadingPass : public DiscOffloadingPassBase { + DiscOffloadingPass() + : DiscOffloadingPassBase::DiscOffloadingPassBase() {} + void getDependentDialects(DialectRegistry& registry) const override { + DiscOffloadingPassBase::getDependentDialects(registry); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + void runOnOperation() override; + void InsertRematBlock(mlir::OpBuilder& b, LivingBuffer& livingBuffer, + Value rematCond, std::vector& ops, + int64_t minSymValue); + void InsertOffloadingOp(mlir::OpBuilder& rewriter, Operation* prevOp, + Operation* op, Value buffer, + std::vector consumers, Value symbS0, + int64_t cstMinS0); +}; +Location getFusionLocation(OpBuilder& b, Operation* op) { + if (auto fusionOp = dyn_cast(op->getParentOp())) { + b.setInsertionPointAfter(fusionOp); + return fusionOp->getLoc(); + } + b.setInsertionPointAfter(op); + return op->getLoc(); +} + +using FuncOp = mlir::func::FuncOp; + +SymbolicDimProduct getSymbolicMemRefSize(Value value, SymbolicDimMgr* mgr, + MLIRContext* ctx) { + auto memRefType = value.getType().cast(); + // get symbolic dims of the memref + // auto symbolics = getSymbolicDims(value); + auto symbolics = mgr->getOrCreateSymbolicDimsForRankedValue(value); + SymbolicDimProduct prod{symbolics}; + return prod; +} +/* +int64_t getMemRefSize(Value value) { + auto memRefType = value.getType().cast(); + int64_t elementSize = getElementSize(memRefType.getElementType()); + if (elementSize < 0) { + return -1; // Unsupported type + } + + int64_t numElements = 1; + for (int64_t dim : memRefType.getShape()) { + numElements *= dim; + } + return numElements * elementSize; +} +*/ + +bool shouldSkipBufferInPeakMemoryEstimator(Value value) { + if (IsHostBuffer(value)) return true; + + // skip buffer if it is a temp buffer which only used inside of a fusion op + // alloc = memref.alloc + // lmhlo.fusion() { + // op1(buffer0, buffer1, alloc) + // op2(alloc, buffer2, buffer3) + // } + // dealloc = memref.dealloc alloc + Operation* prevFusionOp = nullptr; + for (auto user : value.getUsers()) { + if (isa(user)) continue; + if (auto fusionOp = dyn_cast(user->getParentOp())) { + if (!prevFusionOp) prevFusionOp = fusionOp; + if (prevFusionOp != fusionOp) { + return false; + } + } else { + return false; + } + } + return true; +} + +bool IsDynamicShapeBuffer(Value buffer) { + auto memrefType = buffer.getType().cast(); + for (auto dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) { + return true; + } + } + return false; +} +Value createAllocOp(OpBuilder& b, Location loc, Value refBuffer) { + MemRefType type = refBuffer.getType().cast(); + SmallVector dynShape; + for (size_t i = 0; i < type.getRank(); i++) { + if (type.getShape()[i] == ShapedType::kDynamic) { + dynShape.push_back(b.create(loc, refBuffer, i)); + } + } + auto allocOp = b.create(loc, type, dynShape); + StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); + if (refBuffer.getDefiningOp()->hasAttr(attrName)) { + allocOp.getOperation()->setAttr( + attrName, refBuffer.getDefiningOp()->getAttr(attrName)); + } + return allocOp.getResult(); +} + +void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, + LivingBuffer& livingBuffer, + Value rematCond, + std::vector& ops, + int64_t minSymValue) { + auto buffer = livingBuffer.buffer; + auto startOp = livingBuffer.start; + auto endOp = livingBuffer.end; + StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); + auto deviceMemrefType = buffer.getType().cast(); + auto hostMemrefType = MemRefType::get( + deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0); + // offloading to host + // get dynamic dim of buffer and insert into ValueRange + // if remat_cond: + // yield offload(buffer) + // else: + // yield dummy_buffer + auto offloadIfOp = + b.create(startOp->getLoc(), + /*resultTypes*/ hostMemrefType, rematCond, + /*hasElseRegion*/ true); + offloadIfOp.getOperation()->setAttr( + attrName, buffer.getDefiningOp()->getAttr(attrName)); + offloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("offload")); + offloadIfOp.getOperation()->setAttr(kRematMinSymDim, + b.getI64IntegerAttr(minSymValue)); + offloadIfOp.getThenRegion().front().clear(); + b.setInsertionPointToEnd(&offloadIfOp.getThenRegion().front()); + auto hostBuffer = createAllocOp(b, endOp->getLoc(), buffer); + hostBuffer.setType(MemRefType::get(deviceMemrefType.getShape(), + deviceMemrefType.getElementType(), {}, 0)); + b.create(endOp->getLoc(), buffer, hostBuffer); + b.create(endOp->getLoc(), buffer); + b.create(endOp->getLoc(), hostBuffer); + + offloadIfOp.getElseRegion().front().clear(); + b.setInsertionPointToStart(&offloadIfOp.getElseRegion().front()); + auto dummyHostBuffer = createAllocOp(b, endOp->getLoc(), buffer); + dummyHostBuffer.getDefiningOp()->setAttr(kRematBufferAttr, + b.getBoolAttr(true)); + dummyHostBuffer.setType(MemRefType::get( + deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0)); + b.create(endOp->getLoc(), dummyHostBuffer); + b.setInsertionPointAfter(offloadIfOp); + + if (auto fusionOp = dyn_cast(endOp->getParentOp())) { + b.setInsertionPoint(fusionOp); + } else { + b.setInsertionPoint(endOp); + } + + // insert reload block + scf::IfOp reloadIfOp = + b.create(endOp->getLoc(), + /*resultTypes*/ deviceMemrefType, rematCond, + /*hasElseRegion*/ true); + if (buffer.getDefiningOp()->hasAttr(attrName)) { + reloadIfOp.getOperation()->setAttr( + attrName, buffer.getDefiningOp()->getAttr(attrName)); + } + reloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("reload")); + reloadIfOp.getOperation()->setAttr(kRematMinSymDim, + b.getI64IntegerAttr(minSymValue)); + reloadIfOp.getThenRegion().front().clear(); + b.setInsertionPointToStart(&reloadIfOp.getThenRegion().front()); + + auto deviceBuffer = createAllocOp(b, endOp->getLoc(), buffer); + deviceBuffer.setType(deviceMemrefType); + auto h2dOp = b.create( + endOp->getLoc(), offloadIfOp.getResult(0), deviceBuffer); + b.create(endOp->getLoc(), deviceBuffer); + reloadIfOp.getElseRegion().front().clear(); + b.setInsertionPointToStart(&reloadIfOp.getElseRegion().front()); + auto dummyDeviceBuffer = createAllocOp(b, endOp->getLoc(), buffer); + dummyDeviceBuffer.setType(deviceMemrefType); + dummyDeviceBuffer.getDefiningOp()->setAttr( + b.getStringAttr("disc.remat.dummy-buffer"), b.getBoolAttr(true)); + b.create(endOp->getLoc(), buffer); + for (auto pair : ops) { + auto op = pair.first; + auto position = pair.second; + if (position > livingBuffer.end_position) { + for (size_t i = 0; i < op->getNumOperands(); i++) { + if (op->getOperand(i) == buffer) { + op->setOperand(i, reloadIfOp.getResult(0)); + } + } + } + } +} +void DiscOffloadingPass::InsertOffloadingOp(mlir::OpBuilder& b, + Operation* prevOp, Operation* op, + Value buffer, + std::vector consumers, + Value symbS0, int64_t cstMinS0) { + Location loc = prevOp->getLoc(); + if (auto fusionOp = dyn_cast(prevOp->getParentOp())) { + b.setInsertionPointAfter(fusionOp); + loc = fusionOp.getLoc(); + } else { + b.setInsertionPointAfter(prevOp); + } + auto offloadCond = b.create( + loc, arith::CmpIPredicate::sgt, symbS0, + b.create(op->getLoc(), cstMinS0)); + b.setInsertionPointAfter(offloadCond); + StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); + auto deviceMemrefType = buffer.getType().cast(); + auto hostMemrefType = MemRefType::get( + deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0); + // offloading to host + // get dynamic dim of buffer and insert into ValueRange + auto offloadIfOp = + b.create(op->getLoc(), + /*resultTypes*/ hostMemrefType, offloadCond, + /*hasElseRegion*/ true); + offloadIfOp.getOperation()->setAttr( + attrName, buffer.getDefiningOp()->getAttr(attrName)); + offloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("offload")); + offloadIfOp.getOperation()->setAttr(kRematMinSymDim, + b.getI64IntegerAttr(cstMinS0)); + offloadIfOp.getThenRegion().front().clear(); + b.setInsertionPointToEnd(&offloadIfOp.getThenRegion().front()); + auto hostBuffer = createAllocOp(b, op->getLoc(), buffer); + hostBuffer.setType(MemRefType::get(deviceMemrefType.getShape(), + deviceMemrefType.getElementType(), {}, 0)); + b.create(op->getLoc(), buffer, hostBuffer); + b.create(op->getLoc(), buffer); + b.create(op->getLoc(), hostBuffer); + + offloadIfOp.getElseRegion().front().clear(); + b.setInsertionPointToStart(&offloadIfOp.getElseRegion().front()); + auto dummyHostBuffer = createAllocOp(b, op->getLoc(), buffer); + dummyHostBuffer.getDefiningOp()->setAttr(kRematBufferAttr, + b.getBoolAttr(true)); + dummyHostBuffer.setType(MemRefType::get( + deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0)); + b.create(op->getLoc(), dummyHostBuffer); + b.setInsertionPointAfter(offloadIfOp); + + if (auto fusionOp = dyn_cast(op->getParentOp())) { + b.setInsertionPoint(fusionOp); + } else { + b.setInsertionPoint(op); + } + + // insert reload block + scf::IfOp reloadIfOp = + b.create(op->getLoc(), + /*resultTypes*/ deviceMemrefType, offloadCond, + /*hasElseRegion*/ true); + if (buffer.getDefiningOp()->hasAttr(attrName)) { + reloadIfOp.getOperation()->setAttr( + attrName, buffer.getDefiningOp()->getAttr(attrName)); + } + reloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("reload")); + reloadIfOp.getOperation()->setAttr(kRematMinSymDim, + b.getI64IntegerAttr(cstMinS0)); + reloadIfOp.getThenRegion().front().clear(); + b.setInsertionPointToStart(&reloadIfOp.getThenRegion().front()); + + auto deviceBuffer = createAllocOp(b, op->getLoc(), buffer); + deviceBuffer.setType(deviceMemrefType); + auto h2dOp = b.create( + op->getLoc(), offloadIfOp.getResult(0), deviceBuffer); + b.create(op->getLoc(), deviceBuffer); + reloadIfOp.getElseRegion().front().clear(); + b.setInsertionPointToStart(&reloadIfOp.getElseRegion().front()); + auto dummyDeviceBuffer = createAllocOp(b, op->getLoc(), buffer); + dummyDeviceBuffer.setType(deviceMemrefType); + dummyDeviceBuffer.getDefiningOp()->setAttr( + b.getStringAttr("disc.remat.dummy-buffer"), b.getBoolAttr(true)); + b.create(op->getLoc(), buffer); + + for (size_t i = 0; i < consumers.size(); i++) { + auto consumer = consumers[i]; + for (size_t j = 0; j < consumer->getNumOperands(); j++) { + if (consumer->getOperand(j) == buffer) { + consumer->setOperand(j, reloadIfOp.getResult(0)); + } + } + } +} +Value getSymBufferSize(OpBuilder& b, Location loc, Value buffer, + ShapeConstraintIRAnalysis& shapeAnalysis) { + int64_t factor = 1; + SmallVector symbols; + SmallVector dimValues; + auto memrefType = buffer.getType().cast(); + for (size_t i = 0; i < memrefType.getRank(); ++i) { + auto dim = memrefType.getShape()[i]; + if (dim == ShapedType::kDynamic) { + dimValues.push_back(b.create(loc, buffer, i)); + } else { + factor *= dim; + } + } + Value numal = b.create(loc, factor); + for (auto dim : dimValues) { + numal = b.create(loc, numal, dim); + } + return numal; +} + +struct SymbolicDimProductSum { + llvm::SmallVector symbolProds; +}; + +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const SymbolicDimProduct& prod) { + // print prod + os << prod.factor; + if (prod.symbols.size() > 0) os << "*"; + for (size_t j = 0; j < prod.symbols.size(); ++j) { + auto dimOp = prod.symbols[j]; + if (j != prod.symbols.size() - 1) { + os << dimOp.getName() << "*"; + } else { + os << dimOp.getName(); + } + } + return os; +} +/* +SymbolicDimProductSum symbolicDimProductSub(SymbolicDimProductSum sum, + SymbolicDimProduct b) { + bool mergeSymbols = false; + SymbolicDimProductSum result; + + for (auto symbolProd : sum.symbolProds) { + if (symbolProd.symbols == b.symbols) { + mergeSymbols = true; + symbolProd.factor -= b.factor; + if (symbolProd.factor == 0) { + continue; + } + result.symbolProds.push_back(symbolProd); + } else { + result.symbolProds.push_back(symbolProd); + } + } + if (!mergeSymbols) { + b.factor *= -b.factor; + result.symbolProds.push_back(b); + } + return result; +} +SymbolicDimProductSum symbolicDimProductSumAdd(SymbolicDimProductSum sum, + SymbolicDimProduct b) { + SymbolicDimProductSum result; + bool mergeSymbols = false; + for (auto symbolProd : sum.symbolProds) { + if (symbolProd.symbols == b.symbols) { + mergeSymbols = true; + symbolProd.factor += b.factor; + result.symbolProds.push_back(symbolProd); + } else { + result.symbolProds.push_back(symbolProd); + } + } + if (!mergeSymbols) { + result.symbolProds.push_back(b); + } + return result; +} +*/ +// Dumps SymbolicDimProduct to the output stream + +std::tuple solveQuadratic(int64_t A, int64_t B, + int64_t C) { + if (A == 0) { + throw std::invalid_argument( + "Coefficient A cannot be zero in a quadratic equation."); + } + + double discriminant = B * B - 4 * A * C; + // no solution if b^2 - 4ac < 0 + if (discriminant < 0) { + return std::make_tuple(false, 0.0, 0.0); + } + + double sqrtDiscriminant = std::sqrt(discriminant); + + // compute x1, x2 + int64_t X1 = (-B + sqrtDiscriminant) / (2 * A); + int64_t X2 = (-B - sqrtDiscriminant) / (2 * A); + return std::make_tuple(true, X1, X2); +} +int64_t findMinSymbolicDimValue(MemoryUsage memoryPeakExpr, + int64_t memoryLimitation) { + int64_t a, b, c = 0; + // memoryPeakExpr = a * S0^2 + b * S0 + c + // memoryPeakExpr >= memoryLimitation + // a * S0^2 + b * S0 + c - memoryLimitation >= 0 + for (auto prod : memoryPeakExpr) { + if (prod.symbols.size() == 0) { + c = prod.factor; + } else if (prod.symbols.size() == 1) { + b = prod.factor; + } else if (prod.symbols.size() == 2) { + a = prod.factor; + } + } + c -= memoryLimitation; + auto [hasSolution, x0, x1] = solveQuadratic(a, b, c); + if (!hasSolution) { + throw std::invalid_argument("No solution for the quadratic equation."); + } + return std::max(x0, x1); +} + +int64_t getConcretValuewithCst(SymbolicDimProductSum prod, int64_t cstValue) { + int64_t factor = 1; + for (auto symProd : prod.symbolProds) { + if (symProd.symbols.size() == 0) { + factor += symProd.factor; + } else if (symProd.symbols.size() == 1) { + factor += symProd.factor * cstValue; + } else if (symProd.symbols.size() == 2) { + factor += symProd.factor * cstValue * cstValue; + } + } + return factor; +} +bool inRematOffloadBlock(Value value) { + if (auto ifOp = dyn_cast(value.getDefiningOp()->getParentOp())) { + auto blockType = + ifOp.getOperation()->getAttrOfType(kRematBlockTypeAttr); + if (blockType && blockType.getValue() == "offload") { + return true; + } + } + return false; +} + +std::optional findSymbolicDimValue(FuncOp& main, + const std::string& key) { + SmallVector dynInputs; + std::unordered_map symValueMap; + main.walk([&](Operation* op) { + auto allocOp = dyn_cast(op); + if (!allocOp) { + return; + } + if (op->getNumOperands() == 0) { + return; + } + auto attrs = + op->getAttrOfType(SymbolicDimOp::getSymbolicDimAttrName()); + int symbDimIndex = 0; + for (auto attr : attrs) { + auto name = attr.cast().getValue(); + if (name.startswith("S")) { + if (symValueMap.count(name.str()) == 0) { + symValueMap[name.str()] = op->getOperand(symbDimIndex++); + } + } + } + return; + }); + if (symValueMap.count(key) == 0) { + return std::nullopt; + } + return symValueMap[key]; +} +Value InsertRematCond(mlir::OpBuilder& b, Location loc, Value s0, + int64_t minS0) { + auto offloadCond = b.create( + s0.getLoc(), arith::CmpIPredicate::sgt, s0, + b.create(s0.getLoc(), minS0)); + return offloadCond; +} +vector FilterLivingBuffers( + std::vector livingBuffers) { + std::vector result; + for (auto lb : livingBuffers) { + // TODO(yancey): just for experiment, let's remove this condition in the + // future + if (!IsDynamicShapeBuffer(lb.buffer)) { + continue; + } + // filter buffer if already in remat block + if (isa((lb.start->getParentOp()))) continue; + result.push_back(lb); + } + return result; +} +void SortByPrioriy(std::vector& livingBuffers) { + std::sort(livingBuffers.begin(), livingBuffers.end(), + [](const LivingBuffer& a, const LivingBuffer& b) { + return a.living_range > b.living_range; + }); +} +std::optional PickHighestPriorityLivingBuffer( + std::vector& livingBuffers) { + // step1: filter living buffers which can not reduce the peak memory or too + // small buffer + auto buffers = FilterLivingBuffers(livingBuffers); + if (buffers.size() == 0) { + return std::nullopt; + } + // step2: sort living buffers by priority, e.g. living range value + SortByPrioriy(buffers); + return buffers[0]; +} +void DiscOffloadingPass::runOnOperation() { + FuncOp main = getOperation(); + if (main.getName() == SymbolicDimMgr::getShapeConstraintGraphFunctionName()) + return; + mlir::OpBuilder b(main); + const int64_t memoryLimitation = 21474836480; // 30GB + llvm::dbgs() << "memory limitation: " << memoryLimitation << "\n"; + // 1. find all buffer live-range + bool changed = true; + int maxIteration = 200; + std::unique_ptr profiler( + new SymbolicMemoryProfiler(main)); + + std::unique_ptr shapeAnalysisPtr; + std::unique_ptr bufferLivingRange; + bufferLivingRange.reset(new DiscBufferLivingRange(main)); + shapeAnalysisPtr.reset(new ShapeConstraintIRAnalysis(main)); + + auto symS0Value = findSymbolicDimValue(main, "S0"); + if (!symS0Value) { + llvm::errs() << "failed to find S0 value\n"; + return; + } + while (changed && maxIteration--) { + if (failed(profiler->Analysis())) { + llvm::errs() << "failed to analysis\n"; + return; + } + if (failed(bufferLivingRange->Analysis())) { + llvm::errs() << "failed to analysis buffer living range\n"; + return; + } + changed = false; + auto memoryPeakExpr = profiler->GetPeakMemory(); + int64_t minS0 = findMinSymbolicDimValue(memoryPeakExpr, memoryLimitation); + + llvm::dbgs() << "min s0: " << minS0 << "\n"; + auto livingBuffer = PickHighestPriorityLivingBuffer( + std::vector & livingBuffers); + if (auto buffer = livingBuffer.value()) { + llvm::dbgs() << "living range: " << buffer.living_range << "\n"; + auto startOp = buffer.start; + auto endOp = buffer.end; + auto buffer = buffer.buffer; + auto users = bufferLivingRange->GetUsersOrderByPosition(buffer.buffer); + + auto loc = getFusionLocation(b, startOp); + auto rematCond = InsertRematCond(b, loc, symS0Value.value(), minS0); + InsertRematBlock(b, buffer, rematCond, users, minS0); + changed = true; + } + } +} +std::unique_ptr> createDiscOffloadingPass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index 3352f7d5e74..f8fc0dece58 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -677,4 +677,9 @@ def DiscReduceBufferLiveRangePass : Pass<"disc-reduce-buffer-live-range", "mlir: def DiscShapePropagatePass : Pass<"disc-shape-propagate", "ModuleOp"> { let summary = "shape analysis pass"; let constructor = "createDiscShapePropagatePass()"; +} + +def DiscOffloadingPass : Pass<"disc-offloading", "mlir::func::FuncOp"> { + let summary = "auto offlaoding pass"; + let constructor = "createDiscOffloadingPass()"; } \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc b/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc new file mode 100644 index 00000000000..268ddb0433a --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc @@ -0,0 +1,446 @@ +// Copyright 2024 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/disc/transforms/disc_remat_utils.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/disc/IR/disc_ral_ops.h" +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +namespace mlir { +namespace disc_ral { +bool IsHostBuffer(Value value) { + auto memRefType = value.getType().cast(); + return memRefType.getMemorySpace() == 0; +} + +LogicalResult DiscBufferLivingRange::Analysis() { + buffer_list_.clear(); + living_buffers_.clear(); + int64_t position; + buffer_map_.clear(); + auto reccordOpWithPosition = [&](Value value, Operation* op) { + buffer_map_[value].push_back(std::make_pair(op, position++)); + }; + main_.walk([&](Operation* op) { + // Traverse the function's blocks and operations. + if (auto allocOp = dyn_cast(op)) { + auto buffer = allocOp.getResult(); + if (IsHostBuffer(allocOp.getResult())) { + return; + } + buffer_list_.push_back(buffer); + reccordOpWithPosition(buffer, op); + } else if (isa(op)) { + auto buffer = op->getOperand(0); + if (IsHostBuffer(buffer)) { + return; + } + reccordOpWithPosition(buffer, op); + } else if (auto returnOp = dyn_cast(op)) { + for (auto operand : returnOp.getOperands()) { + reccordOpWithPosition(operand, op); + } + } else if (isa(op) || isa(op) || + isa(op) || isa(op)) { + return; + } else { + for (Value operand : op->getOperands()) { + if (buffer_map_.count(operand)) reccordOpWithPosition(operand, op); + } + } + }); + for (auto iter : buffer_map_) { + auto buffer = iter.first; + for (size_t i = 1; i < iter.second.size(); i++) { + auto start = iter.second[i - 1].first; + auto startPosition = iter.second[i - 1].second; + + auto end = iter.second[i].first; + auto endPosition = iter.second[i].second; + LivingBuffer livingBuffer(start, startPosition, end, endPosition, buffer); + living_buffers_.push_back(livingBuffer); + } + } + return success(); +} +std::vector DiscBufferLivingRange::GetUsersOrderByPosition( + Value buffer) { + std::vector ops; + if (buffer_map_.find(buffer) == buffer_map_.end()) { + return ops; + } + return buffer_map_[buffer]; +} + +//////////////// Symbolic Memory Profiler Utils ////////////////// +bool hasDynamicDimension(Value buffer) { + auto memrefType = buffer.getType().cast(); + for (auto dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) { + return true; + } + } + return false; +} +int64_t getElementSize(Type elementType) { + if (elementType.isF32()) { + return sizeof(float); + } else if (elementType.isF16()) { + return sizeof(uint16_t); + } else if (elementType.isBF16()) { + return sizeof(uint16_t); + } else if (elementType.isF64()) { + return sizeof(double); + } else if (elementType.isInteger(1)) { + return sizeof(bool); + } else if (elementType.isInteger(8)) { + return sizeof(int8_t); + } else if (elementType.isInteger(16)) { + return sizeof(int16_t); + } else if (elementType.isInteger(32)) { + return sizeof(int32_t); + } else if (elementType.isInteger(64)) { + return sizeof(int64_t); + } else if (elementType.isIndex()) { + return sizeof(int32_t); + } else { + llvm::dbgs() << elementType << "\n"; + // Add more types as needed + llvm::errs() << "Unsupported element type\n"; + return -1; + } +} + +MemoryUsage& operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs) { + MemoryUsage result; + bool mergeSymbols = false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i].symbols == rhs.symbols) { + mergeSymbols = true; + lhs[i].factor += rhs.factor; + } + } + if (!mergeSymbols) { + lhs.push_back(rhs); + } + return lhs; +} +MemoryUsage& operator-=(MemoryUsage& lhs, const SymbolicDimProduct& rhs) { + bool mergeSymbols = false; + size_t removeIndex = -1; + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i].symbols == rhs.symbols) { + mergeSymbols = true; + lhs[i].factor -= rhs.factor; + if (lhs[i].factor == 0) { + removeIndex = i; + break; + } + } + } + if (removeIndex != -1) { + lhs.erase(lhs.begin() + removeIndex); + } + if (!mergeSymbols) { + auto newRhs = rhs; + newRhs.factor *= -1; + lhs.push_back(newRhs); + } + return lhs; +} +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const MemoryUsage& memoryUsage) { + for (size_t i = 0; i < memoryUsage.size(); ++i) { + // print prod + os << memoryUsage[i].factor; + if (memoryUsage[i].symbols.size() > 0) os << "*"; + for (size_t j = 0; j < memoryUsage[i].symbols.size(); ++j) { + auto dimOp = memoryUsage[i].symbols[j]; + if (j != memoryUsage[i].symbols.size() - 1) { + os << dimOp.getName() << "*"; + } else { + os << dimOp.getName(); + } + } + if (i != memoryUsage.size() - 1 && memoryUsage[i].factor > 0) { + os << " + "; + } + } + return os; +} +// TODO(yancey): just a PoC implementation, need to implement this function +// by traveling the shape operators +SymbolicDimProduct SimplifySymbolicDims(SymbolicDimProduct proudct, + SymbolicDimOp s0) { + auto factor = proudct.factor; + SmallVector symbols; + for (auto symbol : proudct.symbols) { + StringRef symbolName = const_cast(symbol).getName(); + if (symbolName.startswith("S")) { + if (symbolName == "S0") { + factor *= 1; + } else if (symbolName == "S1") { + factor *= 4; + } else if (symbolName == "S2") { + factor *= 32001; + } else if (symbolName == "S3") { + factor *= 128; + } else { + llvm::errs() << "unsupported symbol name\n"; + } + symbols.push_back(s0); + } else { + symbols.push_back(symbol); + } + } + return SymbolicDimProduct{symbols, factor}; +} +int64_t getMemRefSize(Value value) { + auto memRefType = value.getType().cast(); + int64_t elementSize = getElementSize(memRefType.getElementType()); + if (elementSize < 0) { + return -1; // Unsupported type + } + + int64_t numElements = 1; + for (int64_t dim : memRefType.getShape()) { + numElements *= dim; + } + return numElements * elementSize; +} +SymbolicDimProduct getSymbolicMemrefBytes(Value buffer, SymbolicDimMgr* mgr) { + auto s0 = mgr->findSymbolicDimOp("S0"); + auto ty = buffer.getType().cast(); + auto elementBytes = getElementSize(ty.getElementType()); + if (hasDynamicDimension(buffer)) { + if (auto recvOp = dyn_cast(buffer.getDefiningOp())) { + // get first user of the buffer + if (auto rcastOp = + dyn_cast(*buffer.getUsers().begin())) { + buffer = rcastOp.getResult(); + } + } + auto symDims = getMemRefValueSymbolicDims(*mgr, buffer); + SymbolicDimProduct prod{symDims.value()}; + prod = SimplifySymbolicDims(prod, s0.value()); + prod.factor *= elementBytes; + return mgr->simplifySymbolicDimProduct(prod); + } + SymbolicDimProduct prod{}; + int64_t dimProduct = 1; + for (int64_t dim : ty.getShape()) { + dimProduct *= dim; + } + prod.factor = dimProduct * elementBytes; + return prod; +} +bool isHostBuffer(Value value) { + auto memRefType = value.getType().cast(); + return memRefType.getMemorySpace() == 0; +} +Value maybeInReloadBlock(Value buffer) { + if (auto ifOp = dyn_cast(buffer.getDefiningOp()->getParentOp())) { + return ifOp.getResult(0); + } + return buffer; +} +/////////////////////// SymbolicMemoryProfiler /////////////////////// +bool SymbolicMemoryProfiler::isTempBuffer(Value value) { + Operation* prevFusionOp = nullptr; + for (auto user : value.getUsers()) { + if (isa(user)) continue; + if (auto fusionOp = dyn_cast(user->getParentOp())) { + if (!prevFusionOp) prevFusionOp = fusionOp; + if (prevFusionOp != fusionOp) { + return false; + } + } else { + return false; + } + } + return true; +} + +int64_t getConcretValuewithFakeValue(MemoryUsage memoryUsages, + int64_t cstValue) { + int64_t result = 0; + for (auto prod : memoryUsages) { + int64_t factor = prod.factor; + if (prod.symbols.size() > 0) + factor *= std::pow(cstValue, prod.symbols.size()); + result += factor; + } + return result; +} +float bytesToMB(int64_t bytes) { return bytes * 1.0 / (1024.0 * 1024.0); } +MemoryUsage findPeakMemoryWithFakeValue( + std::vector memoryUsageList, int64_t fakeValue) { + size_t instIndex = 0; + int64_t maxMemoryUsage = 0; + for (size_t i = 0; i < memoryUsageList.size(); ++i) { + int64_t memoryUsage = + getConcretValuewithFakeValue(memoryUsageList[i], fakeValue); + if (memoryUsage > maxMemoryUsage) { + maxMemoryUsage = memoryUsage; + instIndex = i; + } + } + llvm::dbgs() << "fakeValue: " << fakeValue << " maxMemory: " << maxMemoryUsage + << " bytes" + << " instIndex: " << instIndex + << " expr: " << memoryUsageList[instIndex] << "\n"; + return memoryUsageList[instIndex]; +} +std::vector ConcretMemoryUsageSimulator(int64_t concretValue) { + /* + SymbolicDimProductSum memoryUsage; + SmallVector memoryUsageList; + std::unordered_set set, skipBuffers; + mlir::OpBuilder b(main); + llvm::dbgs() << "memory usage with seqlen=" << cstS0 << "\n"; + int rematBuffers = 0; + main.walk([&](Operation* op) { + if (auto allocOp = dyn_cast(op)) { + if (allocOp.getOperation()->getAttr( + b.getStringAttr("disc.remat.dummy-buffer"))) { + return; + } + // alloc operator maybe inside a remat block + auto buffer = allocOp.getResult(); + auto reloadBuffer = maybeInReloadBlock(buffer); + if (auto rematBlock = getRematBlock(allocOp.getOperation())) { + return; + } + if (shouldSkipBufferInPeakMemoryEstimator(buffer)) { + skipBuffers.insert(reloadBuffer); + return; + } + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSumAdd(memoryUsage, bufferBytes); + memoryUsageList.push_back(memoryUsage); + set.insert(reloadBuffer); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + + } else if (auto deallocOp = dyn_cast(op)) { + auto buffer = deallocOp.getOperand(); + if (skipBuffers.count(buffer)) { + return; + } + if (auto rematBlock = getRematBlock(deallocOp.getOperation())) { + return; + } + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSub(memoryUsage, bufferBytes); + memoryUsageList.push_back(memoryUsage); + set.erase(buffer); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + } else if (auto d2hOp = dyn_cast(op)) { + auto buffer = d2hOp.getOperand(0); + auto rematBlock = getRematBlock(d2hOp.getOperation()); + auto rematBuffer = rematBlock->getResult(0); + int64_t minS0 = + rematBlock->getAttrOfType(kRematMinSymDim).getInt(); + if (cstS0 > minS0) { + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSub(memoryUsage, bufferBytes); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + } + } else if (auto h2dOp = dyn_cast(op)) { + auto buffer = h2dOp.getOperand(1); + auto rematBlock = getRematBlock(h2dOp.getOperation()); + auto rematBuffer = rematBlock->getResult(0); + int64_t minS0 = + rematBlock->getAttrOfType(kRematMinSymDim).getInt(); + if (cstS0 > minS0) { + auto rematBuffer = rematBlock->getResult(0); + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSumAdd(memoryUsage, bufferBytes); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + rematBuffers++; + } + } + }); + */ +} + +LogicalResult SymbolicMemoryProfiler::Analysis() { + mlir::OpBuilder b(main_); + std::unique_ptr shapeAnalysisPtr; + shapeAnalysisPtr.reset(new ShapeConstraintIRAnalysis(main_)); + auto shapeIRAnalysis = + dynamic_cast(shapeAnalysisPtr.get()); + if (!shapeIRAnalysis) { + llvm::errs() << "shape analysis failed\n"; + return failure(); + } + mgr_ = &shapeIRAnalysis->symbolicDimMgr(); + memory_usage_list_.clear(); + + MemoryUsage currentUsage; + std::unordered_set skipBuffers; + main_.walk([&](Operation* op) { + if (auto recvOp = dyn_cast(op)) { + auto buffer = recvOp.getResult(); + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage += bufferBytes; + memory_usage_list_.push_back(currentUsage); + } else if (auto sendOp = dyn_cast(op)) { + auto buffer = sendOp.getOperand(2); + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage -= bufferBytes; + memory_usage_list_.push_back(currentUsage); + } else if (auto allocOp = dyn_cast(op)) { + // TODO(yancey): dummy attr is a workaround implement for remat, + // better to have a custom operator + if (allocOp.getOperation()->getAttr( + b.getStringAttr("disc.remat.dummy-buffer"))) { + return; + } + // alloc operator maybe inside a remat block + auto buffer = allocOp.getResult(); + // skip the buffer if it is a temp buffer of fusion op + auto reloadBuffer = maybeInReloadBlock(buffer); + if (isHostBuffer(buffer) || isTempBuffer(buffer)) { + skipBuffers.insert(reloadBuffer); + return; + } + + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage += bufferBytes; + memory_usage_list_.push_back(currentUsage); + } else if (auto deallocOp = dyn_cast(op)) { + auto buffer = deallocOp.getOperand(); + if (skipBuffers.count(buffer)) { + return; + } + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage -= bufferBytes; + memory_usage_list_.push_back(currentUsage); + } + }); + // NOTE: search peak memory expr with fake symbolic value, please + // note, it's a fuzzy search algorithm, the result may not be accurate + // but it's good enough for most cases + peak_memory_ = findPeakMemoryWithFakeValue(memory_usage_list_, 2048); + return success(); +} + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_remat_utils.h b/tao_compiler/mlir/disc/transforms/disc_remat_utils.h new file mode 100644 index 00000000000..cddaf665a2d --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_remat_utils.h @@ -0,0 +1,128 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/disc/transforms/disc_shape_optimization_utils.h" +#include "mlir/disc/transforms/shape_utils.h" +namespace mlir { +namespace disc_ral { + +struct ValueHash { + std::size_t operator()(const Value& operand) const { + std::size_t hash = mlir::hash_value(operand); + return hash; + } +}; +using OpWithPosition = std::pair; + +bool IsHostBuffer(Value value); +// return sizeof(dtype) +int64_t getElementSize(Type elementType); +// TODO(yancey): just a PoC implementation, need to implement this function +// by traveling the shape operators +SymbolicDimProduct SimplifySymbolicDims(SymbolicDimProduct proudct, + SymbolicDimOp s0); + +struct LivingBuffer { + LivingBuffer(Operation* start, int64_t start_position, Operation* end, + int64_t end_position, Value buffer) + : start(start), + start_position(start_position), + end(end), + end_position(end_position), + buffer(buffer) { + living_range = end_position - start_position; + } + + Operation* start; + int64_t start_position; + Operation* end; + int64_t end_position; + Value buffer; + int64_t living_range; +}; + +class DiscBufferLivingRange { + public: + explicit DiscBufferLivingRange(mlir::func::FuncOp main) : main_(main) {} + + LogicalResult Analysis(); + + std::vector GetLivingBuffers() { return living_buffers_; } + // get all users of the buffer, ordered by position + std::vector GetUsersOrderByPosition(Value buffer); + + private: + mlir::func::FuncOp main_; + std::vector living_buffers_; + // mapping buffer to operator and position + std::unordered_map>, + ValueHash> + buffer_map_; + std::vector buffer_list_; +}; + +using MemoryUsage = llvm::SmallVector; + +// MemoryUsage operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); +MemoryUsage& operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); +MemoryUsage& operator-=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const MemoryUsage& memoryUsage); + +// SymbolicMemoryProfiler is used to analyze the memory usage of a +// mlir function, it will return the peak memory usage and the memory usage +class SymbolicMemoryProfiler { + public: + explicit SymbolicMemoryProfiler(mlir::func::FuncOp& main) : main_(main) {} + LogicalResult Analysis(); + MemoryUsage GetPeakMemory() { return peak_memory_; } + std::vector GetMemoryUsageList() { return memory_usage_list_; } + + std::vector ConcretMemoryUsageSimulator(int64_t concretValue); + + private: + // return true if it is a temp buffer which only used inside of a fusion op, + // this buffer would be removed after codegen, an example pattern: + // + // alloc = memref.alloc + // lmhlo.fusion() { + // op1(buffer0, buffer1, alloc) + // op2(alloc, buffer2, buffer3) + // } + // dealloc = memref.dealloc alloc + bool isTempBuffer(Value value); + MemoryUsage searchPeakMemory(SmallVector& memoryUsageList); + + SymbolicDimMgr* mgr_; + mlir::func::FuncOp main_; + MemoryUsage peak_memory_; + std::vector memory_usage_list_; +}; + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc b/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc index 0a290ad624a..906fd42727e 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc @@ -1503,6 +1503,9 @@ DenseMap buildSymbolDimInstancesDominantMap( // we should let as much values as possible to be dominated by a same root for (Value root : roots) { for (Value v : instances) { + // llvm::dbgs() << "print root/v: "; + // root.dump(); + // v.dump(); if (dominantMap.find(v) == dominantMap.end() && dominanceInfo.dominates(root, v.getDefiningOp())) { dominantMap[v] = root; @@ -1689,6 +1692,7 @@ LogicalResult injectStaticKnownInfo(ShapeComputationIRAnalysis& analysis, }; for (mhlo::RealDynamicSliceOp op : sliceOps) { + op->dump(); SliceOpShapeHelper helper(op); Value in = op->getOperand(0); Value out = op->getResult(0); @@ -1698,10 +1702,11 @@ LogicalResult injectStaticKnownInfo(ShapeComputationIRAnalysis& analysis, auto outDims = analysis.rankedTensor2SymDims(out); if (inDims.has_value() && outDims.has_value()) { for (const auto& en : llvm::enumerate(llvm::zip(*inDims, *outDims))) { - if (std::get<0>(en.value()) == std::get<1>(en.value())) + if (std::get<0>(en.value()) == std::get<1>(en.value())) { if (failed(helper.markAsFullySlicedAxis(en.index()))) - return op->emitError() << "failed to mark axis" << en.index() + return op->emitError() << "failed to mark axis " << en.index() << " to be fully sliced\n"; + } } } @@ -1761,24 +1766,27 @@ LogicalResult applyShapeComputationOptimization( if (failed(useSameSSAValueIfSymbolicEqual(analysis, changed))) return analysis.getFunc()->emitError( "useSameSSAValueIfSymbolicEqual failed\n"); + analysis.getFunc().dump(); // 2, After propagation some (partial) known dim size infos, refined // the ranked tensor type. if (failed(refineTensorType(analysis, changed))) return analysis.getFunc()->emitError("refineTensorType failed\n"); + analysis.getFunc().dump(); // 3, simplify some expression after propagation shape constraint info. // e.g. if symbolic dim %d is known not negative, then `arith.cmpi eq, %d, // %c-1` could be replaced with a const. if (failed(simplifyAccordingToShapeConstraintInfo(analysis, changed))) return analysis.getFunc()->emitError("fail to simplify\n"); + analysis.getFunc().dump(); // 4, inject some static known infos. For example, // - some axes of a slice op is fully sliced; // - some axes of a pad op are not padded; if (failed(injectStaticKnownInfo(analysis, changed))) return analysis.getFunc()->emitError("fail to injectStaticKnownInfo\n"); - + analysis.getFunc().dump(); return success(); } @@ -1797,9 +1805,13 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, changed = false; std::chrono::steady_clock::time_point begin, end; DISC_DEBUG(begin = std::chrono::steady_clock::now()); + llvm::dbgs() << "before runCanonicalizer\n"; + main.dump(); if (failed(runCanonicalizer(m, runner))) { return failure(); } + llvm::dbgs() << "after runCanonicalizer\n"; + main.dump(); DISC_DEBUG(end = std::chrono::steady_clock::now()); DISC_DEBUG(llvm::dbgs() << " runCanonicalizer takes: " @@ -1840,6 +1852,8 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, if (failed(analysis.run())) { return m.emitError() << "fail to analysis shape computation IR\n"; } + llvm::dbgs() << "after analysis.run()\n"; + main.dump(); DISC_DEBUG(end = std::chrono::steady_clock::now()); DISC_DEBUG(llvm::dbgs() << " Building ShapeComputationIRAnalysis takes: " @@ -1852,6 +1866,8 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, if (failed(applyShapeComputationOptimization(analysis, changed))) { return m.emitError() << "fail to optimize shape computation IR\n"; } + llvm::dbgs() << "after applyShapeComputationOptimization\n"; + main.dump(); DISC_DEBUG(end = std::chrono::steady_clock::now()); DISC_DEBUG(llvm::dbgs() << " applyShapeComputationOptimization takes: " @@ -1885,6 +1901,8 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, if (failed(runCanonicalizer(m, runner))) { return failure(); } + llvm::dbgs() << "after runCanonicalizer\n"; + main.dump(); LLVM_DEBUG(llvm::dbgs() << "Module after optimizeShapeComputation:\n" << m << "\n"); return success(); diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc index fc87b49ee20..0182da7f1fa 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc @@ -84,13 +84,13 @@ bool compareSymbolicDimProduct(const SymbolicDimProduct& lhs, return false; } -llvm::raw_ostream& operator<<(llvm::raw_ostream& os, - const SymbolicDimProduct& product) { - os << "SymbolicDimProduct[\n\tfactor: " << product.factor << ",\n"; - for (auto& s : product.symbols) os << "\tsymbol: " << s << "\n"; - os << "]\n"; - return os; -} +// llvm::raw_ostream& operator<<(llvm::raw_ostream& os, +// const SymbolicDimProduct& product) { +// os << "SymbolicDimProduct[\n\tfactor: " << product.factor << ",\n"; +// for (auto& s : product.symbols) os << "\tsymbol: " << s << "\n"; +// os << "]\n"; +// return os; +// } SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m), symbolTable_(m_) {} @@ -101,7 +101,14 @@ LogicalResult SymbolicDimMgr::load() { }); return loadShapeConstraintGraph(); } - +std::optional SymbolicDimMgr::findSymbolicDimOp(StringRef name) { + for (auto p : symbolDimUnionSet_) { + if (p.first.getName() == name) { + return p.first; + } + } + return std::nullopt; +} std::string SymbolicDimMgr::getNextName() { std::string name; do { @@ -248,7 +255,26 @@ SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, return std::make_pair(std::move(newLhs), std::move(newRhs)); } - +SymbolicDimProduct SymbolicDimMgr::symbolicDimProductAdd( + const SymbolicDimProduct& x, const SymbolicDimProduct& y) { + SymbolicDimProduct result; + result.factor = x.factor + y.factor; + llvm::dbgs() << "x: " << x << "\n"; + llvm::dbgs() << "y: " << y << "\n"; + + // SymbolicDimProduct newLhs, newRhs; + for (auto sym : x.symbols) result.symbols.push_back(sym); + for (auto sym : y.symbols) result.symbols.push_back(sym); + llvm::dbgs() << "add result: " << result << "\n"; + auto newResult = simplifySymbolicDimProduct(result); + llvm::dbgs() << "new result: " << result << "\n"; + return newResult; +} +SymbolicDimProduct symbolicDimProductSub(const SymbolicDimProduct& x, + const SymbolicDimProduct& y) { + SymbolicDimProduct result; + return result; +} std::optional SymbolicDimMgr::symbolicDimProductDivide( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { LLVM_DEBUG(llvm::dbgs() << "Try to check if x % y == 0?\nx = " << lhs @@ -985,7 +1011,6 @@ SliceOpShapeHelper::SliceOpShapeHelper(Operation* op) : op(op) { assert(startAttr.getNumElements() == ty.getRank()); assert(limitAttr.getNumElements() == ty.getRank()); assert(strideAttr.getNumElements() == ty.getRank()); - for (int i = 0; i < ty.getRank(); ++i) { mergeStartIndex(i, startAttr.getValues()[i]); mergeLimitIndex(i, limitAttr.getValues()[i]); @@ -1020,9 +1045,12 @@ LogicalResult SliceOpShapeHelper::markAsFullySlicedAxis(int axis) { LogicalResult SliceOpShapeHelper::mergeStartIndex(int axis, int64_t value) { assert(axis < static_cast(startIndices.size())); - if (startIndices[axis] != value && value != ShapeValueState::kUnknown && - startIndices[axis] != ShapeValueState::kUnknown) - return failure(); + // NOTE(yancey): commented out the following check to support dynamic slice, + // it's valid if start index greater than 0, let's recovery this if statement + // if any issue occurs. + // if (startIndices[axis] != value && value != ShapeValueState::kUnknown && + // startIndices[axis] != ShapeValueState::kUnknown) + // return failure(); if (startIndices[axis] == ShapeValueState::kUnknown) startIndices[axis] = value; diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h index b04cf424c1d..5146c6b836a 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h @@ -361,6 +361,11 @@ class SymbolicDimMgr { std::optional symbolicDimProductDivide( const SymbolicDimProduct& x, const SymbolicDimProduct& y); + SymbolicDimProduct symbolicDimProductAdd(const SymbolicDimProduct& x, + const SymbolicDimProduct& y); + SymbolicDimProduct symbolicDimProductSub(const SymbolicDimProduct& x, + const SymbolicDimProduct& y); + // mark group [a0, b0, ...] and group [a1, b1, c1, ...] are group // multiplication equal `a0 * b0 * ... = a1 * b1 * c1 * ...` bool isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, @@ -378,9 +383,10 @@ class SymbolicDimMgr { // Returns a clone of the original symbol SymbolicDimOp cloneSymbol(SymbolicDimOp symbol); + std::optional findSymbolicDimOp(StringRef name); - // Clones a group of symbols and the relationships among the symbols in the - // group. Returns ok if success, otherwise failure. + // Clones a group of symbols and the relationships among the symbols in + // the group. Returns ok if success, otherwise failure. LogicalResult cloneSymbolGroup( const DenseSet& symbols, DenseMap& mapping); diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc index a17817b83a5..c3b5d0e86d9 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/disc/disc_util.h" #include "mlir/disc/transforms/PassDetail.h" #include "mlir/disc/transforms/shape_utils.h" +#include "tensorflow/tsl/platform/default/logging.h" namespace mlir { namespace disc_ral { @@ -72,7 +73,8 @@ struct DiscShapePropagatePass registry.insert(); } void visitOperator(ModuleOp& m, OpBuilder& rewriter, Operation* op, - std::stack& ctxStack); + ShapeContext& ctx, + std::unordered_map& symbolicMap); void runOnOperation() override; }; bool isBinaryOp(Operation* op) { @@ -110,33 +112,29 @@ std::optional getConstTensor(OpBuilder& b, Operation* op, return const_op.getResult(); } +void PrintCtx(ShapeContext& ctx) { + llvm::dbgs() << "value: "; + ctx.value.dump(); + for (size_t i = 0; i < ctx.shape.size(); ++i) { + llvm::dbgs() << "shape: [" << i << "] = " << ctx.shape[i] << "\n"; + } +} + std::optional HandleBinaryOp(OpBuilder& b, Operation* op, ShapeContext& inputCtx) { - auto bcastOp = dyn_cast_or_null( - op->getOperand(1).getDefiningOp()); - if (!bcastOp) { - return ShapeContext(op->getResult(0), inputCtx.shape); - } - if (bcastOp) { - auto constOp = dyn_cast_or_null( - bcastOp->getOperand(0).getDefiningOp()); - if (!constOp) { - return ShapeContext(op->getResult(0), inputCtx.shape); - } + if (auto bcastOp = dyn_cast_or_null( + op->getOperand(1).getDefiningOp())) { auto elemTy = op->getOperand(0).getType().cast().getElementType(); b.setInsertionPoint(op); - auto dense_attr = constOp.getValue().dyn_cast(); - int64_t value = dense_attr.getValues()[0]; - auto scalar_const_op = getConstTensor(b, op, {value}, {}); - Value inputShape = - b.create(op->getLoc(), op->getOperand(0)); - auto rank = inputCtx.shape.size(); - + auto shapeOf = b.create(op->getLoc(), op->getOperand(0)); + auto outShape = + op->getOperand(0).getType().cast().getShape(); auto dynBcastOp = b.create( op->getLoc(), RankedTensorType::get(inputCtx.shape, elemTy), - scalar_const_op.value(), inputShape, b.getI64TensorAttr({})); + bcastOp->getOperand(0), shapeOf, bcastOp.getBroadcastDimensions()); bcastOp.getResult().replaceAllUsesWith(dynBcastOp.getResult()); + bcastOp->erase(); } return ShapeContext(op->getResult(0), inputCtx.shape); } @@ -176,9 +174,10 @@ std::optional propagateHelper( } } -template <> -std::optional propagateHelper( - OpBuilder& b, Operation* op, ShapeContext& inputCtx) { +std::optional HandleReshapeOp( + OpBuilder& b, Operation* op, ShapeContext& inputCtx, + std::unordered_map& symbolicMap) { + b.setInsertionPoint(op); auto reshape_op = dyn_cast(op); if (!reshape_op) return std::nullopt; Type intType = b.getIntegerType(32); @@ -188,13 +187,47 @@ std::optional propagateHelper( reshape_op.getResult().getType().cast(); auto resultRank = resultRankType.getRank(); auto resultShape = resultRankType.getShape(); + auto inputShape = reshape_op.getOperand().getType().cast(); SmallVector newShape(resultRank, ShapedType::kDynamic); + bool symbolicSeqlen = false; + for (size_t i = 0; i < resultShape.size(); ++i) { + if (symbolicMap.count(resultShape[i])) { + symbolicSeqlen = true; + } + } + if (symbolicSeqlen) { + SmallVector newShapeValues; + SmallVector newShape; + for (size_t i = 0; i < resultShape.size(); ++i) { + if (symbolicMap.count(resultShape[i])) { + newShape.push_back(ShapedType::kDynamic); + newShapeValues.push_back(symbolicMap[resultShape[i]]); + } else { + newShape.push_back(resultShape[i]); + newShapeValues.push_back( + b.create(op->getLoc(), resultShape[i])); + } + } + Value shapeValue = + b.create(op->getLoc(), newShapeValues); + auto shape = b.create(op->getLoc(), op->getOperand(0)); + auto numElems = b.create(op->getLoc(), shape); + auto computeReshapeShape = b.create( + op->getLoc(), shapeValue.getType(), numElems.getResult(), shapeValue); + auto dynReshapeOpResultType = + RankedTensorType::get(newShape, resultRankType.getElementType()); + auto dynReshapeOp = b.create( + op->getLoc(), dynReshapeOpResultType, reshape_op.getOperand(), + computeReshapeShape); + op->getResult(0).replaceAllUsesWith(dynReshapeOp.getResult()); + op->erase(); + return ShapeContext(dynReshapeOp->getResult(0), newShape); + } int64_t numel = std::accumulate(inputCtx.shape.begin(), inputCtx.shape.end(), int64_t(1), [](int64_t acc, int64_t num) { return num == ShapedType::kDynamic ? acc : acc * num; }); - bool inferenced = true; while (inferenced) { inferenced = false; @@ -280,15 +313,21 @@ std::optional propagateHelper( auto stridesCst = slice_op.getStrides().getValues()[i]; startIndices[i] = b.create(slice_op.getLoc(), startIndicesCst); - // using dynamic dim if limitIndices is the same as input shape - if (limitIndicesCst == inputShape[i] && - inputCtx.shape[i] == ShapedType::kDynamic) { - limitIndices[i] = b.create(loc, slice_op.getOperand(), i); - newShape[i] = inputCtx.shape[i]; + if (inputCtx.shape[i] == ShapedType::kDynamic) { + newShape[i] = ShapedType::kDynamic; + if (limitIndicesCst == inputShape[i]) { + limitIndices[i] = + b.create(loc, slice_op.getOperand(), i); + } else { + auto limitOffset = inputShape[i] - limitIndicesCst; + limitIndices[i] = b.create( + loc, b.create(loc, slice_op.getOperand(), i), + b.create(loc, limitOffset)); + } } else { + newShape[i] = (limitIndicesCst - startIndicesCst - 1) / stridesCst + 1; limitIndices[i] = b.create(slice_op.getLoc(), limitIndicesCst); - newShape[i] = (limitIndicesCst - startIndicesCst - 1) / stridesCst + 1; } strides[i] = b.create(slice_op.getLoc(), stridesCst); @@ -297,13 +336,13 @@ std::optional propagateHelper( Value stridesValue = b.create(loc, strides); Value limitIndicesValue = b.create(loc, limitIndices); auto sliceOpResultType = - RankedTensorType::get(newShape, rankType.getElementType()); - auto dyncSliceOp = b.create( + RankedTensorType::get(inputCtx.shape, rankType.getElementType()); + auto dynSliceOp = b.create( loc, sliceOpResultType, slice_op.getOperand(), baseIndicesValue, limitIndicesValue, stridesValue); - op->getResult(0).replaceAllUsesWith(dyncSliceOp.getResult()); + op->getResult(0).replaceAllUsesWith(dynSliceOp.getResult()); op->erase(); - return ShapeContext(dyncSliceOp->getResult(0), newShape); + return ShapeContext(dynSliceOp->getResult(0), newShape); } template <> @@ -472,6 +511,68 @@ std::optional propagateHelper( SmallVector newShape(resultShape.begin(), resultShape.end()); return ShapeContext(op->getResult(0), newShape); } +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto resultShape = + op->getResult(0).getType().cast().getShape(); + SmallVector newShape(resultShape.begin(), resultShape.end()); + return ShapeContext(op->getResult(0), newShape); +} + +template <> +std::optional propagateHelper( + OpBuilder& rewriter, Operation* op, ShapeContext& inputCtx) { + auto padOp = dyn_cast(op); + if (!padOp) return std::nullopt; + rewriter.setInsertionPoint(op); + Value input = op->getOperand(0); + Value paddingValue = op->getOperand(1); + auto padOpType = padOp.getType().cast(); + + auto inputTy = input.getType().cast(); + auto resultTy = op->getResult(0).getType().cast(); + auto rank = inputTy.getRank(); + auto loc = padOp.getLoc(); + SmallVector paddingLow, paddingHigh, paddingInterior; + for (auto padDimension : + llvm::enumerate(padOp.getEdgePaddingLow().getValues())) { + paddingLow.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padDimension.value()))); + } + for (auto padDimension : + llvm::enumerate(padOp.getEdgePaddingHigh().getValues())) { + paddingHigh.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padDimension.value()))); + } + for (auto padDimension : + llvm::enumerate(padOp.getInteriorPadding().getValues())) { + paddingInterior.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padDimension.value()))); + } + + SmallVector newShape; + for (size_t i = 0; i < resultTy.getRank(); ++i) { + if (inputCtx.shape[i] == ShapedType::kDynamic) { + newShape.push_back(ShapedType::kDynamic); + } else { + newShape.push_back(inputCtx.shape[i]); + } + } + auto dynPadType = RankedTensorType::get(newShape, padOpType.getElementType()); + Value paddingLowTensor = + rewriter.create(loc, paddingLow); + Value paddingHighTensor = + rewriter.create(loc, paddingHigh); + Value paddingInterTensor = + rewriter.create(loc, paddingInterior); + auto dynPadOp = rewriter.create( + op->getLoc(), dynPadType, input, paddingValue, paddingLowTensor, + paddingHighTensor, paddingInterTensor); + op->getResult(0).replaceAllUsesWith(dynPadOp.getResult()); + op->erase(); + return ShapeContext(dynPadOp->getResult(0), newShape); +} template <> std::optional propagateHelper( @@ -548,6 +649,7 @@ std::optional propagateHelper( attr.getStartIndexMap(), attr.getIndexVectorDim()), gather_op.getIndicesAreSorted()); gather_op.getResult().replaceAllUsesWith(dynamic_gather_op.getResult()); + gather_op->erase(); // Update DynamicGatherOp result shape information return propagateHelper( @@ -588,17 +690,22 @@ LogicalResult parseInputDynamicDims( } void applyShapeContext(ShapeContext& ctx) { - if (!ctx.value) return; auto res_ty = ctx.value.getType().dyn_cast(); if (!res_ty) return; auto elemTy = res_ty.getElementType(); ctx.value.setType(RankedTensorType::get(ctx.shape, elemTy)); } -std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, - ShapeContext& inputCtx) { +std::optional propagateOpShape( + OpBuilder& rewriter, Operation* op, ShapeContext& inputCtx, + std::unordered_map& symbolicMap) { if (isUnaryOp(op)) { - return ShapeContext(op->getResult(0), inputCtx.shape); + if (inputCtx.value == op->getOperand(0)) { + return ShapeContext(op->getResult(0), inputCtx.shape); + } + SmallVector newShape = llvm::to_vector<4>( + op->getOperand(0).getType().cast().getShape()); + return ShapeContext(op->getResult(0), newShape); } if (isBinaryOp(op)) { return HandleBinaryOp(rewriter, op, inputCtx); @@ -606,8 +713,19 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, if (isa(op)) { return propagateHelper(rewriter, op, inputCtx); } - if (isa(op)) { - return ShapeContext(op->getResult(0), inputCtx.shape); + if (isa(op)) { + return HandleReshapeOp(rewriter, op, inputCtx, symbolicMap); + } + if (auto bcastOp = dyn_cast(op)) { + auto result = op->getResult(0); + auto resultTy = result.getType().cast(); + auto elemTy = resultTy.getElementType(); + bool withSymbolicShape = false; + SmallVector mhloShape; + SmallVector shapes; + + SmallVector newShape = llvm::to_vector<4>(resultTy.getShape()); + return ShapeContext(result, newShape); } #define PROPAGATE_OP_HANDLER(opType) \ if (auto t##opType = dyn_cast(op)) { \ @@ -616,7 +734,7 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, } PROPAGATE_OP_HANDLER(DotOp); PROPAGATE_OP_HANDLER(SliceOp); - PROPAGATE_OP_HANDLER(ReshapeOp); + PROPAGATE_OP_HANDLER(PadOp); PROPAGATE_OP_HANDLER(ConcatenateOp); PROPAGATE_OP_HANDLER(ReduceOp); PROPAGATE_OP_HANDLER(TransposeOp); @@ -626,6 +744,7 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, PROPAGATE_OP_HANDLER(DynamicReshapeOp); PROPAGATE_OP_HANDLER(RealDynamicSliceOp); PROPAGATE_OP_HANDLER(DynamicBroadcastInDimOp); + PROPAGATE_OP_HANDLER(DynamicPadOp); // PROPAGATE_OP_HANDLER(DimOp); #undef PROPAGATE_OP_HANDLER return std::nullopt; @@ -633,38 +752,169 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, bool shouldStopPropagation(Operation* op, ShapeContext& ctx) { if (isConcreteShape(ctx)) return true; - if (isa(op)) + if (isa(op)) return true; if (isa(op->getParentOp())) return true; return false; } +struct OperationValuePairHash { + std::size_t operator()(const std::pair& pair) const { + return llvm::hash_combine(pair.first, pair.second); + } +}; +// Custom equality function for std::pair +struct OperationValuePairEqual { + bool operator()(const std::pair& lhs, + const std::pair& rhs) const { + return lhs.first == rhs.first && lhs.second == rhs.second; + } +}; -void DiscShapePropagatePass::visitOperator(ModuleOp& m, OpBuilder& rewriter, - Operation* op, - std::stack& ctxStack) { - auto ctx = ctxStack.top(); - if (shouldStopPropagation(op, ctx)) { - return; +void DiscShapePropagatePass::visitOperator( + ModuleOp& m, OpBuilder& rewriter, Operation* rootOp, ShapeContext& ctx, + std::unordered_map& symbolicMap) { + std::stack> stack; + std::stack ctxStack; + std::unordered_set, OperationValuePairHash, + OperationValuePairEqual> + visited; + if (VLOG_IS_ON(1)) { + VLOG(1) << "rootOp:"; + rootOp->dump(); } - auto resultShapeCtx = propagateOpShape(rewriter, op, ctx); - if (!resultShapeCtx) { - m.emitError("failed propagate shape on op:" + - op->getName().stripDialect().str()); - signalPassFailure(); - return; + stack.push({rootOp, ctx}); + auto pair = std::make_pair(rootOp, ctx.value); + visited.insert(pair); + while (!stack.empty()) { + auto [op, inputCtx] = stack.top(); + stack.pop(); + if (shouldStopPropagation(op, inputCtx)) { + while (!ctxStack.empty()) { + auto ctx = ctxStack.top(); + ctxStack.pop(); + applyShapeContext(ctx); + } + continue; + } + auto resultCtx = propagateOpShape(rewriter, op, inputCtx, symbolicMap); + if (!resultCtx) { + m.emitError("failed propagate shape on op:" + + op->getName().stripDialect().str()); + signalPassFailure(); + return; + } + ctxStack.push(*resultCtx); + SmallVector users(resultCtx->value.getUsers().begin(), + resultCtx->value.getUsers().end()); + for (size_t i = 0; i < users.size(); ++i) { + auto pair = std::make_pair(users[i], resultCtx->value); + if (visited.count(pair)) continue; + stack.push({users[i], *resultCtx}); + visited.insert(pair); + } } - ctxStack.push(*resultShapeCtx); - SmallVector ctxUsers(resultShapeCtx->value.getUsers().begin(), - resultShapeCtx->value.getUsers().end()); - for (size_t i = 0; i < ctxUsers.size(); ++i) { - visitOperator(m, rewriter, ctxUsers[i], ctxStack); + while (!ctxStack.empty()) { + auto ctx = ctxStack.top(); + ctxStack.pop(); + applyShapeContext(ctx); } - auto context = ctxStack.top(); - ctxStack.pop(); - applyShapeContext(context); } +std::optional HandleConstOp( + OpBuilder& rewriter, Operation* op, + std::unordered_map& symbolicMap) { + auto constOp = dyn_cast(op); + if (!constOp) return std::nullopt; + auto resultTy = op->getResult(0).getType().dyn_cast(); + if (!resultTy) return std::nullopt; + + auto elemTy = resultTy.getElementType(); + bool withSymbolicShape = false; + SmallVector mhloShape; + SmallVector shapes; + rewriter.setInsertionPoint(op); + for (auto dim : resultTy.getShape()) { + if (symbolicMap.count(dim)) { + withSymbolicShape = true; + mhloShape.push_back(symbolicMap[dim]); + shapes.push_back(ShapedType::kDynamic); + } else { + mhloShape.push_back( + rewriter.create(op->getLoc(), dim)); + shapes.push_back(dim); + } + } + if (withSymbolicShape) { + auto mhloShapeValue = + rewriter.create(op->getLoc(), mhloShape); + auto dense_attr = constOp.getValue().dyn_cast(); + Value scalar_const_op; + if (elemTy.isIntOrIndex()) { + auto value = (*dense_attr.getValues().begin()).getSExtValue(); + auto const_type = RankedTensorType::get({}, elemTy); + auto const_attr = DenseElementsAttr::get(const_type, {value}); + scalar_const_op = rewriter.create( + op->getLoc(), const_type, const_attr); + } else if (elemTy.isa()) { + auto value = (*dense_attr.getValues().begin()); + auto const_type = RankedTensorType::get({}, elemTy); + auto const_attr = DenseElementsAttr::get(const_type, {value}); + scalar_const_op = rewriter.create( + op->getLoc(), const_type, const_attr); + } else if (elemTy.isa()) { + auto value = (*dense_attr.getValues().begin()).convertToDouble(); + auto const_type = RankedTensorType::get({}, elemTy); + auto const_attr = DenseElementsAttr::get(const_type, {value}); + scalar_const_op = rewriter.create( + op->getLoc(), const_type, const_attr); + } + auto mhloBroadcastInDimOp = rewriter.create( + op->getLoc(), RankedTensorType::get(shapes, elemTy), scalar_const_op, + mhloShapeValue, /*broadcast_dimensions=*/rewriter.getI64TensorAttr({})); + op->getResult(0).replaceAllUsesWith(mhloBroadcastInDimOp.getResult()); + constOp->erase(); + mhloBroadcastInDimOp.dump(); + return mhloBroadcastInDimOp; + } + return std::nullopt; +} +std::optional HandleDyncBroadcastOp( + OpBuilder& rewriter, Operation* op, + std::unordered_map& symbolicMap) { + auto bcastOp = dyn_cast(op); + if (!bcastOp) return std::nullopt; + auto result = op->getResult(0); + auto resultTy = result.getType().cast(); + auto elemTy = resultTy.getElementType(); + bool withSymbolicShape = false; + SmallVector mhloShape; + SmallVector shapes; + rewriter.setInsertionPoint(op); + for (auto dim : resultTy.getShape()) { + if (symbolicMap.count(dim)) { + withSymbolicShape = true; + mhloShape.push_back(symbolicMap[dim]); + shapes.push_back(ShapedType::kDynamic); + } else { + mhloShape.push_back( + rewriter.create(op->getLoc(), dim)); + shapes.push_back(dim); + } + } + if (withSymbolicShape) { + auto mhloShapeValue = + rewriter.create(op->getLoc(), mhloShape); + auto mhloBroadcastInDimOp = rewriter.create( + op->getLoc(), RankedTensorType::get(shapes, elemTy), op->getOperand(0), + mhloShapeValue, bcastOp.getBroadcastDimensions()); + op->getResult(0).replaceAllUsesWith(mhloBroadcastInDimOp.getResult()); + op->erase(); + return mhloBroadcastInDimOp; + } + return std::nullopt; +} void DiscShapePropagatePass::runOnOperation() { ModuleOp m = getOperation(); auto main = m.lookupSymbol("main"); @@ -695,15 +945,62 @@ void DiscShapePropagatePass::runOnOperation() { SmallVector newShape; std::copy(ty.getShape().begin(), ty.getShape().end(), std::back_inserter(newShape)); + if (pair.second.size() != 1 || pair.second[0] != 1) { + main.emitError("only support sequence length dims equal to 1"); + signalPassFailure(); + return; + } for (auto dim : pair.second) { newShape[dim] = ShapedType::kDynamic; } - std::stack ctxStack; ShapeContext ctx(value, newShape); - ctxStack.push(ctx); auto newType = RankedTensorType::get(newShape, ty.getElementType()); - for (auto user : main.getArgument(argIdx).getUsers()) { - visitOperator(m, rewriter, user, ctxStack); + SmallVector users(value.getUsers().begin(), + value.getUsers().end()); + rewriter.setInsertionPointToStart(&main.getBody().front()); + std::unordered_map symbolicMap; + // seqlen + auto seqlen = ty.getShape()[pair.second[0]]; + auto seqlenValue = rewriter.create(users[0]->getLoc(), value, + pair.second[0]); + symbolicMap.insert({seqlen, seqlenValue}); + // seq sliced: seqlen - 1 + auto seqlenSliced = seqlen - 1; + auto seqlenSlicedValue = rewriter.create( + users[0]->getLoc(), seqlenValue, + rewriter.create(users[0]->getLoc(), 1)); + symbolicMap.insert({seqlenSliced, seqlenValue}); + // batch size + auto bsz = ty.getShape()[0]; + auto bszValue = + rewriter.create(users[0]->getLoc(), bsz); + // symbolicMap.insert({bsz, bszValue}); + + // bsz * seqlen + auto bszSeq = bsz * seqlen; + auto bszSeqValue = rewriter.create(users[0]->getLoc(), + bszValue, seqlenValue); + symbolicMap.insert({bszSeq, bszSeqValue}); + + // bszSeqlen = bsz * (seqlen - 1) + auto bszSlicedSeqlen = ty.getShape()[0] * (seqlen - 1); + auto bszSlicedSeqlenValue = rewriter.create( + users[0]->getLoc(), bszValue, + rewriter.create( + users[0]->getLoc(), seqlenSlicedValue, + rewriter.create(users[0]->getLoc(), 1))); + + symbolicMap.insert({bszSlicedSeqlen, bszSlicedSeqlenValue}); + main.walk([&](Operation* op) { + if (auto dynOp = HandleDyncBroadcastOp(rewriter, op, symbolicMap)) { + users.push_back(dynOp.value()); + } + if (auto dynOp = HandleConstOp(rewriter, op, symbolicMap)) { + users.push_back(dynOp.value()); + } + }); + for (auto user : users) { + visitOperator(m, rewriter, user, ctx, symbolicMap); } new_arg_types[argIdx] = newType; applyShapeContext(ctx); diff --git a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc index fcf33a6e4e5..4da4d2d187e 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc @@ -1305,6 +1305,8 @@ Value elementalLower(OpBuilder* b, Location loc, b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front()); if (i == num_input_operands - 1) { + b->create(loc, zero_element); + /* input_index[axis] = b->create(loc, out_idx, low_bound); auto operand_memref = op.getOperand(i); auto ret_value = @@ -1315,6 +1317,7 @@ Value elementalLower(OpBuilder* b, Location loc, operand_memref, input_index, lower_config); b->create(loc, ret_value); + */ } else { b->create(loc, if_inbound_ops[i + 1].getResults()); } diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index b09b0bef411..c9c047661dc 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -337,6 +337,7 @@ std::unique_ptr> createDiscReduceBufferLiveRangePass(); // rewrite mhlo collective ops to disc custom library call std::unique_ptr> createDiscCollectiveOpsRewriterPass(); +std::unique_ptr> createDiscOffloadingPass(); } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/shape_utils.cc b/tao_compiler/mlir/disc/transforms/shape_utils.cc index 5ef90ada316..08350b761f8 100644 --- a/tao_compiler/mlir/disc/transforms/shape_utils.cc +++ b/tao_compiler/mlir/disc/transforms/shape_utils.cc @@ -1989,6 +1989,23 @@ bool ShapeConstraintIRAnalysis::isShapeEqual(Value lhs, Value rhs) { return lhsSyms == rhsSyms; } +bool ShapeConstraintIRAnalysis::buildSymbolicDimProduct( + SymbolicDimProduct& prod, Value value) { + auto ty = value.getType().dyn_cast(); + auto it = memrefValue2SymDims_.find(value); + if (!ty || !ty.hasRank()) return false; + for (size_t idx = 0; idx < ty.getRank(); ++idx) { + if (ty.getShape()[idx] == ShapedType::kDynamic) { + if (it == memrefValue2SymDims_.end() || it->second.size() <= idx) + return false; + prod.symbols.push_back(it->second[idx]); + } else { + prod.factor *= ty.getShape()[idx]; + } + } + return true; +} + bool ShapeConstraintIRAnalysis::isProductEqual(Value lhs, ArrayRef lhsDimIdxs, Value rhs, diff --git a/tao_compiler/mlir/disc/transforms/shape_utils.h b/tao_compiler/mlir/disc/transforms/shape_utils.h index 5ed178207ab..d43919834ba 100644 --- a/tao_compiler/mlir/disc/transforms/shape_utils.h +++ b/tao_compiler/mlir/disc/transforms/shape_utils.h @@ -213,6 +213,7 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { // rhs.shape[rd0] * rhs.shape[rd1] * ...` bool isProductEqual(Value lhs, ArrayRef lhsDimIdxs, Value rhs, ArrayRef rhsDimIdxs) override; + bool buildSymbolicDimProduct(SymbolicDimProduct& prod, Value value); private: // The operation this analysis runs on. From 6bd9ec4285ef4de191e7a8a03f90823d58538636 Mon Sep 17 00:00:00 2001 From: YanXu Date: Mon, 19 Aug 2024 16:40:17 +0800 Subject: [PATCH 2/4] update --- tao_compiler/mlir/disc/BUILD | 2 +- ...floading.cc => disc_dynamic_offloading.cc} | 274 ++++-------------- 2 files changed, 50 insertions(+), 226 deletions(-) rename tao_compiler/mlir/disc/transforms/{disc_offloading.cc => disc_dynamic_offloading.cc} (60%) diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 2a05e43b130..77baa72c29b 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -2232,7 +2232,7 @@ cc_library( cc_library( name = "disc_offloading", srcs = [ - "transforms/disc_offloading.cc", + "transforms/disc_dynamic_offloading.cc", "transforms/disc_remat_utils.cc" ], hdrs = [ diff --git a/tao_compiler/mlir/disc/transforms/disc_offloading.cc b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc similarity index 60% rename from tao_compiler/mlir/disc/transforms/disc_offloading.cc rename to tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc index 357b80ec848..5d1a3b7d2f1 100644 --- a/tao_compiler/mlir/disc/transforms/disc_offloading.cc +++ b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc @@ -42,8 +42,8 @@ namespace mlir { namespace disc_ral { constexpr StringRef kRematBlockTypeAttr = "disc.remat.type"; -constexpr StringRef kRematBufferAttr = "disc.remat.dummy-buffer"; -constexpr StringRef kRematMinSymDim = "disc.remat.min-sym-dim"; +constexpr StringRef kRematBufferAttr = "disc.remat.is_dummy_buffer"; +constexpr StringRef kRematMinSymDim = "disc.remat.min_symbolic_dim"; struct DiscOffloadingPass : public DiscOffloadingPassBase { DiscOffloadingPass() @@ -61,10 +61,6 @@ struct DiscOffloadingPass : public DiscOffloadingPassBase { void InsertRematBlock(mlir::OpBuilder& b, LivingBuffer& livingBuffer, Value rematCond, std::vector& ops, int64_t minSymValue); - void InsertOffloadingOp(mlir::OpBuilder& rewriter, Operation* prevOp, - Operation* op, Value buffer, - std::vector consumers, Value symbS0, - int64_t cstMinS0); }; Location getFusionLocation(OpBuilder& b, Operation* op) { if (auto fusionOp = dyn_cast(op->getParentOp())) { @@ -136,7 +132,7 @@ bool IsDynamicShapeBuffer(Value buffer) { } return false; } -Value createAllocOp(OpBuilder& b, Location loc, Value refBuffer) { +Value cloneBuffer(OpBuilder& b, Location loc, Value buffer) { MemRefType type = refBuffer.getType().cast(); SmallVector dynShape; for (size_t i = 0; i < type.getRank(); i++) { @@ -153,11 +149,25 @@ Value createAllocOp(OpBuilder& b, Location loc, Value refBuffer) { return allocOp.getResult(); } +// InsertRematBlock create a remat block for the living buffer: +// reload and offload blocks are always pair in graph: +// +// if remat_cond: +// return offload(buffer) +// else: +// return dummy_buffer +// ...... +// if remat_cond: +// return reload(buffer) +// else: +// return buffer void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, LivingBuffer& livingBuffer, Value rematCond, std::vector& ops, int64_t minSymValue) { + // TODO(yancey): we need a custom operator to handle the dummy buffer to avoid + // call alloc or dealloc operator auto buffer = livingBuffer.buffer; auto startOp = livingBuffer.start; auto endOp = livingBuffer.end; @@ -165,12 +175,7 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, auto deviceMemrefType = buffer.getType().cast(); auto hostMemrefType = MemRefType::get( deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0); - // offloading to host - // get dynamic dim of buffer and insert into ValueRange - // if remat_cond: - // yield offload(buffer) - // else: - // yield dummy_buffer + // insert offload block auto offloadIfOp = b.create(startOp->getLoc(), /*resultTypes*/ hostMemrefType, rematCond, @@ -183,7 +188,7 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, b.getI64IntegerAttr(minSymValue)); offloadIfOp.getThenRegion().front().clear(); b.setInsertionPointToEnd(&offloadIfOp.getThenRegion().front()); - auto hostBuffer = createAllocOp(b, endOp->getLoc(), buffer); + auto hostBuffer = cloneBuffer(b, endOp->getLoc(), buffer); hostBuffer.setType(MemRefType::get(deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0)); b.create(endOp->getLoc(), buffer, hostBuffer); @@ -192,7 +197,7 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, offloadIfOp.getElseRegion().front().clear(); b.setInsertionPointToStart(&offloadIfOp.getElseRegion().front()); - auto dummyHostBuffer = createAllocOp(b, endOp->getLoc(), buffer); + auto dummyHostBuffer = cloneBuffer(b, endOp->getLoc(), buffer); dummyHostBuffer.getDefiningOp()->setAttr(kRematBufferAttr, b.getBoolAttr(true)); dummyHostBuffer.setType(MemRefType::get( @@ -222,17 +227,17 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, reloadIfOp.getThenRegion().front().clear(); b.setInsertionPointToStart(&reloadIfOp.getThenRegion().front()); - auto deviceBuffer = createAllocOp(b, endOp->getLoc(), buffer); + auto deviceBuffer = cloneBuffer(b, endOp->getLoc(), buffer); deviceBuffer.setType(deviceMemrefType); auto h2dOp = b.create( endOp->getLoc(), offloadIfOp.getResult(0), deviceBuffer); b.create(endOp->getLoc(), deviceBuffer); reloadIfOp.getElseRegion().front().clear(); b.setInsertionPointToStart(&reloadIfOp.getElseRegion().front()); - auto dummyDeviceBuffer = createAllocOp(b, endOp->getLoc(), buffer); + auto dummyDeviceBuffer = cloneBuffer(b, endOp->getLoc(), buffer); dummyDeviceBuffer.setType(deviceMemrefType); - dummyDeviceBuffer.getDefiningOp()->setAttr( - b.getStringAttr("disc.remat.dummy-buffer"), b.getBoolAttr(true)); + dummyDeviceBuffer.getDefiningOp()->setAttr(kRematBufferAttr, + b.getBoolAttr(true)); b.create(endOp->getLoc(), buffer); for (auto pair : ops) { auto op = pair.first; @@ -246,125 +251,6 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, } } } -void DiscOffloadingPass::InsertOffloadingOp(mlir::OpBuilder& b, - Operation* prevOp, Operation* op, - Value buffer, - std::vector consumers, - Value symbS0, int64_t cstMinS0) { - Location loc = prevOp->getLoc(); - if (auto fusionOp = dyn_cast(prevOp->getParentOp())) { - b.setInsertionPointAfter(fusionOp); - loc = fusionOp.getLoc(); - } else { - b.setInsertionPointAfter(prevOp); - } - auto offloadCond = b.create( - loc, arith::CmpIPredicate::sgt, symbS0, - b.create(op->getLoc(), cstMinS0)); - b.setInsertionPointAfter(offloadCond); - StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); - auto deviceMemrefType = buffer.getType().cast(); - auto hostMemrefType = MemRefType::get( - deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0); - // offloading to host - // get dynamic dim of buffer and insert into ValueRange - auto offloadIfOp = - b.create(op->getLoc(), - /*resultTypes*/ hostMemrefType, offloadCond, - /*hasElseRegion*/ true); - offloadIfOp.getOperation()->setAttr( - attrName, buffer.getDefiningOp()->getAttr(attrName)); - offloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, - b.getStringAttr("offload")); - offloadIfOp.getOperation()->setAttr(kRematMinSymDim, - b.getI64IntegerAttr(cstMinS0)); - offloadIfOp.getThenRegion().front().clear(); - b.setInsertionPointToEnd(&offloadIfOp.getThenRegion().front()); - auto hostBuffer = createAllocOp(b, op->getLoc(), buffer); - hostBuffer.setType(MemRefType::get(deviceMemrefType.getShape(), - deviceMemrefType.getElementType(), {}, 0)); - b.create(op->getLoc(), buffer, hostBuffer); - b.create(op->getLoc(), buffer); - b.create(op->getLoc(), hostBuffer); - - offloadIfOp.getElseRegion().front().clear(); - b.setInsertionPointToStart(&offloadIfOp.getElseRegion().front()); - auto dummyHostBuffer = createAllocOp(b, op->getLoc(), buffer); - dummyHostBuffer.getDefiningOp()->setAttr(kRematBufferAttr, - b.getBoolAttr(true)); - dummyHostBuffer.setType(MemRefType::get( - deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0)); - b.create(op->getLoc(), dummyHostBuffer); - b.setInsertionPointAfter(offloadIfOp); - - if (auto fusionOp = dyn_cast(op->getParentOp())) { - b.setInsertionPoint(fusionOp); - } else { - b.setInsertionPoint(op); - } - - // insert reload block - scf::IfOp reloadIfOp = - b.create(op->getLoc(), - /*resultTypes*/ deviceMemrefType, offloadCond, - /*hasElseRegion*/ true); - if (buffer.getDefiningOp()->hasAttr(attrName)) { - reloadIfOp.getOperation()->setAttr( - attrName, buffer.getDefiningOp()->getAttr(attrName)); - } - reloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, - b.getStringAttr("reload")); - reloadIfOp.getOperation()->setAttr(kRematMinSymDim, - b.getI64IntegerAttr(cstMinS0)); - reloadIfOp.getThenRegion().front().clear(); - b.setInsertionPointToStart(&reloadIfOp.getThenRegion().front()); - - auto deviceBuffer = createAllocOp(b, op->getLoc(), buffer); - deviceBuffer.setType(deviceMemrefType); - auto h2dOp = b.create( - op->getLoc(), offloadIfOp.getResult(0), deviceBuffer); - b.create(op->getLoc(), deviceBuffer); - reloadIfOp.getElseRegion().front().clear(); - b.setInsertionPointToStart(&reloadIfOp.getElseRegion().front()); - auto dummyDeviceBuffer = createAllocOp(b, op->getLoc(), buffer); - dummyDeviceBuffer.setType(deviceMemrefType); - dummyDeviceBuffer.getDefiningOp()->setAttr( - b.getStringAttr("disc.remat.dummy-buffer"), b.getBoolAttr(true)); - b.create(op->getLoc(), buffer); - - for (size_t i = 0; i < consumers.size(); i++) { - auto consumer = consumers[i]; - for (size_t j = 0; j < consumer->getNumOperands(); j++) { - if (consumer->getOperand(j) == buffer) { - consumer->setOperand(j, reloadIfOp.getResult(0)); - } - } - } -} -Value getSymBufferSize(OpBuilder& b, Location loc, Value buffer, - ShapeConstraintIRAnalysis& shapeAnalysis) { - int64_t factor = 1; - SmallVector symbols; - SmallVector dimValues; - auto memrefType = buffer.getType().cast(); - for (size_t i = 0; i < memrefType.getRank(); ++i) { - auto dim = memrefType.getShape()[i]; - if (dim == ShapedType::kDynamic) { - dimValues.push_back(b.create(loc, buffer, i)); - } else { - factor *= dim; - } - } - Value numal = b.create(loc, factor); - for (auto dim : dimValues) { - numal = b.create(loc, numal, dim); - } - return numal; -} - -struct SymbolicDimProductSum { - llvm::SmallVector symbolProds; -}; llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const SymbolicDimProduct& prod) { @@ -381,59 +267,14 @@ llvm::raw_ostream& operator<<(llvm::raw_ostream& os, } return os; } -/* -SymbolicDimProductSum symbolicDimProductSub(SymbolicDimProductSum sum, - SymbolicDimProduct b) { - bool mergeSymbols = false; - SymbolicDimProductSum result; - - for (auto symbolProd : sum.symbolProds) { - if (symbolProd.symbols == b.symbols) { - mergeSymbols = true; - symbolProd.factor -= b.factor; - if (symbolProd.factor == 0) { - continue; - } - result.symbolProds.push_back(symbolProd); - } else { - result.symbolProds.push_back(symbolProd); - } - } - if (!mergeSymbols) { - b.factor *= -b.factor; - result.symbolProds.push_back(b); - } - return result; -} -SymbolicDimProductSum symbolicDimProductSumAdd(SymbolicDimProductSum sum, - SymbolicDimProduct b) { - SymbolicDimProductSum result; - bool mergeSymbols = false; - for (auto symbolProd : sum.symbolProds) { - if (symbolProd.symbols == b.symbols) { - mergeSymbols = true; - symbolProd.factor += b.factor; - result.symbolProds.push_back(symbolProd); - } else { - result.symbolProds.push_back(symbolProd); - } - } - if (!mergeSymbols) { - result.symbolProds.push_back(b); - } - return result; -} -*/ -// Dumps SymbolicDimProduct to the output stream - -std::tuple solveQuadratic(int64_t A, int64_t B, - int64_t C) { - if (A == 0) { +std::tuple solveQuadratic(int64_t a, int64_t b, + int64_t c) { + if (a == 0) { throw std::invalid_argument( "Coefficient A cannot be zero in a quadratic equation."); } - double discriminant = B * B - 4 * A * C; + double discriminant = b * b - 4 * a * c; // no solution if b^2 - 4ac < 0 if (discriminant < 0) { return std::make_tuple(false, 0.0, 0.0); @@ -442,9 +283,9 @@ std::tuple solveQuadratic(int64_t A, int64_t B, double sqrtDiscriminant = std::sqrt(discriminant); // compute x1, x2 - int64_t X1 = (-B + sqrtDiscriminant) / (2 * A); - int64_t X2 = (-B - sqrtDiscriminant) / (2 * A); - return std::make_tuple(true, X1, X2); + int64_t x1 = (-b + sqrtDiscriminant) / (2 * a); + int64_t x2 = (-b - sqrtDiscriminant) / (2 * a); + return std::make_tuple(true, x1, x2); } int64_t findMinSymbolicDimValue(MemoryUsage memoryPeakExpr, int64_t memoryLimitation) { @@ -462,6 +303,9 @@ int64_t findMinSymbolicDimValue(MemoryUsage memoryPeakExpr, } } c -= memoryLimitation; + if (a == 0) { + return -c / b; + } auto [hasSolution, x0, x1] = solveQuadratic(a, b, c); if (!hasSolution) { throw std::invalid_argument("No solution for the quadratic equation."); @@ -469,19 +313,6 @@ int64_t findMinSymbolicDimValue(MemoryUsage memoryPeakExpr, return std::max(x0, x1); } -int64_t getConcretValuewithCst(SymbolicDimProductSum prod, int64_t cstValue) { - int64_t factor = 1; - for (auto symProd : prod.symbolProds) { - if (symProd.symbols.size() == 0) { - factor += symProd.factor; - } else if (symProd.symbols.size() == 1) { - factor += symProd.factor * cstValue; - } else if (symProd.symbols.size() == 2) { - factor += symProd.factor * cstValue * cstValue; - } - } - return factor; -} bool inRematOffloadBlock(Value value) { if (auto ifOp = dyn_cast(value.getDefiningOp()->getParentOp())) { auto blockType = @@ -530,7 +361,7 @@ Value InsertRematCond(mlir::OpBuilder& b, Location loc, Value s0, b.create(s0.getLoc(), minS0)); return offloadCond; } -vector FilterLivingBuffers( +std::vector FilterBuffers( std::vector livingBuffers) { std::vector result; for (auto lb : livingBuffers) { @@ -545,22 +376,22 @@ vector FilterLivingBuffers( } return result; } -void SortByPrioriy(std::vector& livingBuffers) { +void SortBuffersByPrioriy(std::vector& livingBuffers) { std::sort(livingBuffers.begin(), livingBuffers.end(), [](const LivingBuffer& a, const LivingBuffer& b) { return a.living_range > b.living_range; }); } std::optional PickHighestPriorityLivingBuffer( - std::vector& livingBuffers) { + const std::vector& livingBuffers) { // step1: filter living buffers which can not reduce the peak memory or too // small buffer - auto buffers = FilterLivingBuffers(livingBuffers); + auto buffers = FilterBuffers(livingBuffers); if (buffers.size() == 0) { return std::nullopt; } // step2: sort living buffers by priority, e.g. living range value - SortByPrioriy(buffers); + SortBuffersByPrioriy(buffers); return buffers[0]; } void DiscOffloadingPass::runOnOperation() { @@ -570,17 +401,15 @@ void DiscOffloadingPass::runOnOperation() { mlir::OpBuilder b(main); const int64_t memoryLimitation = 21474836480; // 30GB llvm::dbgs() << "memory limitation: " << memoryLimitation << "\n"; - // 1. find all buffer live-range bool changed = true; int maxIteration = 200; std::unique_ptr profiler( new SymbolicMemoryProfiler(main)); - - std::unique_ptr shapeAnalysisPtr; - std::unique_ptr bufferLivingRange; - bufferLivingRange.reset(new DiscBufferLivingRange(main)); - shapeAnalysisPtr.reset(new ShapeConstraintIRAnalysis(main)); - + std::unique_ptr bufferLivingRange( + new DiscBufferLivingRange(main)); + // mapping symbolic dim(S0) in shape constrint graph to SSA value + // TODO(yancey): please note, we only support only one symbolic dim now, + // let's find a way to enhancement this. auto symS0Value = findSymbolicDimValue(main, "S0"); if (!symS0Value) { llvm::errs() << "failed to find S0 value\n"; @@ -599,17 +428,12 @@ void DiscOffloadingPass::runOnOperation() { auto memoryPeakExpr = profiler->GetPeakMemory(); int64_t minS0 = findMinSymbolicDimValue(memoryPeakExpr, memoryLimitation); - llvm::dbgs() << "min s0: " << minS0 << "\n"; - auto livingBuffer = PickHighestPriorityLivingBuffer( - std::vector & livingBuffers); - if (auto buffer = livingBuffer.value()) { - llvm::dbgs() << "living range: " << buffer.living_range << "\n"; - auto startOp = buffer.start; - auto endOp = buffer.end; - auto buffer = buffer.buffer; + auto livingBuffer = + PickHighestPriorityLivingBuffer(bufferLivingRange->GetLivingBuffers()); + if (livingBuffer.has_value()) { + auto buffer = livingBuffer.value(); auto users = bufferLivingRange->GetUsersOrderByPosition(buffer.buffer); - - auto loc = getFusionLocation(b, startOp); + auto loc = getFusionLocation(b, buffer.start); auto rematCond = InsertRematCond(b, loc, symS0Value.value(), minS0); InsertRematBlock(b, buffer, rematCond, users, minS0); changed = true; From 6d6a1d57a38a34cf83b945af6302d5e90fcbe4cc Mon Sep 17 00:00:00 2001 From: YanXu Date: Mon, 19 Aug 2024 16:45:48 +0800 Subject: [PATCH 3/4] update --- .../transforms/disc_dynamic_offloading.cc | 15 ------- .../transforms/disc_shape_optimization.cc | 14 ------- .../disc_shape_optimization_utils.cc | 41 +++++++------------ .../disc_shape_optimization_utils.h | 5 --- .../disc/transforms/lhlo_elemental_utils.cc | 3 -- .../mlir/disc/transforms/shape_utils.cc | 17 -------- .../mlir/disc/transforms/shape_utils.h | 1 - 7 files changed, 14 insertions(+), 82 deletions(-) diff --git a/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc index 5d1a3b7d2f1..4c57b894b2a 100644 --- a/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc +++ b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc @@ -252,21 +252,6 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, } } -llvm::raw_ostream& operator<<(llvm::raw_ostream& os, - const SymbolicDimProduct& prod) { - // print prod - os << prod.factor; - if (prod.symbols.size() > 0) os << "*"; - for (size_t j = 0; j < prod.symbols.size(); ++j) { - auto dimOp = prod.symbols[j]; - if (j != prod.symbols.size() - 1) { - os << dimOp.getName() << "*"; - } else { - os << dimOp.getName(); - } - } - return os; -} std::tuple solveQuadratic(int64_t a, int64_t b, int64_t c) { if (a == 0) { diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc b/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc index 906fd42727e..1d781150a4a 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc @@ -1766,27 +1766,23 @@ LogicalResult applyShapeComputationOptimization( if (failed(useSameSSAValueIfSymbolicEqual(analysis, changed))) return analysis.getFunc()->emitError( "useSameSSAValueIfSymbolicEqual failed\n"); - analysis.getFunc().dump(); // 2, After propagation some (partial) known dim size infos, refined // the ranked tensor type. if (failed(refineTensorType(analysis, changed))) return analysis.getFunc()->emitError("refineTensorType failed\n"); - analysis.getFunc().dump(); // 3, simplify some expression after propagation shape constraint info. // e.g. if symbolic dim %d is known not negative, then `arith.cmpi eq, %d, // %c-1` could be replaced with a const. if (failed(simplifyAccordingToShapeConstraintInfo(analysis, changed))) return analysis.getFunc()->emitError("fail to simplify\n"); - analysis.getFunc().dump(); // 4, inject some static known infos. For example, // - some axes of a slice op is fully sliced; // - some axes of a pad op are not padded; if (failed(injectStaticKnownInfo(analysis, changed))) return analysis.getFunc()->emitError("fail to injectStaticKnownInfo\n"); - analysis.getFunc().dump(); return success(); } @@ -1805,13 +1801,9 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, changed = false; std::chrono::steady_clock::time_point begin, end; DISC_DEBUG(begin = std::chrono::steady_clock::now()); - llvm::dbgs() << "before runCanonicalizer\n"; - main.dump(); if (failed(runCanonicalizer(m, runner))) { return failure(); } - llvm::dbgs() << "after runCanonicalizer\n"; - main.dump(); DISC_DEBUG(end = std::chrono::steady_clock::now()); DISC_DEBUG(llvm::dbgs() << " runCanonicalizer takes: " @@ -1852,8 +1844,6 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, if (failed(analysis.run())) { return m.emitError() << "fail to analysis shape computation IR\n"; } - llvm::dbgs() << "after analysis.run()\n"; - main.dump(); DISC_DEBUG(end = std::chrono::steady_clock::now()); DISC_DEBUG(llvm::dbgs() << " Building ShapeComputationIRAnalysis takes: " @@ -1866,8 +1856,6 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, if (failed(applyShapeComputationOptimization(analysis, changed))) { return m.emitError() << "fail to optimize shape computation IR\n"; } - llvm::dbgs() << "after applyShapeComputationOptimization\n"; - main.dump(); DISC_DEBUG(end = std::chrono::steady_clock::now()); DISC_DEBUG(llvm::dbgs() << " applyShapeComputationOptimization takes: " @@ -1901,8 +1889,6 @@ LogicalResult optimizeShapeComputation(ModuleOp m, FuncOp main, if (failed(runCanonicalizer(m, runner))) { return failure(); } - llvm::dbgs() << "after runCanonicalizer\n"; - main.dump(); LLVM_DEBUG(llvm::dbgs() << "Module after optimizeShapeComputation:\n" << m << "\n"); return success(); diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc index 0182da7f1fa..6f2f1e6b21d 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc @@ -84,13 +84,20 @@ bool compareSymbolicDimProduct(const SymbolicDimProduct& lhs, return false; } -// llvm::raw_ostream& operator<<(llvm::raw_ostream& os, -// const SymbolicDimProduct& product) { -// os << "SymbolicDimProduct[\n\tfactor: " << product.factor << ",\n"; -// for (auto& s : product.symbols) os << "\tsymbol: " << s << "\n"; -// os << "]\n"; -// return os; -// } +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const SymbolicDimProduct& prod) { + os << prod.factor; + if (prod.symbols.size() > 0) os << "*"; + for (size_t j = 0; j < prod.symbols.size(); ++j) { + auto dimOp = prod.symbols[j]; + if (j != prod.symbols.size() - 1) { + os << dimOp.getName() << "*"; + } else { + os << dimOp.getName(); + } + } + return os; +} SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m), symbolTable_(m_) {} @@ -255,26 +262,6 @@ SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, return std::make_pair(std::move(newLhs), std::move(newRhs)); } -SymbolicDimProduct SymbolicDimMgr::symbolicDimProductAdd( - const SymbolicDimProduct& x, const SymbolicDimProduct& y) { - SymbolicDimProduct result; - result.factor = x.factor + y.factor; - llvm::dbgs() << "x: " << x << "\n"; - llvm::dbgs() << "y: " << y << "\n"; - - // SymbolicDimProduct newLhs, newRhs; - for (auto sym : x.symbols) result.symbols.push_back(sym); - for (auto sym : y.symbols) result.symbols.push_back(sym); - llvm::dbgs() << "add result: " << result << "\n"; - auto newResult = simplifySymbolicDimProduct(result); - llvm::dbgs() << "new result: " << result << "\n"; - return newResult; -} -SymbolicDimProduct symbolicDimProductSub(const SymbolicDimProduct& x, - const SymbolicDimProduct& y) { - SymbolicDimProduct result; - return result; -} std::optional SymbolicDimMgr::symbolicDimProductDivide( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { LLVM_DEBUG(llvm::dbgs() << "Try to check if x % y == 0?\nx = " << lhs diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h index 5146c6b836a..7e126c76ced 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h @@ -361,11 +361,6 @@ class SymbolicDimMgr { std::optional symbolicDimProductDivide( const SymbolicDimProduct& x, const SymbolicDimProduct& y); - SymbolicDimProduct symbolicDimProductAdd(const SymbolicDimProduct& x, - const SymbolicDimProduct& y); - SymbolicDimProduct symbolicDimProductSub(const SymbolicDimProduct& x, - const SymbolicDimProduct& y); - // mark group [a0, b0, ...] and group [a1, b1, c1, ...] are group // multiplication equal `a0 * b0 * ... = a1 * b1 * c1 * ...` bool isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, diff --git a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc index 4da4d2d187e..fcf33a6e4e5 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc @@ -1305,8 +1305,6 @@ Value elementalLower(OpBuilder* b, Location loc, b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front()); if (i == num_input_operands - 1) { - b->create(loc, zero_element); - /* input_index[axis] = b->create(loc, out_idx, low_bound); auto operand_memref = op.getOperand(i); auto ret_value = @@ -1317,7 +1315,6 @@ Value elementalLower(OpBuilder* b, Location loc, operand_memref, input_index, lower_config); b->create(loc, ret_value); - */ } else { b->create(loc, if_inbound_ops[i + 1].getResults()); } diff --git a/tao_compiler/mlir/disc/transforms/shape_utils.cc b/tao_compiler/mlir/disc/transforms/shape_utils.cc index 08350b761f8..5ef90ada316 100644 --- a/tao_compiler/mlir/disc/transforms/shape_utils.cc +++ b/tao_compiler/mlir/disc/transforms/shape_utils.cc @@ -1989,23 +1989,6 @@ bool ShapeConstraintIRAnalysis::isShapeEqual(Value lhs, Value rhs) { return lhsSyms == rhsSyms; } -bool ShapeConstraintIRAnalysis::buildSymbolicDimProduct( - SymbolicDimProduct& prod, Value value) { - auto ty = value.getType().dyn_cast(); - auto it = memrefValue2SymDims_.find(value); - if (!ty || !ty.hasRank()) return false; - for (size_t idx = 0; idx < ty.getRank(); ++idx) { - if (ty.getShape()[idx] == ShapedType::kDynamic) { - if (it == memrefValue2SymDims_.end() || it->second.size() <= idx) - return false; - prod.symbols.push_back(it->second[idx]); - } else { - prod.factor *= ty.getShape()[idx]; - } - } - return true; -} - bool ShapeConstraintIRAnalysis::isProductEqual(Value lhs, ArrayRef lhsDimIdxs, Value rhs, diff --git a/tao_compiler/mlir/disc/transforms/shape_utils.h b/tao_compiler/mlir/disc/transforms/shape_utils.h index d43919834ba..5ed178207ab 100644 --- a/tao_compiler/mlir/disc/transforms/shape_utils.h +++ b/tao_compiler/mlir/disc/transforms/shape_utils.h @@ -213,7 +213,6 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { // rhs.shape[rd0] * rhs.shape[rd1] * ...` bool isProductEqual(Value lhs, ArrayRef lhsDimIdxs, Value rhs, ArrayRef rhsDimIdxs) override; - bool buildSymbolicDimProduct(SymbolicDimProduct& prod, Value value); private: // The operation this analysis runs on. From f610ff7ea5c0c7ebc7b7e26a4a062de9b5076534 Mon Sep 17 00:00:00 2001 From: YanXu Date: Wed, 28 Aug 2024 16:34:03 +0800 Subject: [PATCH 4/4] init async offloading --- tao_compiler/mlir/disc/disc_compiler.cc | 2 +- .../transforms/disc_dynamic_offloading.cc | 149 ++++++++++-------- .../transforms/disc_lower_to_library_call.cc | 4 +- .../mlir/disc/transforms/disc_remat_utils.cc | 26 +-- .../mlir/disc/transforms/disc_remat_utils.h | 6 +- .../mlir/disc/transforms/disc_to_llvm.cc | 53 +++++-- .../context/base/cuda/cuda_context_impl.cc | 61 ++++++- .../ral/context/base/cuda/cuda_context_impl.h | 1 + .../mlir/ral/device/gpu/gpu_driver.cc | 2 + tao_compiler/mlir/ral/device/gpu/gpu_driver.h | 2 + 10 files changed, 215 insertions(+), 91 deletions(-) diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index f329721b594..930210d809f 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -544,7 +544,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addPass(disc_ral::createRalInjectExecutionContextPass()); // pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass()); - pm.addNestedPass(disc_ral::createDiscOffloadingPass()); + // pm.addNestedPass(disc_ral::createDiscOffloadingPass()); pm.addNestedPass( disc_ral::createDiscLowerToLibraryCallPass(gpu_enabled)); pm.addPass(disc_ral::createDiscConstToRALPass(options.metadata_file_path)); diff --git a/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc index 4c57b894b2a..5ed21da8852 100644 --- a/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc +++ b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc @@ -32,6 +32,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/disc/IR/disc_ral_ops.h" #include "mlir/disc/IR/hlo_disc_ops.h" #include "mlir/disc/IR/lhlo_disc_ops.h" #include "mlir/disc/disc_util.h" @@ -76,52 +77,10 @@ using FuncOp = mlir::func::FuncOp; SymbolicDimProduct getSymbolicMemRefSize(Value value, SymbolicDimMgr* mgr, MLIRContext* ctx) { auto memRefType = value.getType().cast(); - // get symbolic dims of the memref - // auto symbolics = getSymbolicDims(value); auto symbolics = mgr->getOrCreateSymbolicDimsForRankedValue(value); SymbolicDimProduct prod{symbolics}; return prod; } -/* -int64_t getMemRefSize(Value value) { - auto memRefType = value.getType().cast(); - int64_t elementSize = getElementSize(memRefType.getElementType()); - if (elementSize < 0) { - return -1; // Unsupported type - } - - int64_t numElements = 1; - for (int64_t dim : memRefType.getShape()) { - numElements *= dim; - } - return numElements * elementSize; -} -*/ - -bool shouldSkipBufferInPeakMemoryEstimator(Value value) { - if (IsHostBuffer(value)) return true; - - // skip buffer if it is a temp buffer which only used inside of a fusion op - // alloc = memref.alloc - // lmhlo.fusion() { - // op1(buffer0, buffer1, alloc) - // op2(alloc, buffer2, buffer3) - // } - // dealloc = memref.dealloc alloc - Operation* prevFusionOp = nullptr; - for (auto user : value.getUsers()) { - if (isa(user)) continue; - if (auto fusionOp = dyn_cast(user->getParentOp())) { - if (!prevFusionOp) prevFusionOp = fusionOp; - if (prevFusionOp != fusionOp) { - return false; - } - } else { - return false; - } - } - return true; -} bool IsDynamicShapeBuffer(Value buffer) { auto memrefType = buffer.getType().cast(); @@ -133,34 +92,61 @@ bool IsDynamicShapeBuffer(Value buffer) { return false; } Value cloneBuffer(OpBuilder& b, Location loc, Value buffer) { - MemRefType type = refBuffer.getType().cast(); + MemRefType type = buffer.getType().cast(); SmallVector dynShape; for (size_t i = 0; i < type.getRank(); i++) { if (type.getShape()[i] == ShapedType::kDynamic) { - dynShape.push_back(b.create(loc, refBuffer, i)); + dynShape.push_back(b.create(loc, buffer, i)); } } auto allocOp = b.create(loc, type, dynShape); StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); - if (refBuffer.getDefiningOp()->hasAttr(attrName)) { - allocOp.getOperation()->setAttr( - attrName, refBuffer.getDefiningOp()->getAttr(attrName)); + if (buffer.getDefiningOp()->hasAttr(attrName)) { + allocOp.getOperation()->setAttr(attrName, + buffer.getDefiningOp()->getAttr(attrName)); } return allocOp.getResult(); } - +Value GetContextValueFromFunctionArguments(Operation* op) { + Value ctx; + if (auto func = op->getParentOfType()) { + if (func.getArgument(0).getType().isa()) { + return func.getArgument(0); + } + op->emitError() << "Argument#0 must be RalExecutionContextType."; + } + return ctx; +} +Value GetDefaultStreamHandle(Operation* op, OpBuilder& rewriter) { + Location loc = op->getLoc(); + MLIRContext* ctx = rewriter.getContext(); + Type llvm_int32_type = IntegerType::get(ctx, 32); + Value zero = rewriter.create(loc, llvm_int32_type, + rewriter.getI32IntegerAttr(0)); + Type pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + Value stream_idx = rewriter.create(loc, pointer_type, zero); + return stream_idx; +} // InsertRematBlock create a remat block for the living buffer: // reload and offload blocks are always pair in graph: // -// if remat_cond: -// return offload(buffer) +// if remat_cond: // pre-offloading block +// buffer = pin_memory_alloc() +// +// if remat_cond: // offload block +// wait_on_stream() +// host_buffer = offload(buffer) +// return host_bufer // else: // return dummy_buffer // ...... -// if remat_cond: +// if remat_cond: // async reload block // return reload(buffer) // else: // return buffer +// +// if remat_cond: // wait reload block +// stream_wait(host_buffer) void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, LivingBuffer& livingBuffer, Value rematCond, @@ -234,10 +220,10 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, b.create(endOp->getLoc(), deviceBuffer); reloadIfOp.getElseRegion().front().clear(); b.setInsertionPointToStart(&reloadIfOp.getElseRegion().front()); - auto dummyDeviceBuffer = cloneBuffer(b, endOp->getLoc(), buffer); - dummyDeviceBuffer.setType(deviceMemrefType); - dummyDeviceBuffer.getDefiningOp()->setAttr(kRematBufferAttr, - b.getBoolAttr(true)); + // auto dummyDeviceBuffer = cloneBuffer(b, endOp->getLoc(), buffer); + // dummyDeviceBuffer.setType(deviceMemrefType); + // dummyDeviceBuffer.getDefiningOp()->setAttr(kRematBufferAttr, + // b.getBoolAttr(true)); b.create(endOp->getLoc(), buffer); for (auto pair : ops) { auto op = pair.first; @@ -250,6 +236,27 @@ void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, } } } + if (auto fusionOp = dyn_cast(endOp->getParentOp())) { + b.setInsertionPoint(fusionOp); + } else { + b.setInsertionPoint(endOp); + } + + // insert reload sync block + scf::IfOp waitIfOp = b.create( + endOp->getLoc(), /*resultTypes*/ ArrayRef{}, rematCond, + /*hasElseRegion*/ false); + waitIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("reload_sync")); + b.setInsertionPointToStart(&waitIfOp.getThenRegion().front()); + auto ctx = GetContextValueFromFunctionArguments(endOp); + Value stream_handle = GetDefaultStreamHandle(endOp, b); + auto sync_op = + b.create(endOp->getLoc(), TypeRange{}, ctx, stream_handle, + "sync_on_stream", false, "gpu"); + + // b.create(endOp->getLoc(), offloadIfOp.getResult(0)); + b.setInsertionPointAfter(waitIfOp); } std::tuple solveQuadratic(int64_t a, int64_t b, @@ -350,13 +357,19 @@ std::vector FilterBuffers( std::vector livingBuffers) { std::vector result; for (auto lb : livingBuffers) { - // TODO(yancey): just for experiment, let's remove this condition in the - // future - if (!IsDynamicShapeBuffer(lb.buffer)) { - continue; - } + // filter buffer which living range is too small + if (lb.living_range < 1000) continue; + // filter scalar buffer + auto type = lb.buffer.getType().cast(); + if (type.getRank() == 0) continue; // filter buffer if already in remat block if (isa((lb.start->getParentOp()))) continue; + bool isEscapsedBuffer = std::any_of( + lb.buffer.getUsers().begin(), lb.buffer.getUsers().end(), + [](auto user) { + return isa(user); + }); + if (isEscapsedBuffer) continue; result.push_back(lb); } return result; @@ -384,12 +397,21 @@ void DiscOffloadingPass::runOnOperation() { if (main.getName() == SymbolicDimMgr::getShapeConstraintGraphFunctionName()) return; mlir::OpBuilder b(main); - const int64_t memoryLimitation = 21474836480; // 30GB + // TOOD(yancey): using a ratio to control the memory limitation + const int64_t memoryLimitation = 32212254720; // 30GB llvm::dbgs() << "memory limitation: " << memoryLimitation << "\n"; bool changed = true; - int maxIteration = 200; + int maxIteration = 100; + std::unique_ptr shapeAnalysisPtr( + new ShapeConstraintIRAnalysis(main)); + auto shapeIRAnalysis = + dynamic_cast(shapeAnalysisPtr.get()); + if (!shapeIRAnalysis) { + llvm::errs() << "shape analysis failed\n"; + return; + } std::unique_ptr profiler( - new SymbolicMemoryProfiler(main)); + new SymbolicMemoryProfiler(main, *shapeIRAnalysis)); std::unique_ptr bufferLivingRange( new DiscBufferLivingRange(main)); // mapping symbolic dim(S0) in shape constrint graph to SSA value @@ -424,6 +446,7 @@ void DiscOffloadingPass::runOnOperation() { changed = true; } } + main.dump(); } std::unique_ptr> createDiscOffloadingPass() { return std::make_unique(); diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc index a3de94b8677..3f0adef4bdd 100755 --- a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc @@ -184,7 +184,9 @@ struct GpuCopyOpConvertor : public OpRewritePattern { op.getLoc(), TypeRange{}, ctx, newOperands, target_, false, "gpu"); // TODO(disc): Re-visit this is necessary. // TODO(disc): add a pass to merge sync_on_stream call. - InsertSyncOnStream(op, ctx, stream_handle, rewriter); + if (!isa(op->getParentOp())) { + InsertSyncOnStream(op, ctx, stream_handle, rewriter); + } rewriter.replaceOp(op, newOp.getResults()); return success(); } diff --git a/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc b/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc index 268ddb0433a..0ccabd76c34 100644 --- a/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc +++ b/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc @@ -222,15 +222,23 @@ int64_t getMemRefSize(Value value) { return numElements * elementSize; } SymbolicDimProduct getSymbolicMemrefBytes(Value buffer, SymbolicDimMgr* mgr) { + if (!mgr) { + llvm::errs() << "mgr is nullptr\n"; + return SymbolicDimProduct{}; + } auto s0 = mgr->findSymbolicDimOp("S0"); + if (!s0.has_value()) { + llvm::errs() << "SymbolicDimOp S0 not found\n"; + } auto ty = buffer.getType().cast(); auto elementBytes = getElementSize(ty.getElementType()); if (hasDynamicDimension(buffer)) { if (auto recvOp = dyn_cast(buffer.getDefiningOp())) { // get first user of the buffer - if (auto rcastOp = - dyn_cast(*buffer.getUsers().begin())) { - buffer = rcastOp.getResult(); + for (auto user : buffer.getUsers()) { + if (isa(user)) { + buffer = user->getResult(0); + } } } auto symDims = getMemRefValueSymbolicDims(*mgr, buffer); @@ -382,15 +390,11 @@ std::vector ConcretMemoryUsageSimulator(int64_t concretValue) { LogicalResult SymbolicMemoryProfiler::Analysis() { mlir::OpBuilder b(main_); - std::unique_ptr shapeAnalysisPtr; - shapeAnalysisPtr.reset(new ShapeConstraintIRAnalysis(main_)); - auto shapeIRAnalysis = - dynamic_cast(shapeAnalysisPtr.get()); - if (!shapeIRAnalysis) { - llvm::errs() << "shape analysis failed\n"; + mgr_ = &shapeAnalysis_.symbolicDimMgr(); + if (!mgr_) { + llvm::errs() << "mgr is nullptr\n"; return failure(); } - mgr_ = &shapeIRAnalysis->symbolicDimMgr(); memory_usage_list_.clear(); MemoryUsage currentUsage; @@ -438,7 +442,7 @@ LogicalResult SymbolicMemoryProfiler::Analysis() { // NOTE: search peak memory expr with fake symbolic value, please // note, it's a fuzzy search algorithm, the result may not be accurate // but it's good enough for most cases - peak_memory_ = findPeakMemoryWithFakeValue(memory_usage_list_, 2048); + peak_memory_ = findPeakMemoryWithFakeValue(memory_usage_list_, 4096); return success(); } diff --git a/tao_compiler/mlir/disc/transforms/disc_remat_utils.h b/tao_compiler/mlir/disc/transforms/disc_remat_utils.h index cddaf665a2d..b825d6ae59e 100644 --- a/tao_compiler/mlir/disc/transforms/disc_remat_utils.h +++ b/tao_compiler/mlir/disc/transforms/disc_remat_utils.h @@ -88,7 +88,6 @@ class DiscBufferLivingRange { using MemoryUsage = llvm::SmallVector; -// MemoryUsage operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); MemoryUsage& operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); MemoryUsage& operator-=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); llvm::raw_ostream& operator<<(llvm::raw_ostream& os, @@ -98,7 +97,9 @@ llvm::raw_ostream& operator<<(llvm::raw_ostream& os, // mlir function, it will return the peak memory usage and the memory usage class SymbolicMemoryProfiler { public: - explicit SymbolicMemoryProfiler(mlir::func::FuncOp& main) : main_(main) {} + explicit SymbolicMemoryProfiler(mlir::func::FuncOp& main, + ShapeConstraintIRAnalysis& shapeAnalysis) + : main_(main), shapeAnalysis_(shapeAnalysis) {} LogicalResult Analysis(); MemoryUsage GetPeakMemory() { return peak_memory_; } std::vector GetMemoryUsageList() { return memory_usage_list_; } @@ -122,6 +123,7 @@ class SymbolicMemoryProfiler { mlir::func::FuncOp main_; MemoryUsage peak_memory_; std::vector memory_usage_list_; + ShapeConstraintIRAnalysis& shapeAnalysis_; }; } // namespace disc_ral diff --git a/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc b/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc index 5e4e84c400b..b8d45161738 100644 --- a/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc +++ b/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc @@ -67,6 +67,8 @@ constexpr const char* kRalDispatchFunctionName = "disc_ral_call"; constexpr const char* kRalGpuLaunch = "ral_kernel_launch"; constexpr const char* kRalCpuLaunch = "ral_kernel_launch"; constexpr const char* kMalloc = "alloc"; +constexpr const char* kPinnedMalloc = "ral_gpu_pinned_alloc"; +constexpr const char* kPinnedFree = "ral_gpu_pinned_dealloc"; constexpr const char* kFree = "dealloc"; constexpr const char* kRalCompIntensFusion = "ral_comp_intens_fusion"; @@ -736,11 +738,28 @@ LogicalResult ConvertMemRefAllocOpToDispatchOpPattern::matchAndRewrite( getMemRefDescriptorSizes(loc, memref_type, llvm::to_vector<4>(adaptor.getOperands()), rewriter, sizes, strides, sizeBytes); - - // create dispatch op - auto dispatch_op = rewriter.create( - loc, getVoidPtrType(), context_arg, sizeBytes, kMalloc, false, device); - Value allocated_byte_ptr = dispatch_op.getResult(0); + // using pinned memory to overlap data transfer and computation kernels + bool pinnedMemory = false; + for (auto user : memref.getUsers()) { + if (auto dispatch = dyn_cast(user)) { + if (dispatch.getCallTargetName() == "d2h") { + pinnedMemory = true; + break; + } + } + } + StringRef targetName = kMalloc; + Operation* dispatch_op; + if (device == "cpu" && pinnedMemory) { + // create dispatch op + dispatch_op = rewriter.create( + loc, getVoidPtrType(), context_arg, sizeBytes, kPinnedMalloc, false, + "gpu"); + } else { + dispatch_op = rewriter.create( + loc, getVoidPtrType(), context_arg, sizeBytes, kMalloc, false, device); + } + Value allocated_byte_ptr = dispatch_op->getResult(0); // Create the MemRef descriptor. MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( @@ -793,10 +812,26 @@ LogicalResult ConvertMemRefDeallocOpToDispatchOpPattern::matchAndRewrite( MemRefDescriptor memref(adaptor.getMemref()); Value allocated_bytes_ptr = rewriter.create( loc, getVoidPtrType(), memref.allocatedPtr(rewriter, loc)); - - ModuleOp module = op->getParentOfType(); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, context_arg, allocated_bytes_ptr, kFree, false, device); + bool pinnedMemory = false; + for (auto user : dealloc_op.getMemref().getUsers()) { + if (auto dispatch = dyn_cast(user)) { + if (dispatch.getCallTargetName() == "h2d") { + pinnedMemory = true; + break; + } + } + } + if (device == "cpu" && pinnedMemory) { + // using pinned memory to achive higher performance to transfer data between + // host and device + rewriter.replaceOpWithNewOp( + dealloc_op, TypeRange{}, context_arg, allocated_bytes_ptr, kPinnedFree, + false, "gpu"); + } else { + rewriter.replaceOpWithNewOp( + dealloc_op, TypeRange{}, context_arg, allocated_bytes_ptr, kFree, false, + device); + } return success(); } diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc index 374d39103ce..616250b267e 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -126,6 +126,7 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { std::map, GpuFunctionHandle> kernels; std::shared_ptr gpu_allocator; + std::shared_ptr cuda_host_allocator; bool cache_workspace_mem_across_execution; #ifdef TAO_RAL_USE_STREAM_EXECUTOR ::stream_executor::Stream* se_stream; @@ -183,6 +184,7 @@ std::unique_ptr MakeBaseCudaContext( } else { state->gpu_allocator.reset(new InternalAllocator(gpu_alloc, gpu_dealloc)); } + state->cuda_host_allocator = gpu_opt.cuda_host_allocator; state->cache_workspace_mem_across_execution = opt.cache_workspace_mem_across_execution; @@ -268,7 +270,18 @@ buffer_t ral_base_cuda_alloc(ExecutionContext* ctx, size_t bytes) { exec_ctx->device_ptr_map.insert(std::make_pair(ptr, 1)); return ptr; } - +buffer_t ral_base_cuda_pinned_alloc(ExecutionContext* ctx, size_t bytes) { + auto* state = + ctx->getResource(kRalBaseCudaContextState); + auto exec_ctx = dynamic_cast(ctx); + std::lock_guard lock(state->mu); + TAO_VLOG(1) << "before ral_base_cuda_pinned_alloc alloc " << bytes; + bytes = (bytes ? bytes : 1); + void* ptr = state->cuda_host_allocator->alloc(bytes); + TAO_VLOG(1) << "after ral_base_cuda_pinned_alloc with ptr= " << ptr; + exec_ctx->host_ptr_map.insert(std::make_pair(ptr, 1)); + return ptr; +} void ral_base_cuda_memset(ExecutionContext* ctx, stream_t handle, buffer_t buffer, int value, size_t bytes) { if (!buffer) { @@ -321,6 +334,32 @@ void ral_base_cuda_dealloc(ExecutionContext* ctx, buffer_t buffer) { } TAO_VLOG(1) << "after ral_base_cuda_dealloc with ptr = " << buffer; } +void ral_base_cuda_pinned_dealloc(ExecutionContext* ctx, buffer_t buffer) { + /* + if (!buffer) { + TAO_VLOG(1) << "ral_base_cuda_dealloc early return for nullptr"; + return; + } + + auto* state = + ctx->getResource(kRalBaseCudaContextState); + auto exec_ctx = dynamic_cast(ctx); + + std::lock_guard lock(state->mu); + TAO_VLOG(1) << "before ral_base_cuda_pinned_dealloc with ptr = " << buffer; + if (state->device_persistent_buffers.count(buffer)) return; + auto it = exec_ctx->host_ptr_map.find(buffer); + CHECK(it != exec_ctx->host_ptr_map.end()); + if (--it->second == 0) { + cudaFreeHost(buffer); + exec_ctx->host_ptr_map.erase(buffer); + TAO_VLOG(1) << "delete buffer after ref-count becoming zero"; + } + TAO_VLOG(1) << "after ral_base_cuda_pinned_dealloc with ptr = " << buffer; + */ + + TAO_VLOG(1) << "after ral_base_cuda_pinned_dealloc with ptr = " << buffer; +} buffer_t ral_base_cuda_raw_alloc(Context* ctx, size_t bytes) { auto* state = static_cast( @@ -498,6 +537,8 @@ void ral_base_cuda_sync_on_stream(ExecutionContext* ctx, stream_t sidx) { ctx->signalError(Context::FAILURE, "not a valid stream idx"); return; } + auto comm_stream = + static_cast(ctx)->getCommStream(); auto* state = ctx->getResource(kRalBaseCudaContextState); @@ -505,7 +546,7 @@ void ral_base_cuda_sync_on_stream(ExecutionContext* ctx, stream_t sidx) { reportErrorIfAny(stream_executor::wrap::hipStreamSynchronize(state->stream), ctx, "StreamSync"); #else - reportErrorIfAny(cuStreamSynchronize(state->stream), ctx, "StreamSync"); + reportErrorIfAny(cuStreamSynchronize(comm_stream), ctx, "StreamSync"); #endif } @@ -626,6 +667,10 @@ ::tao::ral::MemRefType ral_base_cuda_bitcast_0d( void ral_base_cuda_h2d(ExecutionContext* ctx, void* stream_handle, const void* h_src, buffer_t d_dst, size_t bytes) { + TAO_VLOG(1) << "ral_base_cuda_h2d, h_src: " << h_src << ", d_dst: " << d_dst + << ", bytes: " << bytes; + auto comm_stream = + static_cast(ctx)->getCommStream(); auto* state = ctx->getResource(kRalBaseCudaContextState); #if TENSORFLOW_USE_ROCM @@ -635,15 +680,19 @@ void ral_base_cuda_h2d(ExecutionContext* ctx, void* stream_handle, ctx, "cuMemcpyHtoDAsync"); #else reportErrorIfAny( - cuMemcpyHtoDAsync((GpuDevicePtr)d_dst, h_src, bytes, state->stream), ctx, + cuMemcpyHtoDAsync((GpuDevicePtr)d_dst, h_src, bytes, comm_stream), ctx, "cuMemcpyHtoDAsync"); #endif } void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle, buffer_t d_src, buffer_t h_dst, size_t bytes) { + TAO_VLOG(1) << "ral_base_cuda_d2h, d_src: " << d_src << ", h_dst: " << h_dst + << ", bytes: "; auto* state = ctx->getResource(kRalBaseCudaContextState); + auto comm_stream = + static_cast(ctx)->getCommStream(); #if TENSORFLOW_USE_ROCM reportErrorIfAny( stream_executor::wrap::hipMemcpyDtoHAsync( @@ -651,7 +700,7 @@ void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle, ctx, "cuMemcpyDtoHAsync"); #else reportErrorIfAny( - cuMemcpyDtoHAsync(h_dst, (GpuDevicePtr)d_src, bytes, state->stream), ctx, + cuMemcpyDtoHAsync(h_dst, (GpuDevicePtr)d_src, bytes, comm_stream), ctx, "cuMemcpyDtoHAsync"); #endif } @@ -745,10 +794,14 @@ RAL_REGISTER_BITCAST_FUNC(bool, 7); RAL_REGISTER_BITCAST_FUNC(bool, 8); TAO_RAL_API(tao::ral::gpu::kRalGpuAlloc, "gpu", ral_base_cuda_alloc); +TAO_RAL_API(tao::ral::gpu::kRalGpuPinnedAlloc, "gpu", + ral_base_cuda_pinned_alloc); TAO_RAL_API(tao::ral::gpu::kRalGpuAllocPersistent, "gpu", ral_base_cuda_alloc_persistent); TAO_RAL_API(tao::ral::gpu::kRalGpuDealloc, "gpu", ral_base_cuda_dealloc); TAO_RAL_API(tao::ral::gpu::kRalGpuRawAlloc, "gpu", ral_base_cuda_raw_alloc); +TAO_RAL_API(tao::ral::gpu::kRalGpuPinnedDealloc, "gpu", + ral_base_cuda_pinned_dealloc); TAO_RAL_API(tao::ral::gpu::kRalGpuRawDealloc, "gpu", ral_base_cuda_raw_dealloc); TAO_RAL_API(tao::ral::gpu::kRalGpuLaunch, "gpu", ral_base_cuda_launch); TAO_RAL_API(tao::ral::gpu::kRalGpuGetStream, "gpu", ral_base_cuda_get_stream); diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h index 8d1fd2a7293..822984dfb48 100755 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h @@ -55,6 +55,7 @@ struct BaseCudaContextOption { bool use_stream_executor = true; bool cache_workspace_mem_across_execution = false; std::shared_ptr gpu_allocator; + std::shared_ptr cuda_host_allocator; }; std::unique_ptr MakeBaseCudaContext( diff --git a/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc b/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc index d7dad76b545..45e4605074a 100644 --- a/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc +++ b/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc @@ -28,6 +28,8 @@ namespace gpu { const char* kRalGpuAlloc = "alloc"; const char* kRalGpuAllocPersistent = "ral_gpu_alloc_persistent"; +const char* kRalGpuPinnedAlloc = "ral_gpu_pinned_alloc"; +const char* kRalGpuPinnedDealloc = "ral_gpu_pinned_dealloc"; const char* kRalGpuDealloc = "dealloc"; const char* kRalGpuRawAlloc = "raw_gpu_alloc"; const char* kRalGpuRawDealloc = "raw_gpu_dealloc"; diff --git a/tao_compiler/mlir/ral/device/gpu/gpu_driver.h b/tao_compiler/mlir/ral/device/gpu/gpu_driver.h index eb934cd6e84..9fc48ec49de 100644 --- a/tao_compiler/mlir/ral/device/gpu/gpu_driver.h +++ b/tao_compiler/mlir/ral/device/gpu/gpu_driver.h @@ -35,6 +35,8 @@ using stream_t = void*; extern const char* kRalGpuAlloc; extern const char* kRalGpuAllocPersistent; +extern const char* kRalGpuPinnedAlloc; +extern const char* kRalGpuPinnedDealloc; extern const char* kRalGpuDealloc; extern const char* kRalGpuRawAlloc; extern const char* kRalGpuRawDealloc;