diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index d3dfd3fbc9b..77baa72c29b 100755 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -341,6 +341,7 @@ cc_library( ":codegen_utils", ":disc_shape_optimization_utils", ":disc_util", + ":disc_offloading", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:map_lmhlo_to_scalar_op", @@ -2228,7 +2229,35 @@ cc_library( ], alwayslink = 1, ) - +cc_library( + name = "disc_offloading", + srcs = [ + "transforms/disc_dynamic_offloading.cc", + "transforms/disc_remat_utils.cc" + ], + hdrs = [ + "transforms/passes.h", + "transforms/disc_remat_utils.h" + ], + includes = ["include"], + deps = [ + ":disc_ral", + ":disc_util", + ":mhlo_disc", + ":pass_details", + ":shape_utils", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo", + "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:map_lmhlo_to_scalar_op", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Transforms", + ] +) cc_library( name = "disc_custom_call_rewriter", srcs = ["transforms/disc_custom_call_rewriter.cc"], diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index e8f0eb06ae3..930210d809f 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -242,10 +242,9 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); - + pm.addPass(disc_ral::createDiscShapePropagatePass()); pm.addNestedPass(disc_ral::createDiscAlgebraicSimplifierPass()); pm.addPass(disc_ral::createDiscInputOutputAliasPass()); - pm.addPass(disc_ral::createDiscShapePropagatePass()); pm.addPass(mlir::createInlinerPass()); // TODO(disc): Lower HLO shape constraints instead of eliding them here. pm.addNestedPass(disc_ral::createDiscCollectiveOpsRewriterPass()); @@ -269,8 +268,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass( disc_ral::createDiscLowerQuantizeAndDequantizePass()); } - bool enable_shape_constraint_ir = useShapeConstraintIR(); + if (!enable_shape_constraint_ir) { // propagate some known shape information. pm.addPass(disc_ral::createDiscShapeSimplifierPass()); @@ -279,7 +278,6 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { // shape-related optimization pm.addPass(disc_ral::createDiscShapeOptimizationPass()); } - pm.addNestedPass(disc_ral::createDiscConvertTensorToStandardPass()); pm.addNestedPass(disc_ral::createDiscConvertHloToStandardPass()); pm.addNestedPass(createCanonicalizerPass()); @@ -500,9 +498,9 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { if (gpu_enabled) { // TODO: Support cpu stitch with splat const pm.addNestedPass(disc_ral::createDiscFuseSplatConstPass()); - pm.addNestedPass( - disc_ral::createDiscSpecializeFusionWithSpeculationPass( - gpu_options.sm_count, gpu_options.max_threads_per_sm)); + // pm.addNestedPass( + // disc_ral::createDiscSpecializeFusionWithSpeculationPass( + // gpu_options.sm_count, gpu_options.max_threads_per_sm)); } else { pm.addNestedPass( disc_ral::createDiscDuplicateComputationAfterFusionPass()); @@ -545,6 +543,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(disc_ral::createDiscBufferDeallocationPass()); pm.addPass(disc_ral::createRalInjectExecutionContextPass()); + // pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass()); + // pm.addNestedPass(disc_ral::createDiscOffloadingPass()); pm.addNestedPass( disc_ral::createDiscLowerToLibraryCallPass(gpu_enabled)); pm.addPass(disc_ral::createDiscConstToRALPass(options.metadata_file_path)); diff --git a/tao_compiler/mlir/disc/disc_compiler_main.cc b/tao_compiler/mlir/disc/disc_compiler_main.cc index 366eebb88b1..97f065366c3 100644 --- a/tao_compiler/mlir/disc/disc_compiler_main.cc +++ b/tao_compiler/mlir/disc/disc_compiler_main.cc @@ -210,11 +210,13 @@ int RealMain() { << " s.\n"; llvm::dbgs() << "[[ INFO ]] Running TF2XLA\n"; + /* auto s = tensorflow::ConvertTF2MlirHlo(module); if (!s.ok()) { llvm::dbgs() << "ConvertTF2MlirHlo failed: " << s.ToString() << "\n"; return 1; } + */ if (VLOG_IS_ON(0)) { llvm::dbgs() << "======== BEGIN After TF2HLO =========\n"; diff --git a/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc new file mode 100644 index 00000000000..5ed21da8852 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_dynamic_offloading.cc @@ -0,0 +1,456 @@ +// Copyright 2024 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/strings/str_split.h" +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project +#include "mlir/IR/Dominance.h" +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" // TF:llvm-project +#include "mlir/disc/IR/disc_ral_ops.h" +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" +#include "mlir/disc/transforms/disc_remat_utils.h" +#include "mlir/disc/transforms/shape_utils.h" + +namespace mlir { +namespace disc_ral { +constexpr StringRef kRematBlockTypeAttr = "disc.remat.type"; +constexpr StringRef kRematBufferAttr = "disc.remat.is_dummy_buffer"; +constexpr StringRef kRematMinSymDim = "disc.remat.min_symbolic_dim"; + +struct DiscOffloadingPass : public DiscOffloadingPassBase { + DiscOffloadingPass() + : DiscOffloadingPassBase::DiscOffloadingPassBase() {} + void getDependentDialects(DialectRegistry& registry) const override { + DiscOffloadingPassBase::getDependentDialects(registry); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + void runOnOperation() override; + void InsertRematBlock(mlir::OpBuilder& b, LivingBuffer& livingBuffer, + Value rematCond, std::vector& ops, + int64_t minSymValue); +}; +Location getFusionLocation(OpBuilder& b, Operation* op) { + if (auto fusionOp = dyn_cast(op->getParentOp())) { + b.setInsertionPointAfter(fusionOp); + return fusionOp->getLoc(); + } + b.setInsertionPointAfter(op); + return op->getLoc(); +} + +using FuncOp = mlir::func::FuncOp; + +SymbolicDimProduct getSymbolicMemRefSize(Value value, SymbolicDimMgr* mgr, + MLIRContext* ctx) { + auto memRefType = value.getType().cast(); + auto symbolics = mgr->getOrCreateSymbolicDimsForRankedValue(value); + SymbolicDimProduct prod{symbolics}; + return prod; +} + +bool IsDynamicShapeBuffer(Value buffer) { + auto memrefType = buffer.getType().cast(); + for (auto dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) { + return true; + } + } + return false; +} +Value cloneBuffer(OpBuilder& b, Location loc, Value buffer) { + MemRefType type = buffer.getType().cast(); + SmallVector dynShape; + for (size_t i = 0; i < type.getRank(); i++) { + if (type.getShape()[i] == ShapedType::kDynamic) { + dynShape.push_back(b.create(loc, buffer, i)); + } + } + auto allocOp = b.create(loc, type, dynShape); + StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); + if (buffer.getDefiningOp()->hasAttr(attrName)) { + allocOp.getOperation()->setAttr(attrName, + buffer.getDefiningOp()->getAttr(attrName)); + } + return allocOp.getResult(); +} +Value GetContextValueFromFunctionArguments(Operation* op) { + Value ctx; + if (auto func = op->getParentOfType()) { + if (func.getArgument(0).getType().isa()) { + return func.getArgument(0); + } + op->emitError() << "Argument#0 must be RalExecutionContextType."; + } + return ctx; +} +Value GetDefaultStreamHandle(Operation* op, OpBuilder& rewriter) { + Location loc = op->getLoc(); + MLIRContext* ctx = rewriter.getContext(); + Type llvm_int32_type = IntegerType::get(ctx, 32); + Value zero = rewriter.create(loc, llvm_int32_type, + rewriter.getI32IntegerAttr(0)); + Type pointer_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + Value stream_idx = rewriter.create(loc, pointer_type, zero); + return stream_idx; +} +// InsertRematBlock create a remat block for the living buffer: +// reload and offload blocks are always pair in graph: +// +// if remat_cond: // pre-offloading block +// buffer = pin_memory_alloc() +// +// if remat_cond: // offload block +// wait_on_stream() +// host_buffer = offload(buffer) +// return host_bufer +// else: +// return dummy_buffer +// ...... +// if remat_cond: // async reload block +// return reload(buffer) +// else: +// return buffer +// +// if remat_cond: // wait reload block +// stream_wait(host_buffer) +void DiscOffloadingPass::InsertRematBlock(mlir::OpBuilder& b, + LivingBuffer& livingBuffer, + Value rematCond, + std::vector& ops, + int64_t minSymValue) { + // TODO(yancey): we need a custom operator to handle the dummy buffer to avoid + // call alloc or dealloc operator + auto buffer = livingBuffer.buffer; + auto startOp = livingBuffer.start; + auto endOp = livingBuffer.end; + StringRef attrName = SymbolicDimOp::getSymbolicDimAttrName(); + auto deviceMemrefType = buffer.getType().cast(); + auto hostMemrefType = MemRefType::get( + deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0); + // insert offload block + auto offloadIfOp = + b.create(startOp->getLoc(), + /*resultTypes*/ hostMemrefType, rematCond, + /*hasElseRegion*/ true); + offloadIfOp.getOperation()->setAttr( + attrName, buffer.getDefiningOp()->getAttr(attrName)); + offloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("offload")); + offloadIfOp.getOperation()->setAttr(kRematMinSymDim, + b.getI64IntegerAttr(minSymValue)); + offloadIfOp.getThenRegion().front().clear(); + b.setInsertionPointToEnd(&offloadIfOp.getThenRegion().front()); + auto hostBuffer = cloneBuffer(b, endOp->getLoc(), buffer); + hostBuffer.setType(MemRefType::get(deviceMemrefType.getShape(), + deviceMemrefType.getElementType(), {}, 0)); + b.create(endOp->getLoc(), buffer, hostBuffer); + b.create(endOp->getLoc(), buffer); + b.create(endOp->getLoc(), hostBuffer); + + offloadIfOp.getElseRegion().front().clear(); + b.setInsertionPointToStart(&offloadIfOp.getElseRegion().front()); + auto dummyHostBuffer = cloneBuffer(b, endOp->getLoc(), buffer); + dummyHostBuffer.getDefiningOp()->setAttr(kRematBufferAttr, + b.getBoolAttr(true)); + dummyHostBuffer.setType(MemRefType::get( + deviceMemrefType.getShape(), deviceMemrefType.getElementType(), {}, 0)); + b.create(endOp->getLoc(), dummyHostBuffer); + b.setInsertionPointAfter(offloadIfOp); + + if (auto fusionOp = dyn_cast(endOp->getParentOp())) { + b.setInsertionPoint(fusionOp); + } else { + b.setInsertionPoint(endOp); + } + + // insert reload block + scf::IfOp reloadIfOp = + b.create(endOp->getLoc(), + /*resultTypes*/ deviceMemrefType, rematCond, + /*hasElseRegion*/ true); + if (buffer.getDefiningOp()->hasAttr(attrName)) { + reloadIfOp.getOperation()->setAttr( + attrName, buffer.getDefiningOp()->getAttr(attrName)); + } + reloadIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("reload")); + reloadIfOp.getOperation()->setAttr(kRematMinSymDim, + b.getI64IntegerAttr(minSymValue)); + reloadIfOp.getThenRegion().front().clear(); + b.setInsertionPointToStart(&reloadIfOp.getThenRegion().front()); + + auto deviceBuffer = cloneBuffer(b, endOp->getLoc(), buffer); + deviceBuffer.setType(deviceMemrefType); + auto h2dOp = b.create( + endOp->getLoc(), offloadIfOp.getResult(0), deviceBuffer); + b.create(endOp->getLoc(), deviceBuffer); + reloadIfOp.getElseRegion().front().clear(); + b.setInsertionPointToStart(&reloadIfOp.getElseRegion().front()); + // auto dummyDeviceBuffer = cloneBuffer(b, endOp->getLoc(), buffer); + // dummyDeviceBuffer.setType(deviceMemrefType); + // dummyDeviceBuffer.getDefiningOp()->setAttr(kRematBufferAttr, + // b.getBoolAttr(true)); + b.create(endOp->getLoc(), buffer); + for (auto pair : ops) { + auto op = pair.first; + auto position = pair.second; + if (position > livingBuffer.end_position) { + for (size_t i = 0; i < op->getNumOperands(); i++) { + if (op->getOperand(i) == buffer) { + op->setOperand(i, reloadIfOp.getResult(0)); + } + } + } + } + if (auto fusionOp = dyn_cast(endOp->getParentOp())) { + b.setInsertionPoint(fusionOp); + } else { + b.setInsertionPoint(endOp); + } + + // insert reload sync block + scf::IfOp waitIfOp = b.create( + endOp->getLoc(), /*resultTypes*/ ArrayRef{}, rematCond, + /*hasElseRegion*/ false); + waitIfOp.getOperation()->setAttr(kRematBlockTypeAttr, + b.getStringAttr("reload_sync")); + b.setInsertionPointToStart(&waitIfOp.getThenRegion().front()); + auto ctx = GetContextValueFromFunctionArguments(endOp); + Value stream_handle = GetDefaultStreamHandle(endOp, b); + auto sync_op = + b.create(endOp->getLoc(), TypeRange{}, ctx, stream_handle, + "sync_on_stream", false, "gpu"); + + // b.create(endOp->getLoc(), offloadIfOp.getResult(0)); + b.setInsertionPointAfter(waitIfOp); +} + +std::tuple solveQuadratic(int64_t a, int64_t b, + int64_t c) { + if (a == 0) { + throw std::invalid_argument( + "Coefficient A cannot be zero in a quadratic equation."); + } + + double discriminant = b * b - 4 * a * c; + // no solution if b^2 - 4ac < 0 + if (discriminant < 0) { + return std::make_tuple(false, 0.0, 0.0); + } + + double sqrtDiscriminant = std::sqrt(discriminant); + + // compute x1, x2 + int64_t x1 = (-b + sqrtDiscriminant) / (2 * a); + int64_t x2 = (-b - sqrtDiscriminant) / (2 * a); + return std::make_tuple(true, x1, x2); +} +int64_t findMinSymbolicDimValue(MemoryUsage memoryPeakExpr, + int64_t memoryLimitation) { + int64_t a, b, c = 0; + // memoryPeakExpr = a * S0^2 + b * S0 + c + // memoryPeakExpr >= memoryLimitation + // a * S0^2 + b * S0 + c - memoryLimitation >= 0 + for (auto prod : memoryPeakExpr) { + if (prod.symbols.size() == 0) { + c = prod.factor; + } else if (prod.symbols.size() == 1) { + b = prod.factor; + } else if (prod.symbols.size() == 2) { + a = prod.factor; + } + } + c -= memoryLimitation; + if (a == 0) { + return -c / b; + } + auto [hasSolution, x0, x1] = solveQuadratic(a, b, c); + if (!hasSolution) { + throw std::invalid_argument("No solution for the quadratic equation."); + } + return std::max(x0, x1); +} + +bool inRematOffloadBlock(Value value) { + if (auto ifOp = dyn_cast(value.getDefiningOp()->getParentOp())) { + auto blockType = + ifOp.getOperation()->getAttrOfType(kRematBlockTypeAttr); + if (blockType && blockType.getValue() == "offload") { + return true; + } + } + return false; +} + +std::optional findSymbolicDimValue(FuncOp& main, + const std::string& key) { + SmallVector dynInputs; + std::unordered_map symValueMap; + main.walk([&](Operation* op) { + auto allocOp = dyn_cast(op); + if (!allocOp) { + return; + } + if (op->getNumOperands() == 0) { + return; + } + auto attrs = + op->getAttrOfType(SymbolicDimOp::getSymbolicDimAttrName()); + int symbDimIndex = 0; + for (auto attr : attrs) { + auto name = attr.cast().getValue(); + if (name.startswith("S")) { + if (symValueMap.count(name.str()) == 0) { + symValueMap[name.str()] = op->getOperand(symbDimIndex++); + } + } + } + return; + }); + if (symValueMap.count(key) == 0) { + return std::nullopt; + } + return symValueMap[key]; +} +Value InsertRematCond(mlir::OpBuilder& b, Location loc, Value s0, + int64_t minS0) { + auto offloadCond = b.create( + s0.getLoc(), arith::CmpIPredicate::sgt, s0, + b.create(s0.getLoc(), minS0)); + return offloadCond; +} +std::vector FilterBuffers( + std::vector livingBuffers) { + std::vector result; + for (auto lb : livingBuffers) { + // filter buffer which living range is too small + if (lb.living_range < 1000) continue; + // filter scalar buffer + auto type = lb.buffer.getType().cast(); + if (type.getRank() == 0) continue; + // filter buffer if already in remat block + if (isa((lb.start->getParentOp()))) continue; + bool isEscapsedBuffer = std::any_of( + lb.buffer.getUsers().begin(), lb.buffer.getUsers().end(), + [](auto user) { + return isa(user); + }); + if (isEscapsedBuffer) continue; + result.push_back(lb); + } + return result; +} +void SortBuffersByPrioriy(std::vector& livingBuffers) { + std::sort(livingBuffers.begin(), livingBuffers.end(), + [](const LivingBuffer& a, const LivingBuffer& b) { + return a.living_range > b.living_range; + }); +} +std::optional PickHighestPriorityLivingBuffer( + const std::vector& livingBuffers) { + // step1: filter living buffers which can not reduce the peak memory or too + // small buffer + auto buffers = FilterBuffers(livingBuffers); + if (buffers.size() == 0) { + return std::nullopt; + } + // step2: sort living buffers by priority, e.g. living range value + SortBuffersByPrioriy(buffers); + return buffers[0]; +} +void DiscOffloadingPass::runOnOperation() { + FuncOp main = getOperation(); + if (main.getName() == SymbolicDimMgr::getShapeConstraintGraphFunctionName()) + return; + mlir::OpBuilder b(main); + // TOOD(yancey): using a ratio to control the memory limitation + const int64_t memoryLimitation = 32212254720; // 30GB + llvm::dbgs() << "memory limitation: " << memoryLimitation << "\n"; + bool changed = true; + int maxIteration = 100; + std::unique_ptr shapeAnalysisPtr( + new ShapeConstraintIRAnalysis(main)); + auto shapeIRAnalysis = + dynamic_cast(shapeAnalysisPtr.get()); + if (!shapeIRAnalysis) { + llvm::errs() << "shape analysis failed\n"; + return; + } + std::unique_ptr profiler( + new SymbolicMemoryProfiler(main, *shapeIRAnalysis)); + std::unique_ptr bufferLivingRange( + new DiscBufferLivingRange(main)); + // mapping symbolic dim(S0) in shape constrint graph to SSA value + // TODO(yancey): please note, we only support only one symbolic dim now, + // let's find a way to enhancement this. + auto symS0Value = findSymbolicDimValue(main, "S0"); + if (!symS0Value) { + llvm::errs() << "failed to find S0 value\n"; + return; + } + while (changed && maxIteration--) { + if (failed(profiler->Analysis())) { + llvm::errs() << "failed to analysis\n"; + return; + } + if (failed(bufferLivingRange->Analysis())) { + llvm::errs() << "failed to analysis buffer living range\n"; + return; + } + changed = false; + auto memoryPeakExpr = profiler->GetPeakMemory(); + int64_t minS0 = findMinSymbolicDimValue(memoryPeakExpr, memoryLimitation); + + auto livingBuffer = + PickHighestPriorityLivingBuffer(bufferLivingRange->GetLivingBuffers()); + if (livingBuffer.has_value()) { + auto buffer = livingBuffer.value(); + auto users = bufferLivingRange->GetUsersOrderByPosition(buffer.buffer); + auto loc = getFusionLocation(b, buffer.start); + auto rematCond = InsertRematCond(b, loc, symS0Value.value(), minS0); + InsertRematBlock(b, buffer, rematCond, users, minS0); + changed = true; + } + } + main.dump(); +} +std::unique_ptr> createDiscOffloadingPass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc index a3de94b8677..3f0adef4bdd 100755 --- a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc @@ -184,7 +184,9 @@ struct GpuCopyOpConvertor : public OpRewritePattern { op.getLoc(), TypeRange{}, ctx, newOperands, target_, false, "gpu"); // TODO(disc): Re-visit this is necessary. // TODO(disc): add a pass to merge sync_on_stream call. - InsertSyncOnStream(op, ctx, stream_handle, rewriter); + if (!isa(op->getParentOp())) { + InsertSyncOnStream(op, ctx, stream_handle, rewriter); + } rewriter.replaceOp(op, newOp.getResults()); return success(); } diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index 3352f7d5e74..f8fc0dece58 100755 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -677,4 +677,9 @@ def DiscReduceBufferLiveRangePass : Pass<"disc-reduce-buffer-live-range", "mlir: def DiscShapePropagatePass : Pass<"disc-shape-propagate", "ModuleOp"> { let summary = "shape analysis pass"; let constructor = "createDiscShapePropagatePass()"; +} + +def DiscOffloadingPass : Pass<"disc-offloading", "mlir::func::FuncOp"> { + let summary = "auto offlaoding pass"; + let constructor = "createDiscOffloadingPass()"; } \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc b/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc new file mode 100644 index 00000000000..0ccabd76c34 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_remat_utils.cc @@ -0,0 +1,450 @@ +// Copyright 2024 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/disc/transforms/disc_remat_utils.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/disc/IR/disc_ral_ops.h" +#include "mlir/disc/IR/hlo_disc_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +namespace mlir { +namespace disc_ral { +bool IsHostBuffer(Value value) { + auto memRefType = value.getType().cast(); + return memRefType.getMemorySpace() == 0; +} + +LogicalResult DiscBufferLivingRange::Analysis() { + buffer_list_.clear(); + living_buffers_.clear(); + int64_t position; + buffer_map_.clear(); + auto reccordOpWithPosition = [&](Value value, Operation* op) { + buffer_map_[value].push_back(std::make_pair(op, position++)); + }; + main_.walk([&](Operation* op) { + // Traverse the function's blocks and operations. + if (auto allocOp = dyn_cast(op)) { + auto buffer = allocOp.getResult(); + if (IsHostBuffer(allocOp.getResult())) { + return; + } + buffer_list_.push_back(buffer); + reccordOpWithPosition(buffer, op); + } else if (isa(op)) { + auto buffer = op->getOperand(0); + if (IsHostBuffer(buffer)) { + return; + } + reccordOpWithPosition(buffer, op); + } else if (auto returnOp = dyn_cast(op)) { + for (auto operand : returnOp.getOperands()) { + reccordOpWithPosition(operand, op); + } + } else if (isa(op) || isa(op) || + isa(op) || isa(op)) { + return; + } else { + for (Value operand : op->getOperands()) { + if (buffer_map_.count(operand)) reccordOpWithPosition(operand, op); + } + } + }); + for (auto iter : buffer_map_) { + auto buffer = iter.first; + for (size_t i = 1; i < iter.second.size(); i++) { + auto start = iter.second[i - 1].first; + auto startPosition = iter.second[i - 1].second; + + auto end = iter.second[i].first; + auto endPosition = iter.second[i].second; + LivingBuffer livingBuffer(start, startPosition, end, endPosition, buffer); + living_buffers_.push_back(livingBuffer); + } + } + return success(); +} +std::vector DiscBufferLivingRange::GetUsersOrderByPosition( + Value buffer) { + std::vector ops; + if (buffer_map_.find(buffer) == buffer_map_.end()) { + return ops; + } + return buffer_map_[buffer]; +} + +//////////////// Symbolic Memory Profiler Utils ////////////////// +bool hasDynamicDimension(Value buffer) { + auto memrefType = buffer.getType().cast(); + for (auto dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) { + return true; + } + } + return false; +} +int64_t getElementSize(Type elementType) { + if (elementType.isF32()) { + return sizeof(float); + } else if (elementType.isF16()) { + return sizeof(uint16_t); + } else if (elementType.isBF16()) { + return sizeof(uint16_t); + } else if (elementType.isF64()) { + return sizeof(double); + } else if (elementType.isInteger(1)) { + return sizeof(bool); + } else if (elementType.isInteger(8)) { + return sizeof(int8_t); + } else if (elementType.isInteger(16)) { + return sizeof(int16_t); + } else if (elementType.isInteger(32)) { + return sizeof(int32_t); + } else if (elementType.isInteger(64)) { + return sizeof(int64_t); + } else if (elementType.isIndex()) { + return sizeof(int32_t); + } else { + llvm::dbgs() << elementType << "\n"; + // Add more types as needed + llvm::errs() << "Unsupported element type\n"; + return -1; + } +} + +MemoryUsage& operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs) { + MemoryUsage result; + bool mergeSymbols = false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i].symbols == rhs.symbols) { + mergeSymbols = true; + lhs[i].factor += rhs.factor; + } + } + if (!mergeSymbols) { + lhs.push_back(rhs); + } + return lhs; +} +MemoryUsage& operator-=(MemoryUsage& lhs, const SymbolicDimProduct& rhs) { + bool mergeSymbols = false; + size_t removeIndex = -1; + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i].symbols == rhs.symbols) { + mergeSymbols = true; + lhs[i].factor -= rhs.factor; + if (lhs[i].factor == 0) { + removeIndex = i; + break; + } + } + } + if (removeIndex != -1) { + lhs.erase(lhs.begin() + removeIndex); + } + if (!mergeSymbols) { + auto newRhs = rhs; + newRhs.factor *= -1; + lhs.push_back(newRhs); + } + return lhs; +} +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const MemoryUsage& memoryUsage) { + for (size_t i = 0; i < memoryUsage.size(); ++i) { + // print prod + os << memoryUsage[i].factor; + if (memoryUsage[i].symbols.size() > 0) os << "*"; + for (size_t j = 0; j < memoryUsage[i].symbols.size(); ++j) { + auto dimOp = memoryUsage[i].symbols[j]; + if (j != memoryUsage[i].symbols.size() - 1) { + os << dimOp.getName() << "*"; + } else { + os << dimOp.getName(); + } + } + if (i != memoryUsage.size() - 1 && memoryUsage[i].factor > 0) { + os << " + "; + } + } + return os; +} +// TODO(yancey): just a PoC implementation, need to implement this function +// by traveling the shape operators +SymbolicDimProduct SimplifySymbolicDims(SymbolicDimProduct proudct, + SymbolicDimOp s0) { + auto factor = proudct.factor; + SmallVector symbols; + for (auto symbol : proudct.symbols) { + StringRef symbolName = const_cast(symbol).getName(); + if (symbolName.startswith("S")) { + if (symbolName == "S0") { + factor *= 1; + } else if (symbolName == "S1") { + factor *= 4; + } else if (symbolName == "S2") { + factor *= 32001; + } else if (symbolName == "S3") { + factor *= 128; + } else { + llvm::errs() << "unsupported symbol name\n"; + } + symbols.push_back(s0); + } else { + symbols.push_back(symbol); + } + } + return SymbolicDimProduct{symbols, factor}; +} +int64_t getMemRefSize(Value value) { + auto memRefType = value.getType().cast(); + int64_t elementSize = getElementSize(memRefType.getElementType()); + if (elementSize < 0) { + return -1; // Unsupported type + } + + int64_t numElements = 1; + for (int64_t dim : memRefType.getShape()) { + numElements *= dim; + } + return numElements * elementSize; +} +SymbolicDimProduct getSymbolicMemrefBytes(Value buffer, SymbolicDimMgr* mgr) { + if (!mgr) { + llvm::errs() << "mgr is nullptr\n"; + return SymbolicDimProduct{}; + } + auto s0 = mgr->findSymbolicDimOp("S0"); + if (!s0.has_value()) { + llvm::errs() << "SymbolicDimOp S0 not found\n"; + } + auto ty = buffer.getType().cast(); + auto elementBytes = getElementSize(ty.getElementType()); + if (hasDynamicDimension(buffer)) { + if (auto recvOp = dyn_cast(buffer.getDefiningOp())) { + // get first user of the buffer + for (auto user : buffer.getUsers()) { + if (isa(user)) { + buffer = user->getResult(0); + } + } + } + auto symDims = getMemRefValueSymbolicDims(*mgr, buffer); + SymbolicDimProduct prod{symDims.value()}; + prod = SimplifySymbolicDims(prod, s0.value()); + prod.factor *= elementBytes; + return mgr->simplifySymbolicDimProduct(prod); + } + SymbolicDimProduct prod{}; + int64_t dimProduct = 1; + for (int64_t dim : ty.getShape()) { + dimProduct *= dim; + } + prod.factor = dimProduct * elementBytes; + return prod; +} +bool isHostBuffer(Value value) { + auto memRefType = value.getType().cast(); + return memRefType.getMemorySpace() == 0; +} +Value maybeInReloadBlock(Value buffer) { + if (auto ifOp = dyn_cast(buffer.getDefiningOp()->getParentOp())) { + return ifOp.getResult(0); + } + return buffer; +} +/////////////////////// SymbolicMemoryProfiler /////////////////////// +bool SymbolicMemoryProfiler::isTempBuffer(Value value) { + Operation* prevFusionOp = nullptr; + for (auto user : value.getUsers()) { + if (isa(user)) continue; + if (auto fusionOp = dyn_cast(user->getParentOp())) { + if (!prevFusionOp) prevFusionOp = fusionOp; + if (prevFusionOp != fusionOp) { + return false; + } + } else { + return false; + } + } + return true; +} + +int64_t getConcretValuewithFakeValue(MemoryUsage memoryUsages, + int64_t cstValue) { + int64_t result = 0; + for (auto prod : memoryUsages) { + int64_t factor = prod.factor; + if (prod.symbols.size() > 0) + factor *= std::pow(cstValue, prod.symbols.size()); + result += factor; + } + return result; +} +float bytesToMB(int64_t bytes) { return bytes * 1.0 / (1024.0 * 1024.0); } +MemoryUsage findPeakMemoryWithFakeValue( + std::vector memoryUsageList, int64_t fakeValue) { + size_t instIndex = 0; + int64_t maxMemoryUsage = 0; + for (size_t i = 0; i < memoryUsageList.size(); ++i) { + int64_t memoryUsage = + getConcretValuewithFakeValue(memoryUsageList[i], fakeValue); + if (memoryUsage > maxMemoryUsage) { + maxMemoryUsage = memoryUsage; + instIndex = i; + } + } + llvm::dbgs() << "fakeValue: " << fakeValue << " maxMemory: " << maxMemoryUsage + << " bytes" + << " instIndex: " << instIndex + << " expr: " << memoryUsageList[instIndex] << "\n"; + return memoryUsageList[instIndex]; +} +std::vector ConcretMemoryUsageSimulator(int64_t concretValue) { + /* + SymbolicDimProductSum memoryUsage; + SmallVector memoryUsageList; + std::unordered_set set, skipBuffers; + mlir::OpBuilder b(main); + llvm::dbgs() << "memory usage with seqlen=" << cstS0 << "\n"; + int rematBuffers = 0; + main.walk([&](Operation* op) { + if (auto allocOp = dyn_cast(op)) { + if (allocOp.getOperation()->getAttr( + b.getStringAttr("disc.remat.dummy-buffer"))) { + return; + } + // alloc operator maybe inside a remat block + auto buffer = allocOp.getResult(); + auto reloadBuffer = maybeInReloadBlock(buffer); + if (auto rematBlock = getRematBlock(allocOp.getOperation())) { + return; + } + if (shouldSkipBufferInPeakMemoryEstimator(buffer)) { + skipBuffers.insert(reloadBuffer); + return; + } + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSumAdd(memoryUsage, bufferBytes); + memoryUsageList.push_back(memoryUsage); + set.insert(reloadBuffer); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + + } else if (auto deallocOp = dyn_cast(op)) { + auto buffer = deallocOp.getOperand(); + if (skipBuffers.count(buffer)) { + return; + } + if (auto rematBlock = getRematBlock(deallocOp.getOperation())) { + return; + } + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSub(memoryUsage, bufferBytes); + memoryUsageList.push_back(memoryUsage); + set.erase(buffer); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + } else if (auto d2hOp = dyn_cast(op)) { + auto buffer = d2hOp.getOperand(0); + auto rematBlock = getRematBlock(d2hOp.getOperation()); + auto rematBuffer = rematBlock->getResult(0); + int64_t minS0 = + rematBlock->getAttrOfType(kRematMinSymDim).getInt(); + if (cstS0 > minS0) { + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSub(memoryUsage, bufferBytes); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + } + } else if (auto h2dOp = dyn_cast(op)) { + auto buffer = h2dOp.getOperand(1); + auto rematBlock = getRematBlock(h2dOp.getOperation()); + auto rematBuffer = rematBlock->getResult(0); + int64_t minS0 = + rematBlock->getAttrOfType(kRematMinSymDim).getInt(); + if (cstS0 > minS0) { + auto rematBuffer = rematBlock->getResult(0); + auto bufferBytes = getSymMemrefSize(buffer, mgr); + memoryUsage = symbolicDimProductSumAdd(memoryUsage, bufferBytes); + llvm::dbgs() << bytesToMB(getConcretValuewithCst(memoryUsage, cstS0)) + << "\n"; + rematBuffers++; + } + } + }); + */ +} + +LogicalResult SymbolicMemoryProfiler::Analysis() { + mlir::OpBuilder b(main_); + mgr_ = &shapeAnalysis_.symbolicDimMgr(); + if (!mgr_) { + llvm::errs() << "mgr is nullptr\n"; + return failure(); + } + memory_usage_list_.clear(); + + MemoryUsage currentUsage; + std::unordered_set skipBuffers; + main_.walk([&](Operation* op) { + if (auto recvOp = dyn_cast(op)) { + auto buffer = recvOp.getResult(); + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage += bufferBytes; + memory_usage_list_.push_back(currentUsage); + } else if (auto sendOp = dyn_cast(op)) { + auto buffer = sendOp.getOperand(2); + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage -= bufferBytes; + memory_usage_list_.push_back(currentUsage); + } else if (auto allocOp = dyn_cast(op)) { + // TODO(yancey): dummy attr is a workaround implement for remat, + // better to have a custom operator + if (allocOp.getOperation()->getAttr( + b.getStringAttr("disc.remat.dummy-buffer"))) { + return; + } + // alloc operator maybe inside a remat block + auto buffer = allocOp.getResult(); + // skip the buffer if it is a temp buffer of fusion op + auto reloadBuffer = maybeInReloadBlock(buffer); + if (isHostBuffer(buffer) || isTempBuffer(buffer)) { + skipBuffers.insert(reloadBuffer); + return; + } + + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage += bufferBytes; + memory_usage_list_.push_back(currentUsage); + } else if (auto deallocOp = dyn_cast(op)) { + auto buffer = deallocOp.getOperand(); + if (skipBuffers.count(buffer)) { + return; + } + auto bufferBytes = getSymbolicMemrefBytes(buffer, mgr_); + currentUsage -= bufferBytes; + memory_usage_list_.push_back(currentUsage); + } + }); + // NOTE: search peak memory expr with fake symbolic value, please + // note, it's a fuzzy search algorithm, the result may not be accurate + // but it's good enough for most cases + peak_memory_ = findPeakMemoryWithFakeValue(memory_usage_list_, 4096); + return success(); +} + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_remat_utils.h b/tao_compiler/mlir/disc/transforms/disc_remat_utils.h new file mode 100644 index 00000000000..b825d6ae59e --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_remat_utils.h @@ -0,0 +1,130 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" +#include "lhlo/IR/lhlo_ops.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/disc/transforms/disc_shape_optimization_utils.h" +#include "mlir/disc/transforms/shape_utils.h" +namespace mlir { +namespace disc_ral { + +struct ValueHash { + std::size_t operator()(const Value& operand) const { + std::size_t hash = mlir::hash_value(operand); + return hash; + } +}; +using OpWithPosition = std::pair; + +bool IsHostBuffer(Value value); +// return sizeof(dtype) +int64_t getElementSize(Type elementType); +// TODO(yancey): just a PoC implementation, need to implement this function +// by traveling the shape operators +SymbolicDimProduct SimplifySymbolicDims(SymbolicDimProduct proudct, + SymbolicDimOp s0); + +struct LivingBuffer { + LivingBuffer(Operation* start, int64_t start_position, Operation* end, + int64_t end_position, Value buffer) + : start(start), + start_position(start_position), + end(end), + end_position(end_position), + buffer(buffer) { + living_range = end_position - start_position; + } + + Operation* start; + int64_t start_position; + Operation* end; + int64_t end_position; + Value buffer; + int64_t living_range; +}; + +class DiscBufferLivingRange { + public: + explicit DiscBufferLivingRange(mlir::func::FuncOp main) : main_(main) {} + + LogicalResult Analysis(); + + std::vector GetLivingBuffers() { return living_buffers_; } + // get all users of the buffer, ordered by position + std::vector GetUsersOrderByPosition(Value buffer); + + private: + mlir::func::FuncOp main_; + std::vector living_buffers_; + // mapping buffer to operator and position + std::unordered_map>, + ValueHash> + buffer_map_; + std::vector buffer_list_; +}; + +using MemoryUsage = llvm::SmallVector; + +MemoryUsage& operator+=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); +MemoryUsage& operator-=(MemoryUsage& lhs, const SymbolicDimProduct& rhs); +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const MemoryUsage& memoryUsage); + +// SymbolicMemoryProfiler is used to analyze the memory usage of a +// mlir function, it will return the peak memory usage and the memory usage +class SymbolicMemoryProfiler { + public: + explicit SymbolicMemoryProfiler(mlir::func::FuncOp& main, + ShapeConstraintIRAnalysis& shapeAnalysis) + : main_(main), shapeAnalysis_(shapeAnalysis) {} + LogicalResult Analysis(); + MemoryUsage GetPeakMemory() { return peak_memory_; } + std::vector GetMemoryUsageList() { return memory_usage_list_; } + + std::vector ConcretMemoryUsageSimulator(int64_t concretValue); + + private: + // return true if it is a temp buffer which only used inside of a fusion op, + // this buffer would be removed after codegen, an example pattern: + // + // alloc = memref.alloc + // lmhlo.fusion() { + // op1(buffer0, buffer1, alloc) + // op2(alloc, buffer2, buffer3) + // } + // dealloc = memref.dealloc alloc + bool isTempBuffer(Value value); + MemoryUsage searchPeakMemory(SmallVector& memoryUsageList); + + SymbolicDimMgr* mgr_; + mlir::func::FuncOp main_; + MemoryUsage peak_memory_; + std::vector memory_usage_list_; + ShapeConstraintIRAnalysis& shapeAnalysis_; +}; + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc b/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc index 0a290ad624a..1d781150a4a 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization.cc @@ -1503,6 +1503,9 @@ DenseMap buildSymbolDimInstancesDominantMap( // we should let as much values as possible to be dominated by a same root for (Value root : roots) { for (Value v : instances) { + // llvm::dbgs() << "print root/v: "; + // root.dump(); + // v.dump(); if (dominantMap.find(v) == dominantMap.end() && dominanceInfo.dominates(root, v.getDefiningOp())) { dominantMap[v] = root; @@ -1689,6 +1692,7 @@ LogicalResult injectStaticKnownInfo(ShapeComputationIRAnalysis& analysis, }; for (mhlo::RealDynamicSliceOp op : sliceOps) { + op->dump(); SliceOpShapeHelper helper(op); Value in = op->getOperand(0); Value out = op->getResult(0); @@ -1698,10 +1702,11 @@ LogicalResult injectStaticKnownInfo(ShapeComputationIRAnalysis& analysis, auto outDims = analysis.rankedTensor2SymDims(out); if (inDims.has_value() && outDims.has_value()) { for (const auto& en : llvm::enumerate(llvm::zip(*inDims, *outDims))) { - if (std::get<0>(en.value()) == std::get<1>(en.value())) + if (std::get<0>(en.value()) == std::get<1>(en.value())) { if (failed(helper.markAsFullySlicedAxis(en.index()))) - return op->emitError() << "failed to mark axis" << en.index() + return op->emitError() << "failed to mark axis " << en.index() << " to be fully sliced\n"; + } } } @@ -1778,7 +1783,6 @@ LogicalResult applyShapeComputationOptimization( // - some axes of a pad op are not padded; if (failed(injectStaticKnownInfo(analysis, changed))) return analysis.getFunc()->emitError("fail to injectStaticKnownInfo\n"); - return success(); } diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc index fc87b49ee20..6f2f1e6b21d 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.cc @@ -85,10 +85,17 @@ bool compareSymbolicDimProduct(const SymbolicDimProduct& lhs, } llvm::raw_ostream& operator<<(llvm::raw_ostream& os, - const SymbolicDimProduct& product) { - os << "SymbolicDimProduct[\n\tfactor: " << product.factor << ",\n"; - for (auto& s : product.symbols) os << "\tsymbol: " << s << "\n"; - os << "]\n"; + const SymbolicDimProduct& prod) { + os << prod.factor; + if (prod.symbols.size() > 0) os << "*"; + for (size_t j = 0; j < prod.symbols.size(); ++j) { + auto dimOp = prod.symbols[j]; + if (j != prod.symbols.size() - 1) { + os << dimOp.getName() << "*"; + } else { + os << dimOp.getName(); + } + } return os; } @@ -101,7 +108,14 @@ LogicalResult SymbolicDimMgr::load() { }); return loadShapeConstraintGraph(); } - +std::optional SymbolicDimMgr::findSymbolicDimOp(StringRef name) { + for (auto p : symbolDimUnionSet_) { + if (p.first.getName() == name) { + return p.first; + } + } + return std::nullopt; +} std::string SymbolicDimMgr::getNextName() { std::string name; do { @@ -248,7 +262,6 @@ SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, return std::make_pair(std::move(newLhs), std::move(newRhs)); } - std::optional SymbolicDimMgr::symbolicDimProductDivide( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { LLVM_DEBUG(llvm::dbgs() << "Try to check if x % y == 0?\nx = " << lhs @@ -985,7 +998,6 @@ SliceOpShapeHelper::SliceOpShapeHelper(Operation* op) : op(op) { assert(startAttr.getNumElements() == ty.getRank()); assert(limitAttr.getNumElements() == ty.getRank()); assert(strideAttr.getNumElements() == ty.getRank()); - for (int i = 0; i < ty.getRank(); ++i) { mergeStartIndex(i, startAttr.getValues()[i]); mergeLimitIndex(i, limitAttr.getValues()[i]); @@ -1020,9 +1032,12 @@ LogicalResult SliceOpShapeHelper::markAsFullySlicedAxis(int axis) { LogicalResult SliceOpShapeHelper::mergeStartIndex(int axis, int64_t value) { assert(axis < static_cast(startIndices.size())); - if (startIndices[axis] != value && value != ShapeValueState::kUnknown && - startIndices[axis] != ShapeValueState::kUnknown) - return failure(); + // NOTE(yancey): commented out the following check to support dynamic slice, + // it's valid if start index greater than 0, let's recovery this if statement + // if any issue occurs. + // if (startIndices[axis] != value && value != ShapeValueState::kUnknown && + // startIndices[axis] != ShapeValueState::kUnknown) + // return failure(); if (startIndices[axis] == ShapeValueState::kUnknown) startIndices[axis] = value; diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h index b04cf424c1d..7e126c76ced 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h +++ b/tao_compiler/mlir/disc/transforms/disc_shape_optimization_utils.h @@ -378,9 +378,10 @@ class SymbolicDimMgr { // Returns a clone of the original symbol SymbolicDimOp cloneSymbol(SymbolicDimOp symbol); + std::optional findSymbolicDimOp(StringRef name); - // Clones a group of symbols and the relationships among the symbols in the - // group. Returns ok if success, otherwise failure. + // Clones a group of symbols and the relationships among the symbols in + // the group. Returns ok if success, otherwise failure. LogicalResult cloneSymbolGroup( const DenseSet& symbols, DenseMap& mapping); diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc index a17817b83a5..c3b5d0e86d9 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/disc/disc_util.h" #include "mlir/disc/transforms/PassDetail.h" #include "mlir/disc/transforms/shape_utils.h" +#include "tensorflow/tsl/platform/default/logging.h" namespace mlir { namespace disc_ral { @@ -72,7 +73,8 @@ struct DiscShapePropagatePass registry.insert(); } void visitOperator(ModuleOp& m, OpBuilder& rewriter, Operation* op, - std::stack& ctxStack); + ShapeContext& ctx, + std::unordered_map& symbolicMap); void runOnOperation() override; }; bool isBinaryOp(Operation* op) { @@ -110,33 +112,29 @@ std::optional getConstTensor(OpBuilder& b, Operation* op, return const_op.getResult(); } +void PrintCtx(ShapeContext& ctx) { + llvm::dbgs() << "value: "; + ctx.value.dump(); + for (size_t i = 0; i < ctx.shape.size(); ++i) { + llvm::dbgs() << "shape: [" << i << "] = " << ctx.shape[i] << "\n"; + } +} + std::optional HandleBinaryOp(OpBuilder& b, Operation* op, ShapeContext& inputCtx) { - auto bcastOp = dyn_cast_or_null( - op->getOperand(1).getDefiningOp()); - if (!bcastOp) { - return ShapeContext(op->getResult(0), inputCtx.shape); - } - if (bcastOp) { - auto constOp = dyn_cast_or_null( - bcastOp->getOperand(0).getDefiningOp()); - if (!constOp) { - return ShapeContext(op->getResult(0), inputCtx.shape); - } + if (auto bcastOp = dyn_cast_or_null( + op->getOperand(1).getDefiningOp())) { auto elemTy = op->getOperand(0).getType().cast().getElementType(); b.setInsertionPoint(op); - auto dense_attr = constOp.getValue().dyn_cast(); - int64_t value = dense_attr.getValues()[0]; - auto scalar_const_op = getConstTensor(b, op, {value}, {}); - Value inputShape = - b.create(op->getLoc(), op->getOperand(0)); - auto rank = inputCtx.shape.size(); - + auto shapeOf = b.create(op->getLoc(), op->getOperand(0)); + auto outShape = + op->getOperand(0).getType().cast().getShape(); auto dynBcastOp = b.create( op->getLoc(), RankedTensorType::get(inputCtx.shape, elemTy), - scalar_const_op.value(), inputShape, b.getI64TensorAttr({})); + bcastOp->getOperand(0), shapeOf, bcastOp.getBroadcastDimensions()); bcastOp.getResult().replaceAllUsesWith(dynBcastOp.getResult()); + bcastOp->erase(); } return ShapeContext(op->getResult(0), inputCtx.shape); } @@ -176,9 +174,10 @@ std::optional propagateHelper( } } -template <> -std::optional propagateHelper( - OpBuilder& b, Operation* op, ShapeContext& inputCtx) { +std::optional HandleReshapeOp( + OpBuilder& b, Operation* op, ShapeContext& inputCtx, + std::unordered_map& symbolicMap) { + b.setInsertionPoint(op); auto reshape_op = dyn_cast(op); if (!reshape_op) return std::nullopt; Type intType = b.getIntegerType(32); @@ -188,13 +187,47 @@ std::optional propagateHelper( reshape_op.getResult().getType().cast(); auto resultRank = resultRankType.getRank(); auto resultShape = resultRankType.getShape(); + auto inputShape = reshape_op.getOperand().getType().cast(); SmallVector newShape(resultRank, ShapedType::kDynamic); + bool symbolicSeqlen = false; + for (size_t i = 0; i < resultShape.size(); ++i) { + if (symbolicMap.count(resultShape[i])) { + symbolicSeqlen = true; + } + } + if (symbolicSeqlen) { + SmallVector newShapeValues; + SmallVector newShape; + for (size_t i = 0; i < resultShape.size(); ++i) { + if (symbolicMap.count(resultShape[i])) { + newShape.push_back(ShapedType::kDynamic); + newShapeValues.push_back(symbolicMap[resultShape[i]]); + } else { + newShape.push_back(resultShape[i]); + newShapeValues.push_back( + b.create(op->getLoc(), resultShape[i])); + } + } + Value shapeValue = + b.create(op->getLoc(), newShapeValues); + auto shape = b.create(op->getLoc(), op->getOperand(0)); + auto numElems = b.create(op->getLoc(), shape); + auto computeReshapeShape = b.create( + op->getLoc(), shapeValue.getType(), numElems.getResult(), shapeValue); + auto dynReshapeOpResultType = + RankedTensorType::get(newShape, resultRankType.getElementType()); + auto dynReshapeOp = b.create( + op->getLoc(), dynReshapeOpResultType, reshape_op.getOperand(), + computeReshapeShape); + op->getResult(0).replaceAllUsesWith(dynReshapeOp.getResult()); + op->erase(); + return ShapeContext(dynReshapeOp->getResult(0), newShape); + } int64_t numel = std::accumulate(inputCtx.shape.begin(), inputCtx.shape.end(), int64_t(1), [](int64_t acc, int64_t num) { return num == ShapedType::kDynamic ? acc : acc * num; }); - bool inferenced = true; while (inferenced) { inferenced = false; @@ -280,15 +313,21 @@ std::optional propagateHelper( auto stridesCst = slice_op.getStrides().getValues()[i]; startIndices[i] = b.create(slice_op.getLoc(), startIndicesCst); - // using dynamic dim if limitIndices is the same as input shape - if (limitIndicesCst == inputShape[i] && - inputCtx.shape[i] == ShapedType::kDynamic) { - limitIndices[i] = b.create(loc, slice_op.getOperand(), i); - newShape[i] = inputCtx.shape[i]; + if (inputCtx.shape[i] == ShapedType::kDynamic) { + newShape[i] = ShapedType::kDynamic; + if (limitIndicesCst == inputShape[i]) { + limitIndices[i] = + b.create(loc, slice_op.getOperand(), i); + } else { + auto limitOffset = inputShape[i] - limitIndicesCst; + limitIndices[i] = b.create( + loc, b.create(loc, slice_op.getOperand(), i), + b.create(loc, limitOffset)); + } } else { + newShape[i] = (limitIndicesCst - startIndicesCst - 1) / stridesCst + 1; limitIndices[i] = b.create(slice_op.getLoc(), limitIndicesCst); - newShape[i] = (limitIndicesCst - startIndicesCst - 1) / stridesCst + 1; } strides[i] = b.create(slice_op.getLoc(), stridesCst); @@ -297,13 +336,13 @@ std::optional propagateHelper( Value stridesValue = b.create(loc, strides); Value limitIndicesValue = b.create(loc, limitIndices); auto sliceOpResultType = - RankedTensorType::get(newShape, rankType.getElementType()); - auto dyncSliceOp = b.create( + RankedTensorType::get(inputCtx.shape, rankType.getElementType()); + auto dynSliceOp = b.create( loc, sliceOpResultType, slice_op.getOperand(), baseIndicesValue, limitIndicesValue, stridesValue); - op->getResult(0).replaceAllUsesWith(dyncSliceOp.getResult()); + op->getResult(0).replaceAllUsesWith(dynSliceOp.getResult()); op->erase(); - return ShapeContext(dyncSliceOp->getResult(0), newShape); + return ShapeContext(dynSliceOp->getResult(0), newShape); } template <> @@ -472,6 +511,68 @@ std::optional propagateHelper( SmallVector newShape(resultShape.begin(), resultShape.end()); return ShapeContext(op->getResult(0), newShape); } +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto resultShape = + op->getResult(0).getType().cast().getShape(); + SmallVector newShape(resultShape.begin(), resultShape.end()); + return ShapeContext(op->getResult(0), newShape); +} + +template <> +std::optional propagateHelper( + OpBuilder& rewriter, Operation* op, ShapeContext& inputCtx) { + auto padOp = dyn_cast(op); + if (!padOp) return std::nullopt; + rewriter.setInsertionPoint(op); + Value input = op->getOperand(0); + Value paddingValue = op->getOperand(1); + auto padOpType = padOp.getType().cast(); + + auto inputTy = input.getType().cast(); + auto resultTy = op->getResult(0).getType().cast(); + auto rank = inputTy.getRank(); + auto loc = padOp.getLoc(); + SmallVector paddingLow, paddingHigh, paddingInterior; + for (auto padDimension : + llvm::enumerate(padOp.getEdgePaddingLow().getValues())) { + paddingLow.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padDimension.value()))); + } + for (auto padDimension : + llvm::enumerate(padOp.getEdgePaddingHigh().getValues())) { + paddingHigh.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padDimension.value()))); + } + for (auto padDimension : + llvm::enumerate(padOp.getInteriorPadding().getValues())) { + paddingInterior.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padDimension.value()))); + } + + SmallVector newShape; + for (size_t i = 0; i < resultTy.getRank(); ++i) { + if (inputCtx.shape[i] == ShapedType::kDynamic) { + newShape.push_back(ShapedType::kDynamic); + } else { + newShape.push_back(inputCtx.shape[i]); + } + } + auto dynPadType = RankedTensorType::get(newShape, padOpType.getElementType()); + Value paddingLowTensor = + rewriter.create(loc, paddingLow); + Value paddingHighTensor = + rewriter.create(loc, paddingHigh); + Value paddingInterTensor = + rewriter.create(loc, paddingInterior); + auto dynPadOp = rewriter.create( + op->getLoc(), dynPadType, input, paddingValue, paddingLowTensor, + paddingHighTensor, paddingInterTensor); + op->getResult(0).replaceAllUsesWith(dynPadOp.getResult()); + op->erase(); + return ShapeContext(dynPadOp->getResult(0), newShape); +} template <> std::optional propagateHelper( @@ -548,6 +649,7 @@ std::optional propagateHelper( attr.getStartIndexMap(), attr.getIndexVectorDim()), gather_op.getIndicesAreSorted()); gather_op.getResult().replaceAllUsesWith(dynamic_gather_op.getResult()); + gather_op->erase(); // Update DynamicGatherOp result shape information return propagateHelper( @@ -588,17 +690,22 @@ LogicalResult parseInputDynamicDims( } void applyShapeContext(ShapeContext& ctx) { - if (!ctx.value) return; auto res_ty = ctx.value.getType().dyn_cast(); if (!res_ty) return; auto elemTy = res_ty.getElementType(); ctx.value.setType(RankedTensorType::get(ctx.shape, elemTy)); } -std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, - ShapeContext& inputCtx) { +std::optional propagateOpShape( + OpBuilder& rewriter, Operation* op, ShapeContext& inputCtx, + std::unordered_map& symbolicMap) { if (isUnaryOp(op)) { - return ShapeContext(op->getResult(0), inputCtx.shape); + if (inputCtx.value == op->getOperand(0)) { + return ShapeContext(op->getResult(0), inputCtx.shape); + } + SmallVector newShape = llvm::to_vector<4>( + op->getOperand(0).getType().cast().getShape()); + return ShapeContext(op->getResult(0), newShape); } if (isBinaryOp(op)) { return HandleBinaryOp(rewriter, op, inputCtx); @@ -606,8 +713,19 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, if (isa(op)) { return propagateHelper(rewriter, op, inputCtx); } - if (isa(op)) { - return ShapeContext(op->getResult(0), inputCtx.shape); + if (isa(op)) { + return HandleReshapeOp(rewriter, op, inputCtx, symbolicMap); + } + if (auto bcastOp = dyn_cast(op)) { + auto result = op->getResult(0); + auto resultTy = result.getType().cast(); + auto elemTy = resultTy.getElementType(); + bool withSymbolicShape = false; + SmallVector mhloShape; + SmallVector shapes; + + SmallVector newShape = llvm::to_vector<4>(resultTy.getShape()); + return ShapeContext(result, newShape); } #define PROPAGATE_OP_HANDLER(opType) \ if (auto t##opType = dyn_cast(op)) { \ @@ -616,7 +734,7 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, } PROPAGATE_OP_HANDLER(DotOp); PROPAGATE_OP_HANDLER(SliceOp); - PROPAGATE_OP_HANDLER(ReshapeOp); + PROPAGATE_OP_HANDLER(PadOp); PROPAGATE_OP_HANDLER(ConcatenateOp); PROPAGATE_OP_HANDLER(ReduceOp); PROPAGATE_OP_HANDLER(TransposeOp); @@ -626,6 +744,7 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, PROPAGATE_OP_HANDLER(DynamicReshapeOp); PROPAGATE_OP_HANDLER(RealDynamicSliceOp); PROPAGATE_OP_HANDLER(DynamicBroadcastInDimOp); + PROPAGATE_OP_HANDLER(DynamicPadOp); // PROPAGATE_OP_HANDLER(DimOp); #undef PROPAGATE_OP_HANDLER return std::nullopt; @@ -633,38 +752,169 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, bool shouldStopPropagation(Operation* op, ShapeContext& ctx) { if (isConcreteShape(ctx)) return true; - if (isa(op)) + if (isa(op)) return true; if (isa(op->getParentOp())) return true; return false; } +struct OperationValuePairHash { + std::size_t operator()(const std::pair& pair) const { + return llvm::hash_combine(pair.first, pair.second); + } +}; +// Custom equality function for std::pair +struct OperationValuePairEqual { + bool operator()(const std::pair& lhs, + const std::pair& rhs) const { + return lhs.first == rhs.first && lhs.second == rhs.second; + } +}; -void DiscShapePropagatePass::visitOperator(ModuleOp& m, OpBuilder& rewriter, - Operation* op, - std::stack& ctxStack) { - auto ctx = ctxStack.top(); - if (shouldStopPropagation(op, ctx)) { - return; +void DiscShapePropagatePass::visitOperator( + ModuleOp& m, OpBuilder& rewriter, Operation* rootOp, ShapeContext& ctx, + std::unordered_map& symbolicMap) { + std::stack> stack; + std::stack ctxStack; + std::unordered_set, OperationValuePairHash, + OperationValuePairEqual> + visited; + if (VLOG_IS_ON(1)) { + VLOG(1) << "rootOp:"; + rootOp->dump(); } - auto resultShapeCtx = propagateOpShape(rewriter, op, ctx); - if (!resultShapeCtx) { - m.emitError("failed propagate shape on op:" + - op->getName().stripDialect().str()); - signalPassFailure(); - return; + stack.push({rootOp, ctx}); + auto pair = std::make_pair(rootOp, ctx.value); + visited.insert(pair); + while (!stack.empty()) { + auto [op, inputCtx] = stack.top(); + stack.pop(); + if (shouldStopPropagation(op, inputCtx)) { + while (!ctxStack.empty()) { + auto ctx = ctxStack.top(); + ctxStack.pop(); + applyShapeContext(ctx); + } + continue; + } + auto resultCtx = propagateOpShape(rewriter, op, inputCtx, symbolicMap); + if (!resultCtx) { + m.emitError("failed propagate shape on op:" + + op->getName().stripDialect().str()); + signalPassFailure(); + return; + } + ctxStack.push(*resultCtx); + SmallVector users(resultCtx->value.getUsers().begin(), + resultCtx->value.getUsers().end()); + for (size_t i = 0; i < users.size(); ++i) { + auto pair = std::make_pair(users[i], resultCtx->value); + if (visited.count(pair)) continue; + stack.push({users[i], *resultCtx}); + visited.insert(pair); + } } - ctxStack.push(*resultShapeCtx); - SmallVector ctxUsers(resultShapeCtx->value.getUsers().begin(), - resultShapeCtx->value.getUsers().end()); - for (size_t i = 0; i < ctxUsers.size(); ++i) { - visitOperator(m, rewriter, ctxUsers[i], ctxStack); + while (!ctxStack.empty()) { + auto ctx = ctxStack.top(); + ctxStack.pop(); + applyShapeContext(ctx); } - auto context = ctxStack.top(); - ctxStack.pop(); - applyShapeContext(context); } +std::optional HandleConstOp( + OpBuilder& rewriter, Operation* op, + std::unordered_map& symbolicMap) { + auto constOp = dyn_cast(op); + if (!constOp) return std::nullopt; + auto resultTy = op->getResult(0).getType().dyn_cast(); + if (!resultTy) return std::nullopt; + + auto elemTy = resultTy.getElementType(); + bool withSymbolicShape = false; + SmallVector mhloShape; + SmallVector shapes; + rewriter.setInsertionPoint(op); + for (auto dim : resultTy.getShape()) { + if (symbolicMap.count(dim)) { + withSymbolicShape = true; + mhloShape.push_back(symbolicMap[dim]); + shapes.push_back(ShapedType::kDynamic); + } else { + mhloShape.push_back( + rewriter.create(op->getLoc(), dim)); + shapes.push_back(dim); + } + } + if (withSymbolicShape) { + auto mhloShapeValue = + rewriter.create(op->getLoc(), mhloShape); + auto dense_attr = constOp.getValue().dyn_cast(); + Value scalar_const_op; + if (elemTy.isIntOrIndex()) { + auto value = (*dense_attr.getValues().begin()).getSExtValue(); + auto const_type = RankedTensorType::get({}, elemTy); + auto const_attr = DenseElementsAttr::get(const_type, {value}); + scalar_const_op = rewriter.create( + op->getLoc(), const_type, const_attr); + } else if (elemTy.isa()) { + auto value = (*dense_attr.getValues().begin()); + auto const_type = RankedTensorType::get({}, elemTy); + auto const_attr = DenseElementsAttr::get(const_type, {value}); + scalar_const_op = rewriter.create( + op->getLoc(), const_type, const_attr); + } else if (elemTy.isa()) { + auto value = (*dense_attr.getValues().begin()).convertToDouble(); + auto const_type = RankedTensorType::get({}, elemTy); + auto const_attr = DenseElementsAttr::get(const_type, {value}); + scalar_const_op = rewriter.create( + op->getLoc(), const_type, const_attr); + } + auto mhloBroadcastInDimOp = rewriter.create( + op->getLoc(), RankedTensorType::get(shapes, elemTy), scalar_const_op, + mhloShapeValue, /*broadcast_dimensions=*/rewriter.getI64TensorAttr({})); + op->getResult(0).replaceAllUsesWith(mhloBroadcastInDimOp.getResult()); + constOp->erase(); + mhloBroadcastInDimOp.dump(); + return mhloBroadcastInDimOp; + } + return std::nullopt; +} +std::optional HandleDyncBroadcastOp( + OpBuilder& rewriter, Operation* op, + std::unordered_map& symbolicMap) { + auto bcastOp = dyn_cast(op); + if (!bcastOp) return std::nullopt; + auto result = op->getResult(0); + auto resultTy = result.getType().cast(); + auto elemTy = resultTy.getElementType(); + bool withSymbolicShape = false; + SmallVector mhloShape; + SmallVector shapes; + rewriter.setInsertionPoint(op); + for (auto dim : resultTy.getShape()) { + if (symbolicMap.count(dim)) { + withSymbolicShape = true; + mhloShape.push_back(symbolicMap[dim]); + shapes.push_back(ShapedType::kDynamic); + } else { + mhloShape.push_back( + rewriter.create(op->getLoc(), dim)); + shapes.push_back(dim); + } + } + if (withSymbolicShape) { + auto mhloShapeValue = + rewriter.create(op->getLoc(), mhloShape); + auto mhloBroadcastInDimOp = rewriter.create( + op->getLoc(), RankedTensorType::get(shapes, elemTy), op->getOperand(0), + mhloShapeValue, bcastOp.getBroadcastDimensions()); + op->getResult(0).replaceAllUsesWith(mhloBroadcastInDimOp.getResult()); + op->erase(); + return mhloBroadcastInDimOp; + } + return std::nullopt; +} void DiscShapePropagatePass::runOnOperation() { ModuleOp m = getOperation(); auto main = m.lookupSymbol("main"); @@ -695,15 +945,62 @@ void DiscShapePropagatePass::runOnOperation() { SmallVector newShape; std::copy(ty.getShape().begin(), ty.getShape().end(), std::back_inserter(newShape)); + if (pair.second.size() != 1 || pair.second[0] != 1) { + main.emitError("only support sequence length dims equal to 1"); + signalPassFailure(); + return; + } for (auto dim : pair.second) { newShape[dim] = ShapedType::kDynamic; } - std::stack ctxStack; ShapeContext ctx(value, newShape); - ctxStack.push(ctx); auto newType = RankedTensorType::get(newShape, ty.getElementType()); - for (auto user : main.getArgument(argIdx).getUsers()) { - visitOperator(m, rewriter, user, ctxStack); + SmallVector users(value.getUsers().begin(), + value.getUsers().end()); + rewriter.setInsertionPointToStart(&main.getBody().front()); + std::unordered_map symbolicMap; + // seqlen + auto seqlen = ty.getShape()[pair.second[0]]; + auto seqlenValue = rewriter.create(users[0]->getLoc(), value, + pair.second[0]); + symbolicMap.insert({seqlen, seqlenValue}); + // seq sliced: seqlen - 1 + auto seqlenSliced = seqlen - 1; + auto seqlenSlicedValue = rewriter.create( + users[0]->getLoc(), seqlenValue, + rewriter.create(users[0]->getLoc(), 1)); + symbolicMap.insert({seqlenSliced, seqlenValue}); + // batch size + auto bsz = ty.getShape()[0]; + auto bszValue = + rewriter.create(users[0]->getLoc(), bsz); + // symbolicMap.insert({bsz, bszValue}); + + // bsz * seqlen + auto bszSeq = bsz * seqlen; + auto bszSeqValue = rewriter.create(users[0]->getLoc(), + bszValue, seqlenValue); + symbolicMap.insert({bszSeq, bszSeqValue}); + + // bszSeqlen = bsz * (seqlen - 1) + auto bszSlicedSeqlen = ty.getShape()[0] * (seqlen - 1); + auto bszSlicedSeqlenValue = rewriter.create( + users[0]->getLoc(), bszValue, + rewriter.create( + users[0]->getLoc(), seqlenSlicedValue, + rewriter.create(users[0]->getLoc(), 1))); + + symbolicMap.insert({bszSlicedSeqlen, bszSlicedSeqlenValue}); + main.walk([&](Operation* op) { + if (auto dynOp = HandleDyncBroadcastOp(rewriter, op, symbolicMap)) { + users.push_back(dynOp.value()); + } + if (auto dynOp = HandleConstOp(rewriter, op, symbolicMap)) { + users.push_back(dynOp.value()); + } + }); + for (auto user : users) { + visitOperator(m, rewriter, user, ctx, symbolicMap); } new_arg_types[argIdx] = newType; applyShapeContext(ctx); diff --git a/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc b/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc index 5e4e84c400b..b8d45161738 100644 --- a/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc +++ b/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc @@ -67,6 +67,8 @@ constexpr const char* kRalDispatchFunctionName = "disc_ral_call"; constexpr const char* kRalGpuLaunch = "ral_kernel_launch"; constexpr const char* kRalCpuLaunch = "ral_kernel_launch"; constexpr const char* kMalloc = "alloc"; +constexpr const char* kPinnedMalloc = "ral_gpu_pinned_alloc"; +constexpr const char* kPinnedFree = "ral_gpu_pinned_dealloc"; constexpr const char* kFree = "dealloc"; constexpr const char* kRalCompIntensFusion = "ral_comp_intens_fusion"; @@ -736,11 +738,28 @@ LogicalResult ConvertMemRefAllocOpToDispatchOpPattern::matchAndRewrite( getMemRefDescriptorSizes(loc, memref_type, llvm::to_vector<4>(adaptor.getOperands()), rewriter, sizes, strides, sizeBytes); - - // create dispatch op - auto dispatch_op = rewriter.create( - loc, getVoidPtrType(), context_arg, sizeBytes, kMalloc, false, device); - Value allocated_byte_ptr = dispatch_op.getResult(0); + // using pinned memory to overlap data transfer and computation kernels + bool pinnedMemory = false; + for (auto user : memref.getUsers()) { + if (auto dispatch = dyn_cast(user)) { + if (dispatch.getCallTargetName() == "d2h") { + pinnedMemory = true; + break; + } + } + } + StringRef targetName = kMalloc; + Operation* dispatch_op; + if (device == "cpu" && pinnedMemory) { + // create dispatch op + dispatch_op = rewriter.create( + loc, getVoidPtrType(), context_arg, sizeBytes, kPinnedMalloc, false, + "gpu"); + } else { + dispatch_op = rewriter.create( + loc, getVoidPtrType(), context_arg, sizeBytes, kMalloc, false, device); + } + Value allocated_byte_ptr = dispatch_op->getResult(0); // Create the MemRef descriptor. MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( @@ -793,10 +812,26 @@ LogicalResult ConvertMemRefDeallocOpToDispatchOpPattern::matchAndRewrite( MemRefDescriptor memref(adaptor.getMemref()); Value allocated_bytes_ptr = rewriter.create( loc, getVoidPtrType(), memref.allocatedPtr(rewriter, loc)); - - ModuleOp module = op->getParentOfType(); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, context_arg, allocated_bytes_ptr, kFree, false, device); + bool pinnedMemory = false; + for (auto user : dealloc_op.getMemref().getUsers()) { + if (auto dispatch = dyn_cast(user)) { + if (dispatch.getCallTargetName() == "h2d") { + pinnedMemory = true; + break; + } + } + } + if (device == "cpu" && pinnedMemory) { + // using pinned memory to achive higher performance to transfer data between + // host and device + rewriter.replaceOpWithNewOp( + dealloc_op, TypeRange{}, context_arg, allocated_bytes_ptr, kPinnedFree, + false, "gpu"); + } else { + rewriter.replaceOpWithNewOp( + dealloc_op, TypeRange{}, context_arg, allocated_bytes_ptr, kFree, false, + device); + } return success(); } diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index b09b0bef411..c9c047661dc 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -337,6 +337,7 @@ std::unique_ptr> createDiscReduceBufferLiveRangePass(); // rewrite mhlo collective ops to disc custom library call std::unique_ptr> createDiscCollectiveOpsRewriterPass(); +std::unique_ptr> createDiscOffloadingPass(); } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc index 374d39103ce..616250b267e 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -126,6 +126,7 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { std::map, GpuFunctionHandle> kernels; std::shared_ptr gpu_allocator; + std::shared_ptr cuda_host_allocator; bool cache_workspace_mem_across_execution; #ifdef TAO_RAL_USE_STREAM_EXECUTOR ::stream_executor::Stream* se_stream; @@ -183,6 +184,7 @@ std::unique_ptr MakeBaseCudaContext( } else { state->gpu_allocator.reset(new InternalAllocator(gpu_alloc, gpu_dealloc)); } + state->cuda_host_allocator = gpu_opt.cuda_host_allocator; state->cache_workspace_mem_across_execution = opt.cache_workspace_mem_across_execution; @@ -268,7 +270,18 @@ buffer_t ral_base_cuda_alloc(ExecutionContext* ctx, size_t bytes) { exec_ctx->device_ptr_map.insert(std::make_pair(ptr, 1)); return ptr; } - +buffer_t ral_base_cuda_pinned_alloc(ExecutionContext* ctx, size_t bytes) { + auto* state = + ctx->getResource(kRalBaseCudaContextState); + auto exec_ctx = dynamic_cast(ctx); + std::lock_guard lock(state->mu); + TAO_VLOG(1) << "before ral_base_cuda_pinned_alloc alloc " << bytes; + bytes = (bytes ? bytes : 1); + void* ptr = state->cuda_host_allocator->alloc(bytes); + TAO_VLOG(1) << "after ral_base_cuda_pinned_alloc with ptr= " << ptr; + exec_ctx->host_ptr_map.insert(std::make_pair(ptr, 1)); + return ptr; +} void ral_base_cuda_memset(ExecutionContext* ctx, stream_t handle, buffer_t buffer, int value, size_t bytes) { if (!buffer) { @@ -321,6 +334,32 @@ void ral_base_cuda_dealloc(ExecutionContext* ctx, buffer_t buffer) { } TAO_VLOG(1) << "after ral_base_cuda_dealloc with ptr = " << buffer; } +void ral_base_cuda_pinned_dealloc(ExecutionContext* ctx, buffer_t buffer) { + /* + if (!buffer) { + TAO_VLOG(1) << "ral_base_cuda_dealloc early return for nullptr"; + return; + } + + auto* state = + ctx->getResource(kRalBaseCudaContextState); + auto exec_ctx = dynamic_cast(ctx); + + std::lock_guard lock(state->mu); + TAO_VLOG(1) << "before ral_base_cuda_pinned_dealloc with ptr = " << buffer; + if (state->device_persistent_buffers.count(buffer)) return; + auto it = exec_ctx->host_ptr_map.find(buffer); + CHECK(it != exec_ctx->host_ptr_map.end()); + if (--it->second == 0) { + cudaFreeHost(buffer); + exec_ctx->host_ptr_map.erase(buffer); + TAO_VLOG(1) << "delete buffer after ref-count becoming zero"; + } + TAO_VLOG(1) << "after ral_base_cuda_pinned_dealloc with ptr = " << buffer; + */ + + TAO_VLOG(1) << "after ral_base_cuda_pinned_dealloc with ptr = " << buffer; +} buffer_t ral_base_cuda_raw_alloc(Context* ctx, size_t bytes) { auto* state = static_cast( @@ -498,6 +537,8 @@ void ral_base_cuda_sync_on_stream(ExecutionContext* ctx, stream_t sidx) { ctx->signalError(Context::FAILURE, "not a valid stream idx"); return; } + auto comm_stream = + static_cast(ctx)->getCommStream(); auto* state = ctx->getResource(kRalBaseCudaContextState); @@ -505,7 +546,7 @@ void ral_base_cuda_sync_on_stream(ExecutionContext* ctx, stream_t sidx) { reportErrorIfAny(stream_executor::wrap::hipStreamSynchronize(state->stream), ctx, "StreamSync"); #else - reportErrorIfAny(cuStreamSynchronize(state->stream), ctx, "StreamSync"); + reportErrorIfAny(cuStreamSynchronize(comm_stream), ctx, "StreamSync"); #endif } @@ -626,6 +667,10 @@ ::tao::ral::MemRefType ral_base_cuda_bitcast_0d( void ral_base_cuda_h2d(ExecutionContext* ctx, void* stream_handle, const void* h_src, buffer_t d_dst, size_t bytes) { + TAO_VLOG(1) << "ral_base_cuda_h2d, h_src: " << h_src << ", d_dst: " << d_dst + << ", bytes: " << bytes; + auto comm_stream = + static_cast(ctx)->getCommStream(); auto* state = ctx->getResource(kRalBaseCudaContextState); #if TENSORFLOW_USE_ROCM @@ -635,15 +680,19 @@ void ral_base_cuda_h2d(ExecutionContext* ctx, void* stream_handle, ctx, "cuMemcpyHtoDAsync"); #else reportErrorIfAny( - cuMemcpyHtoDAsync((GpuDevicePtr)d_dst, h_src, bytes, state->stream), ctx, + cuMemcpyHtoDAsync((GpuDevicePtr)d_dst, h_src, bytes, comm_stream), ctx, "cuMemcpyHtoDAsync"); #endif } void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle, buffer_t d_src, buffer_t h_dst, size_t bytes) { + TAO_VLOG(1) << "ral_base_cuda_d2h, d_src: " << d_src << ", h_dst: " << h_dst + << ", bytes: "; auto* state = ctx->getResource(kRalBaseCudaContextState); + auto comm_stream = + static_cast(ctx)->getCommStream(); #if TENSORFLOW_USE_ROCM reportErrorIfAny( stream_executor::wrap::hipMemcpyDtoHAsync( @@ -651,7 +700,7 @@ void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle, ctx, "cuMemcpyDtoHAsync"); #else reportErrorIfAny( - cuMemcpyDtoHAsync(h_dst, (GpuDevicePtr)d_src, bytes, state->stream), ctx, + cuMemcpyDtoHAsync(h_dst, (GpuDevicePtr)d_src, bytes, comm_stream), ctx, "cuMemcpyDtoHAsync"); #endif } @@ -745,10 +794,14 @@ RAL_REGISTER_BITCAST_FUNC(bool, 7); RAL_REGISTER_BITCAST_FUNC(bool, 8); TAO_RAL_API(tao::ral::gpu::kRalGpuAlloc, "gpu", ral_base_cuda_alloc); +TAO_RAL_API(tao::ral::gpu::kRalGpuPinnedAlloc, "gpu", + ral_base_cuda_pinned_alloc); TAO_RAL_API(tao::ral::gpu::kRalGpuAllocPersistent, "gpu", ral_base_cuda_alloc_persistent); TAO_RAL_API(tao::ral::gpu::kRalGpuDealloc, "gpu", ral_base_cuda_dealloc); TAO_RAL_API(tao::ral::gpu::kRalGpuRawAlloc, "gpu", ral_base_cuda_raw_alloc); +TAO_RAL_API(tao::ral::gpu::kRalGpuPinnedDealloc, "gpu", + ral_base_cuda_pinned_dealloc); TAO_RAL_API(tao::ral::gpu::kRalGpuRawDealloc, "gpu", ral_base_cuda_raw_dealloc); TAO_RAL_API(tao::ral::gpu::kRalGpuLaunch, "gpu", ral_base_cuda_launch); TAO_RAL_API(tao::ral::gpu::kRalGpuGetStream, "gpu", ral_base_cuda_get_stream); diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h index 8d1fd2a7293..822984dfb48 100755 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h @@ -55,6 +55,7 @@ struct BaseCudaContextOption { bool use_stream_executor = true; bool cache_workspace_mem_across_execution = false; std::shared_ptr gpu_allocator; + std::shared_ptr cuda_host_allocator; }; std::unique_ptr MakeBaseCudaContext( diff --git a/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc b/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc index d7dad76b545..45e4605074a 100644 --- a/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc +++ b/tao_compiler/mlir/ral/device/gpu/gpu_driver.cc @@ -28,6 +28,8 @@ namespace gpu { const char* kRalGpuAlloc = "alloc"; const char* kRalGpuAllocPersistent = "ral_gpu_alloc_persistent"; +const char* kRalGpuPinnedAlloc = "ral_gpu_pinned_alloc"; +const char* kRalGpuPinnedDealloc = "ral_gpu_pinned_dealloc"; const char* kRalGpuDealloc = "dealloc"; const char* kRalGpuRawAlloc = "raw_gpu_alloc"; const char* kRalGpuRawDealloc = "raw_gpu_dealloc"; diff --git a/tao_compiler/mlir/ral/device/gpu/gpu_driver.h b/tao_compiler/mlir/ral/device/gpu/gpu_driver.h index eb934cd6e84..9fc48ec49de 100644 --- a/tao_compiler/mlir/ral/device/gpu/gpu_driver.h +++ b/tao_compiler/mlir/ral/device/gpu/gpu_driver.h @@ -35,6 +35,8 @@ using stream_t = void*; extern const char* kRalGpuAlloc; extern const char* kRalGpuAllocPersistent; +extern const char* kRalGpuPinnedAlloc; +extern const char* kRalGpuPinnedDealloc; extern const char* kRalGpuDealloc; extern const char* kRalGpuRawAlloc; extern const char* kRalGpuRawDealloc;