From 349ff70f765e92efde4ca910e1e8baef24e91344 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Mon, 24 Jun 2024 20:25:29 +0800 Subject: [PATCH 01/13] [compiler] add forall-tiling pass --- compiler/include/byteir/Dialect/SCF/Passes.h | 1 + compiler/include/byteir/Dialect/SCF/Passes.td | 24 ++ .../Dialect/SCF/Transforms/ForallTiling.h | 33 +++ .../lib/Dialect/SCF/Transforms/CMakeLists.txt | 3 +- .../Dialect/SCF/Transforms/ForallTiling.cpp | 264 ++++++++++++++++++ .../lib/Dialect/SCF/Transforms/PassDetail.h | 4 + compiler/test/Dialect/SCF/forallTiling.mlir | 69 +++++ 7 files changed, 397 insertions(+), 1 deletion(-) create mode 100644 compiler/include/byteir/Dialect/SCF/Transforms/ForallTiling.h create mode 100644 compiler/lib/Dialect/SCF/Transforms/ForallTiling.cpp create mode 100644 compiler/test/Dialect/SCF/forallTiling.mlir diff --git a/compiler/include/byteir/Dialect/SCF/Passes.h b/compiler/include/byteir/Dialect/SCF/Passes.h index e8d7427d5..dd6b5a5c8 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.h +++ b/compiler/include/byteir/Dialect/SCF/Passes.h @@ -19,6 +19,7 @@ #define BYTEIR_DIALECT_SCF_PASSES_H #include "byteir/Dialect/SCF/Transforms/ForallCollapsing.h" +#include "byteir/Dialect/SCF/Transforms/ForallTiling.h" #include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/SCF/Transforms/InsertTrivialSCFLoop.h" diff --git a/compiler/include/byteir/Dialect/SCF/Passes.td b/compiler/include/byteir/Dialect/SCF/Passes.td index a4f3cc0c5..44309c9f3 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.td +++ b/compiler/include/byteir/Dialect/SCF/Passes.td @@ -72,4 +72,28 @@ def ForallCollapsing : Pass<"forall-collapsing", "mlir::func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// ForallTiling +//===----------------------------------------------------------------------===// + +def ForallTiling : Pass<"forall-tiling"> { + let summary = "tile forall Op with specific tileSize"; + let constructor = "mlir::createForallTilingPass()"; + let dependentDialects = [ + "scf::SCFDialect", + "affine::AffineDialect" + ]; + let options = [ + ListOption<"tileSizes", "tile-sizes", "int64_t", + "Factors to tile forall">, + Option<"noMinMaxBounds", "no-min-max-bounds", "bool", + /*default=*/"false", + "Perform tiling with fixed upper bound with inbound check " + "inside the internal loops">, + Option<"anchorTag", "anchor-tag", "std::string", + /*default=*/"", + "Optional unitAttr anchored tag to apply this pass">, + ]; +} + #endif // BYTEIR_DIALECT_SCF_PASSES \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/SCF/Transforms/ForallTiling.h b/compiler/include/byteir/Dialect/SCF/Transforms/ForallTiling.h new file mode 100644 index 000000000..3363fde21 --- /dev/null +++ b/compiler/include/byteir/Dialect/SCF/Transforms/ForallTiling.h @@ -0,0 +1,33 @@ +//===- ForallTiling.h ------------------------------------- C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLTILING_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLTILING_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { + +std::unique_ptr +createForallTilingPass(llvm::ArrayRef tileSize = {}, + bool noMinMaxBounds = false, + llvm::StringRef anchorTag = ""); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLTILING_H diff --git a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt index 967c79b5f..4891d0741 100644 --- a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(ByteIRSCFPasses ForallCollapsing.cpp + ForallTiling.cpp FuseNestedForall.cpp InsertTrivialSCFLoop.cpp TilingInterfaceToSCFFor.cpp @@ -21,4 +22,4 @@ add_mlir_dialect_library(ByteIRSCFPasses MLIRSCFTransforms MLIRSideEffectInterfaces MLIRSupport - ) +) diff --git a/compiler/lib/Dialect/SCF/Transforms/ForallTiling.cpp b/compiler/lib/Dialect/SCF/Transforms/ForallTiling.cpp new file mode 100644 index 000000000..5bd827494 --- /dev/null +++ b/compiler/lib/Dialect/SCF/Transforms/ForallTiling.cpp @@ -0,0 +1,264 @@ +//===- ForallTiling.cpp ------------------------------------ C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +// in LLVM project +// Orignal license: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/SCF/Transforms/ForallTiling.h" +#include "byteir/Utils/LoopUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h" +#include + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::scf; + +namespace { +std::pair tileForall(ForallOp forallOp, + ArrayRef tileSizes, + bool noMinMaxBounds) { + OpBuilder builder(forallOp); + auto loc = forallOp.getLoc(); + auto zero = builder.create(loc, 0); + SmallVector tileSizeConstants; + int64_t rank = forallOp.getRank(); + tileSizeConstants.reserve(rank); + for (int64_t i = 0; i < rank; ++i) { + tileSizeConstants.push_back( + builder.create(loc, tileSizes[i])); + } + + SmallVector oriSteps; + oriSteps = forallOp.getStep(builder); + + SmallVector outerSteps, outerLowerBounds, outerUpperBounds; + outerLowerBounds = forallOp.getLowerBound(builder); + outerUpperBounds = forallOp.getUpperBound(builder); + + outerSteps.reserve(rank); + + for (int64_t i = 0; i < rank; ++i) { + if (tileSizes[i] == 0) { + outerSteps.push_back(oriSteps[i]); + } else { + outerSteps.push_back(builder.create(loc, oriSteps[i], + tileSizeConstants[i])); + } + } + + auto outerForall = builder.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + getAsOpFoldResult(outerSteps), ValueRange(), forallOp.getMapping()); + + builder.setInsertionPointToStart(outerForall.getBody()); + + // Compute min(size, dim - offset) to avoid out-of-bounds accesses. + auto minMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, + {getAffineDimExpr(/*position=*/0, builder.getContext()), + getAffineDimExpr(/*position=*/1, builder.getContext()) - + getAffineDimExpr(/*position=*/2, builder.getContext())}, + builder.getContext()); + + SmallVector innerUpperBounds, innerSteps; + SmallVector tiledOuterUpperBounds, tiledOuterIVs; + + innerUpperBounds.reserve(rank); + bool needInboundCheck = false; + for (auto [lowerBound, upperBound, newStep, iv, oriStep, tileSizeConstant] : + llvm::zip(outerLowerBounds, outerUpperBounds, outerSteps, + outerForall.getInductionVars(), oriSteps, tileSizeConstants)) { + // Collect the statically known loop bounds + auto lowerBoundConstant = + dyn_cast_or_null(lowerBound.getDefiningOp()); + auto upperBoundConstant = + dyn_cast_or_null(upperBound.getDefiningOp()); + auto stepConstant = + dyn_cast_or_null(oriStep.getDefiningOp()); + auto tileSize = + cast(tileSizeConstant.getDefiningOp()).value(); + if (tileSize == 0) { + continue; + } + innerSteps.push_back(oriStep); + tiledOuterUpperBounds.push_back(upperBound); + tiledOuterIVs.push_back(iv); + // If the loop bounds and the loop step are constant and if the number of + // loop iterations is an integer multiple of the tile size, we use a static + // bound for the inner loop. + if (lowerBoundConstant && upperBoundConstant && stepConstant) { + auto numIterations = llvm::divideCeil(upperBoundConstant.value() - + lowerBoundConstant.value(), + stepConstant.value()); + if (numIterations % tileSize == 0) { + innerUpperBounds.push_back(newStep); + continue; + } + } + + // For InboundCheck mode, just use the variable outer step + if (noMinMaxBounds) { + innerUpperBounds.push_back(newStep); + needInboundCheck = true; + continue; + } + + // Otherwise, we dynamically compute the bound for + // each iteration of the outer loop. + innerUpperBounds.push_back(builder.create( + loc, builder.getIndexType(), minMap, + ValueRange{newStep, upperBound, iv})); + } + + auto innerForall = builder.create( + loc, getAsOpFoldResult(SmallVector(innerUpperBounds.size(), zero)), + getAsOpFoldResult(innerUpperBounds), getAsOpFoldResult(innerSteps), + ValueRange(), std::nullopt); + + if (noMinMaxBounds && needInboundCheck) { + builder.setInsertionPointToStart(innerForall.getBody()); + // Insert in-bound check + Value inbound = + builder.create(loc, 1, builder.getIntegerType(1)); + for (auto [outerUpperBound, outerIV, innerIV, innerStep] : + llvm::zip(tiledOuterUpperBounds, tiledOuterIVs, + innerForall.getInductionVars(), innerSteps)) { + // %in_bound = %in_bound && + // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) + Value index = builder.create( + loc, builder.create(loc, innerIV, innerStep), outerIV); + Value dimInbound = builder.create( + loc, arith::CmpIPredicate::ult, index, outerUpperBound); + inbound = builder.create(loc, inbound, dimInbound); + } + auto ifInbound = + builder.create(loc, + /*resultTypes*/ ArrayRef{}, inbound, + /*hasElseRegion*/ false); + builder.setInsertionPointToStart(innerForall.getBody()); + for (int64_t i = 0, tiled = 0; i < rank; ++i) { + Value iv; + if (tileSizes[i] == 0) { + iv = outerForall.getInductionVars()[i]; + } else { + Value innerIndex = innerForall.getInductionVars()[tiled]; + Value outerIndex = tiledOuterIVs[tiled]; + iv = builder.create(loc, innerIndex, outerIndex); + tiled += 1; + } + replaceAllUsesInRegionWith(forallOp.getBody()->getArgument(i), iv, + forallOp.getRegion()); + } + Block &thenBlock = ifInbound.getThenRegion().front(); + forallOp.getBody()->back().erase(); + // Replace the old forall with innerForall forall. + thenBlock.getOperations().splice(Block::iterator(thenBlock.back()), + forallOp.getBody()->getOperations()); + } else { + builder.setInsertionPointToStart(innerForall.getBody()); + for (int64_t i = 0, tiled = 0; i < rank; ++i) { + Value iv; + if (tileSizes[i] == 0) { + iv = outerForall.getInductionVars()[i]; + } else { + Value innerIndex = innerForall.getInductionVars()[tiled]; + Value outerIndex = tiledOuterIVs[tiled]; + iv = builder.create(loc, innerIndex, outerIndex); + tiled += 1; + } + replaceAllUsesInRegionWith(forallOp.getBody()->getArgument(i), iv, + forallOp.getRegion()); + } + // Replace the old forall with innerForall forall. + innerForall.getBody()->getOperations().splice( + Block::iterator(innerForall.getBody()->back()), + forallOp.getBody()->getOperations()); + // erase redudant scf.forall.in_parallel + innerForall.getBody()->back().erase(); + } + + // erase old forall + forallOp.erase(); + return std::make_pair(outerForall, innerForall); +} + +struct ForallTilingPass : public ForallTilingBase { + ForallTilingPass(ArrayRef tileSizes, bool noMinMaxBounds, + llvm::StringRef anchor) + : ForallTilingBase() { + anchorTag = anchor.str(); + this->tileSizes = tileSizes; + this->noMinMaxBounds = noMinMaxBounds; + } + void runOnOperation() override { + Operation *rootOp = getOperation(); + + SmallVector candidateForall; + if (llvm::all_of(tileSizes, [](int64_t val) { return val == 0; })) { + return; + } + + rootOp->walk([&](scf::ForallOp forallOp) { + // skip non-anchored + if (!anchorTag.empty() && !forallOp->hasAttr(anchorTag)) { + return; + } + + if (forallOp.getRank() != tileSizes.size()) { + mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()), + "tile size is not match the forallOp"); + return signalPassFailure(); + } + + if (forallOp.getOutputs().size() > 0) { + mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()), + "forall with tensor share_outs is not support."); + return signalPassFailure(); + } + candidateForall.emplace_back(forallOp); + }); + + for (auto forallOp : candidateForall) { + tileForall(forallOp, tileSizes, noMinMaxBounds); + } + } +}; + +} // namespace + +std::unique_ptr mlir::createForallTilingPass(ArrayRef tileSizes, + bool noMinMaxBounds, + llvm::StringRef anchor) { + return std::make_unique(tileSizes, noMinMaxBounds, anchor); +} diff --git a/compiler/lib/Dialect/SCF/Transforms/PassDetail.h b/compiler/lib/Dialect/SCF/Transforms/PassDetail.h index 150246c55..f728745e4 100644 --- a/compiler/lib/Dialect/SCF/Transforms/PassDetail.h +++ b/compiler/lib/Dialect/SCF/Transforms/PassDetail.h @@ -27,6 +27,10 @@ namespace scf { class SCFDialect; } // namespace scf +namespace affine { +class AffineDialect; +} // namepsace affine + #define GEN_PASS_CLASSES #include "byteir/Dialect/SCF/Passes.h.inc" diff --git a/compiler/test/Dialect/SCF/forallTiling.mlir b/compiler/test/Dialect/SCF/forallTiling.mlir new file mode 100644 index 000000000..6add041b5 --- /dev/null +++ b/compiler/test/Dialect/SCF/forallTiling.mlir @@ -0,0 +1,69 @@ +// RUN: byteir-opt %s --forall-tiling="tile-sizes=256" --split-input-file --canonicalize --cse | FileCheck %s + +func.func @Copy(%arg0: memref<32x64xf32>, %arg1: memref<32x64xf32>) attributes {__byteir_reduction_fusion__} { + %c64 = arith.constant 64 : index + scf.forall (%arg2) in (2048) { + %0 = arith.remsi %arg2, %c64 : index + %1 = arith.divsi %arg2, %c64 : index + %2 = memref.load %arg0[%1, %0] : memref<32x64xf32> + memref.store %2, %arg1[%1, %0] : memref<32x64xf32> + } + return +} + +// CHECK-LABEL: func.func @Copy +// CHECK-NEXT: %[[C64:.*]] = arith.constant 64 : index +// CHECK-NEXT: scf.forall (%[[ARG2:.*]]) = (0) to (2048) step (256) { + // CHECK-NEXT: scf.forall (%[[ARG3:.*]]) in (256) { + // CHECK-NEXT: %[[V0:.*]] = arith.addi %[[ARG3]], %[[ARG2]] : index + // CHECK-NEXT: %[[V1:.*]] = arith.remsi %[[V0]], %[[C64]] : index + // CHECK-NEXT: %[[V2:.*]] = arith.divsi %[[V0]], %[[C64]] : index + // CHECK-NEXT: %[[V3:.*]] = memref.load %arg0[%[[V2]], %[[V1]]] : memref<32x64xf32> + // CHECK-NEXT: memref.store %[[V3]], %arg1[%[[V2]], %[[V1]]] : memref<32x64xf32> + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @Elementwise(%arg0: memref<32x1024x?x30xf32>) -> memref<32768x?x30xf32> attributes {__byteir_elementwise_fusion__} { + %c983040 = arith.constant 983040 : index + %c30 = arith.constant 30 : index + %c2 = arith.constant 2 : index + %collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2], [3]] : memref<32x1024x?x30xf32> into memref<32768x?x30xf32> + %dim = memref.dim %arg0, %c2 : memref<32x1024x?x30xf32> + %alloc = memref.alloc(%dim) : memref<32768x?x30xf32> + %0 = arith.muli %dim, %c983040 : index + scf.forall (%arg1) in (%0) { + %1 = arith.remsi %arg1, %c30 : index + %2 = arith.divsi %arg1, %c30 : index + %3 = arith.remsi %2, %dim : index + %4 = arith.divsi %2, %dim : index + %subview = memref.subview %collapse_shape[%4, %3, %1] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> + %subview_0 = memref.subview %alloc[%4, %3, %1] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%subview : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) outs(%subview_0 : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} { + ^bb0(%in: f32, %out: f32): + %5 = arith.mulf %in, %in : f32 + linalg.yield %5 : f32 + } + } + return %alloc : memref<32768x?x30xf32> +} + +// CHECK: #[[$MAP_LOOP_SIZE:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 256)> +// CHECK-LABEL: func.func @Elementwise +// CHECK-DAG: %[[C983040:.*]] = arith.constant 983040 : index +// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 +// CHECK: %[[DIM:.*]] = memref.dim %arg0, %[[C2]] : memref<32x1024x?x30xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<32768x?x30xf32> +// CHECK-NEXT: %[[LB:.*]] = arith.muli %dim, %[[C983040]] : index +// CHECK-NEXT: scf.forall (%[[ARG1:.*]]) = (0) to (%[[LB]]) step (256) { + // CHECK-NEXT: %[[V1:.*]] = affine.min #[[$MAP_LOOP_SIZE]](%[[ARG1]])[%[[LB]]] + // CHECK-NEXT: scf.forall (%[[ARG2:.*]]) in (%[[V1:.*]]) { + // CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG2]], %[[ARG1]] : index + // CHECK-NEXT: %[[V3:.*]] = arith.remsi %[[V2]], %[[C30]] : index + // CHECK-NEXT: %[[V4:.*]] = arith.divsi %[[V2]], %[[C30]] : index + // CHECK-NEXT: %[[V5:.*]] = arith.remsi %[[V4]], %[[DIM]] : index + // CHECK-NEXT: %[[V6:.*]] = arith.divsi %[[V4]], %[[DIM]] : index + // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[COLLAPSE]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> + // CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> \ No newline at end of file From 262a6a42ec4a44e2ee24ac5b363fbee287801bd8 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Mon, 24 Jun 2024 21:17:03 +0800 Subject: [PATCH 02/13] format --- compiler/include/byteir/Dialect/SCF/Passes.td | 2 +- compiler/lib/Dialect/SCF/Transforms/PassDetail.h | 2 +- compiler/test/Dialect/SCF/forallTiling.mlir | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/include/byteir/Dialect/SCF/Passes.td b/compiler/include/byteir/Dialect/SCF/Passes.td index 44309c9f3..1ce451e8d 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.td +++ b/compiler/include/byteir/Dialect/SCF/Passes.td @@ -77,7 +77,7 @@ def ForallCollapsing : Pass<"forall-collapsing", "mlir::func::FuncOp"> { //===----------------------------------------------------------------------===// def ForallTiling : Pass<"forall-tiling"> { - let summary = "tile forall Op with specific tileSize"; + let summary = "tile forall Op with given tileSize"; let constructor = "mlir::createForallTilingPass()"; let dependentDialects = [ "scf::SCFDialect", diff --git a/compiler/lib/Dialect/SCF/Transforms/PassDetail.h b/compiler/lib/Dialect/SCF/Transforms/PassDetail.h index f728745e4..8a56f8ed3 100644 --- a/compiler/lib/Dialect/SCF/Transforms/PassDetail.h +++ b/compiler/lib/Dialect/SCF/Transforms/PassDetail.h @@ -29,7 +29,7 @@ class SCFDialect; namespace affine { class AffineDialect; -} // namepsace affine +} // namespace affine #define GEN_PASS_CLASSES #include "byteir/Dialect/SCF/Passes.h.inc" diff --git a/compiler/test/Dialect/SCF/forallTiling.mlir b/compiler/test/Dialect/SCF/forallTiling.mlir index 6add041b5..2b5c0b63d 100644 --- a/compiler/test/Dialect/SCF/forallTiling.mlir +++ b/compiler/test/Dialect/SCF/forallTiling.mlir @@ -66,4 +66,4 @@ func.func @Elementwise(%arg0: memref<32x1024x?x30xf32>) -> memref<32768x?x30xf32 // CHECK-NEXT: %[[V5:.*]] = arith.remsi %[[V4]], %[[DIM]] : index // CHECK-NEXT: %[[V6:.*]] = arith.divsi %[[V4]], %[[DIM]] : index // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[COLLAPSE]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> - // CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> \ No newline at end of file + // CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> From 2ad8917f2e76edd8ee648d789bfca0ea69c910c8 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Tue, 25 Jun 2024 15:32:24 +0800 Subject: [PATCH 03/13] [compiler] add forall-normalize-pass --- compiler/include/byteir/Dialect/SCF/Passes.h | 1 + compiler/include/byteir/Dialect/SCF/Passes.td | 18 +++ .../Dialect/SCF/Transforms/ForallNormalize.h | 30 +++++ .../include/byteir/Dialect/SCF/Util/Util.h | 14 ++ .../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 + .../SCF/Transforms/ForallCollapsing.cpp | 74 +--------- .../SCF/Transforms/ForallNormalize.cpp | 101 ++++++++++++++ compiler/lib/Dialect/SCF/Util/Util.cpp | 127 ++++++++++++++++++ 8 files changed, 294 insertions(+), 72 deletions(-) create mode 100644 compiler/include/byteir/Dialect/SCF/Transforms/ForallNormalize.h create mode 100644 compiler/lib/Dialect/SCF/Transforms/ForallNormalize.cpp diff --git a/compiler/include/byteir/Dialect/SCF/Passes.h b/compiler/include/byteir/Dialect/SCF/Passes.h index dd6b5a5c8..9096af759 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.h +++ b/compiler/include/byteir/Dialect/SCF/Passes.h @@ -20,6 +20,7 @@ #include "byteir/Dialect/SCF/Transforms/ForallCollapsing.h" #include "byteir/Dialect/SCF/Transforms/ForallTiling.h" +#include "byteir/Dialect/SCF/Transforms/ForallNormalize.h" #include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/SCF/Transforms/InsertTrivialSCFLoop.h" diff --git a/compiler/include/byteir/Dialect/SCF/Passes.td b/compiler/include/byteir/Dialect/SCF/Passes.td index 1ce451e8d..20320991d 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.td +++ b/compiler/include/byteir/Dialect/SCF/Passes.td @@ -96,4 +96,22 @@ def ForallTiling : Pass<"forall-tiling"> { ]; } +//===----------------------------------------------------------------------===// +// ForallNormalize +//===----------------------------------------------------------------------===// + +def ForallNormalize : Pass<"forall-normalize"> { + let summary = "norlize forall"; + let constructor = "mlir::createForallNormalizePass()"; + let dependentDialects = [ + "scf::SCFDialect", + "affine::AffineDialect" + ]; + let options = [ + Option<"anchorTag", "anchor-tag", "std::string", + /*default=*/"", + "Optional unitAttr anchored tag to apply this pass"> + ]; +} + #endif // BYTEIR_DIALECT_SCF_PASSES \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/SCF/Transforms/ForallNormalize.h b/compiler/include/byteir/Dialect/SCF/Transforms/ForallNormalize.h new file mode 100644 index 000000000..c924511fe --- /dev/null +++ b/compiler/include/byteir/Dialect/SCF/Transforms/ForallNormalize.h @@ -0,0 +1,30 @@ +//===- ForallNormalize.h ------------------------------------- C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLNORMALIZE_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLNORMALIZE_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { + +std::unique_ptr createForallNormalizePass(llvm::StringRef anchorTag = ""); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLNORMALIZE_H diff --git a/compiler/include/byteir/Dialect/SCF/Util/Util.h b/compiler/include/byteir/Dialect/SCF/Util/Util.h index 2d039917b..d2bc86d8a 100644 --- a/compiler/include/byteir/Dialect/SCF/Util/Util.h +++ b/compiler/include/byteir/Dialect/SCF/Util/Util.h @@ -26,6 +26,20 @@ namespace mlir { namespace scf { +// This structure is to pass and return sets of loop parameters without +// confusing the order. +struct LoopParams { + Value lowerBound; + Value upperBound; + Value step; +}; + +/// Return the new lower bound, upper bound, and step in that order. Insert any +/// additional bounds calculations before the given builder and any additional +/// conversion back to the original loop induction value inside the given Block. +LoopParams normalizeLoop(OpBuilder &boundsBuilder, OpBuilder &insideLoopBuilder, + Location loc, Value lowerBound, Value upperBound, + Value step, Value inductionVar); SmallVector createNestedEmptyScfForOps(OpBuilder &b, Location loc, ArrayRef lowerBounds, diff --git a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt index 4891d0741..c4a030f7f 100644 --- a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(ByteIRSCFPasses ForallCollapsing.cpp + ForallNormalize.cpp ForallTiling.cpp FuseNestedForall.cpp InsertTrivialSCFLoop.cpp diff --git a/compiler/lib/Dialect/SCF/Transforms/ForallCollapsing.cpp b/compiler/lib/Dialect/SCF/Transforms/ForallCollapsing.cpp index e151dcb63..137749ba1 100644 --- a/compiler/lib/Dialect/SCF/Transforms/ForallCollapsing.cpp +++ b/compiler/lib/Dialect/SCF/Transforms/ForallCollapsing.cpp @@ -14,16 +14,9 @@ // limitations under the License. // //===----------------------------------------------------------------------===// -// Some code comes from mlir/lib/Dialect/SCF/Utils/Utils.cpp in LLVM project -// Orignal license: -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// #include "byteir/Dialect/SCF/Transforms/ForallCollapsing.h" +#include "byteir/Dialect/SCF/Util/Util.h" #include "byteir/Utils/LoopUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -43,69 +36,6 @@ using namespace mlir; using namespace mlir::scf; namespace { -// This structure is to pass and return sets of loop parameters without -// confusing the order. -struct LoopParams { - Value lowerBound; - Value upperBound; - Value step; -}; - -/// Return the new lower bound, upper bound, and step in that order. Insert any -/// additional bounds calculations before the given builder and any additional -/// conversion back to the original loop induction value inside the given Block. -static LoopParams normalizeLoop(OpBuilder &boundsBuilder, - OpBuilder &insideLoopBuilder, Location loc, - Value lowerBound, Value upperBound, Value step, - Value inductionVar) { - // Check if the loop is already known to have a constant zero lower bound or - // a constant one step. - bool isZeroBased = false; - if (auto ubCst = getConstantIntValue(lowerBound)) - isZeroBased = ubCst.value() == 0; - - bool isStepOne = false; - if (auto stepCst = getConstantIntValue(step)) - isStepOne = stepCst.value() == 1; - - // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) - // assuming the step is strictly positive. Update the bounds and the step - // of the loop to go from 0 to the number of iterations, if necessary. - if (isZeroBased && isStepOne) - return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound, - /*step=*/step}; - - Value diff = boundsBuilder.create(loc, upperBound, lowerBound); - Value newUpperBound = - boundsBuilder.create(loc, diff, step); - - Value newLowerBound = - isZeroBased ? lowerBound - : boundsBuilder.create( - loc, boundsBuilder.getZeroAttr(lowerBound.getType())); - Value newStep = - isStepOne ? step - : boundsBuilder.create( - loc, boundsBuilder.getIntegerAttr(step.getType(), 1)); - - // Insert code computing the value of the original loop induction variable - // from the "normalized" one. - Value scaled = - isStepOne - ? inductionVar - : insideLoopBuilder.create(loc, inductionVar, step); - Value shifted = - isZeroBased - ? scaled - : insideLoopBuilder.create(loc, scaled, lowerBound); - - SmallPtrSet preserve{scaled.getDefiningOp(), - shifted.getDefiningOp()}; - inductionVar.replaceAllUsesExcept(shifted, preserve); - return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound, - /*step=*/newStep}; -} - void collapseForallImpl(scf::ForallOp forallOp) { OpBuilder outsideBuilder(forallOp); Location loc = forallOp.getLoc(); @@ -120,7 +50,7 @@ void collapseForallImpl(scf::ForallOp forallOp) { for (size_t i = 0, e = forallOp.getRank(); i < e; ++i) { OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(forallOp.getBody()); - auto resultBounds = normalizeLoop( + auto resultBounds = mlir::scf::normalizeLoop( outsideBuilder, insideLoopBuilder, loc, oriLowerBounds[i], oriUpperBounds[i], oriSteps[i], forallOp.getBody()->getArgument(i)); diff --git a/compiler/lib/Dialect/SCF/Transforms/ForallNormalize.cpp b/compiler/lib/Dialect/SCF/Transforms/ForallNormalize.cpp new file mode 100644 index 000000000..55eebc75d --- /dev/null +++ b/compiler/lib/Dialect/SCF/Transforms/ForallNormalize.cpp @@ -0,0 +1,101 @@ +//===- ForallNormalize.cpp ------------------------------------ C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +// in LLVM project +// Orignal license: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/SCF/Transforms/ForallNormalize.h" +#include "byteir/Dialect/SCF/Util/Util.h" +#include "byteir/Utils/LoopUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/RegionUtils.h" +#include + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::scf; + +namespace { +struct ForallNormalizePass : public ForallNormalizeBase { + ForallNormalizePass(llvm::StringRef anchor) : ForallNormalizeBase() { + anchorTag = anchor.str(); + } + void runOnOperation() override { + Operation *rootOp = getOperation(); + + rootOp->walk([&](scf::ForallOp forallOp) { + // skip non-anchored + if (!anchorTag.empty() && !forallOp->hasAttr(anchorTag)) { + return; + } + SmallVector normalizedLowerBound, normalizedStep, + normalizedUpperBound; + + OpBuilder outsideBuilder(forallOp); + SmallVector oriLowerBounds, oriSteps, oriUpperBounds; + oriLowerBounds = forallOp.getLowerBound(outsideBuilder); + oriSteps = forallOp.getStep(outsideBuilder); + oriUpperBounds = forallOp.getUpperBound(outsideBuilder); + for (size_t i = 0, e = forallOp.getRank(); i < e; ++i) { + OpBuilder insideLoopBuilder = + OpBuilder::atBlockBegin(forallOp.getBody()); + auto resultBounds = mlir::scf::normalizeLoop( + outsideBuilder, insideLoopBuilder, forallOp.getLoc(), + oriLowerBounds[i], oriUpperBounds[i], oriSteps[i], + forallOp.getBody()->getArgument(i)); + + normalizedLowerBound.push_back(resultBounds.lowerBound); + normalizedUpperBound.push_back(resultBounds.upperBound); + normalizedStep.push_back(resultBounds.step); + } + + SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep; + SmallVector staticLowerBound, staticUpperBound, staticStep; + dispatchIndexOpFoldResults(getAsOpFoldResult(normalizedLowerBound), + dynamicLowerBound, staticLowerBound); + forallOp.getDynamicLowerBoundMutable().assign(dynamicLowerBound); + forallOp.setStaticLowerBound(staticLowerBound); + + dispatchIndexOpFoldResults(getAsOpFoldResult(normalizedUpperBound), + dynamicUpperBound, staticUpperBound); + forallOp.getDynamicUpperBoundMutable().assign(dynamicUpperBound); + forallOp.setStaticUpperBound(staticUpperBound); + + dispatchIndexOpFoldResults(getAsOpFoldResult(normalizedStep), dynamicStep, + staticStep); + forallOp.getDynamicStepMutable().assign(dynamicStep); + forallOp.setStaticStep(staticStep); + }); + } +}; + +} // namespace + +std::unique_ptr mlir::createForallNormalizePass(llvm::StringRef anchor) { + return std::make_unique(anchor); +} diff --git a/compiler/lib/Dialect/SCF/Util/Util.cpp b/compiler/lib/Dialect/SCF/Util/Util.cpp index 440ad813b..f3270a573 100644 --- a/compiler/lib/Dialect/SCF/Util/Util.cpp +++ b/compiler/lib/Dialect/SCF/Util/Util.cpp @@ -14,16 +14,143 @@ // limitations under the License. // //===----------------------------------------------------------------------===// +// Some code comes from mlir/lib/Dialect/SCF/Utils/Utils.cpp in LLVM project +// Orignal license: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// #include "byteir/Dialect/SCF/Util/Util.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace scf; +mlir::scf::LoopParams mlir::scf::normalizeLoop(OpBuilder &boundsBuilder, + OpBuilder &insideLoopBuilder, + Location loc, Value lowerBound, + Value upperBound, Value step, + Value inductionVar) { + // Check if the loop is already known to have a constant zero lower bound or + // a constant one step. + bool isZeroBased = false; + if (auto ubCst = getConstantIntValue(lowerBound)) + isZeroBased = ubCst.value() == 0; + + bool isStepOne = false; + if (auto stepCst = getConstantIntValue(step)) + isStepOne = stepCst.value() == 1; + + // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) + // assuming the step is strictly positive. Update the bounds and the step + // of the loop to go from 0 to the number of iterations, if necessary. + if (isZeroBased && isStepOne) + return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound, + /*step=*/step}; + + Value diff = boundsBuilder.create(loc, upperBound, lowerBound); + Value newUpperBound = + boundsBuilder.create(loc, diff, step); + + Value newLowerBound = + isZeroBased ? lowerBound + : boundsBuilder.create( + loc, boundsBuilder.getZeroAttr(lowerBound.getType())); + Value newStep = + isStepOne ? step + : boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr(step.getType(), 1)); + + // Insert code computing the value of the original loop induction variable + // from the "normalized" one. + Value scaled = + isStepOne + ? inductionVar + : insideLoopBuilder.create(loc, inductionVar, step); + Value shifted = + isZeroBased + ? scaled + : insideLoopBuilder.create(loc, scaled, lowerBound); + + SmallPtrSet preserve{scaled.getDefiningOp(), + shifted.getDefiningOp()}; + inductionVar.replaceAllUsesExcept(shifted, preserve); + return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound, + /*step=*/newStep}; +} + +void collapseForallImpl(scf::ForallOp forallOp) { + OpBuilder outsideBuilder(forallOp); + Location loc = forallOp.getLoc(); + + // Normalize forallOp's iteration pattern. + SmallVector normalizedLowerBounds, normalizedSteps, + normalizedUpperBounds; + SmallVector oriLowerBounds, oriSteps, oriUpperBounds; + oriLowerBounds = forallOp.getLowerBound(outsideBuilder); + oriSteps = forallOp.getStep(outsideBuilder); + oriUpperBounds = forallOp.getUpperBound(outsideBuilder); + + for (size_t i = 0, e = forallOp.getRank(); i < e; ++i) { + OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(forallOp.getBody()); + auto resultBounds = normalizeLoop( + outsideBuilder, insideLoopBuilder, loc, oriLowerBounds[i], + oriUpperBounds[i], oriSteps[i], forallOp.getBody()->getArgument(i)); + + normalizedLowerBounds.push_back(resultBounds.lowerBound); + normalizedUpperBounds.push_back(resultBounds.upperBound); + normalizedSteps.push_back(resultBounds.step); + } + Value newUpperBound = outsideBuilder.create(loc, 1); + // after normalize: lowerBound = 0, step = 1 + auto cst0 = outsideBuilder.create(loc, 0); + auto cst1 = outsideBuilder.create(loc, 1); + for (size_t i = 0, e = forallOp.getRank(); i < e; ++i) { + newUpperBound = outsideBuilder.create( + loc, newUpperBound, normalizedUpperBounds[i]); + } + + auto outputs = llvm::to_vector(forallOp.getOutputs()); + auto newForall = outsideBuilder.create( + loc, ArrayRef({cst0}), + ArrayRef({newUpperBound}), ArrayRef({cst1}), + outputs, std::nullopt, + [&](OpBuilder &insideBuilder, Location loc, ValueRange regionArgs) { + Value previous = regionArgs[0]; + for (int64_t i = forallOp.getRank() - 1; i > 0; --i) { + + Value iv = insideBuilder.create( + loc, previous, normalizedUpperBounds[i]); + replaceAllUsesInRegionWith(forallOp.getBody()->getArgument(i), iv, + forallOp.getRegion()); + + previous = insideBuilder.create( + loc, previous, normalizedUpperBounds[i]); + } + + replaceAllUsesInRegionWith(forallOp.getBody()->getArgument(0), previous, + forallOp.getRegion()); + insideBuilder.create(loc); + }); + + // Replace the old forall with the new forall. + newForall.getBody()->getOperations().splice( + Block::iterator(newForall.getBody()->back()), + forallOp.getBody()->getOperations()); + // erase redudant scf.forall.in_parallel + newForall.getBody()->back().erase(); + // erase old forall + forallOp.erase(); +} + SmallVector mlir::scf::createNestedEmptyScfForOps( OpBuilder &b, Location loc, ArrayRef lowerBounds, ArrayRef upperBounds, ArrayRef steps) { From 3187ffaaa418de418c3d01a70b73ace140bedca8 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Tue, 25 Jun 2024 16:40:09 +0800 Subject: [PATCH 04/13] add elementwise tiling pipeline --- .../byteir/Pipelines/GPU/ElementwiseCodegen.h | 14 ++++ .../lib/Pipelines/GPU/ElementwiseCodegen.cpp | 83 +++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/compiler/include/byteir/Pipelines/GPU/ElementwiseCodegen.h b/compiler/include/byteir/Pipelines/GPU/ElementwiseCodegen.h index 87e81f40b..a18ea6c26 100644 --- a/compiler/include/byteir/Pipelines/GPU/ElementwiseCodegen.h +++ b/compiler/include/byteir/Pipelines/GPU/ElementwiseCodegen.h @@ -40,14 +40,28 @@ struct GPUTileElementwiseOptions llvm::cl::init(256)}; }; +struct GPUTileElementwiseInSCFOptions + : public PassPipelineOptions { + Option maxBlockSize{*this, "max-block-size", + llvm::cl::desc("max block size"), + llvm::cl::init(256)}; +}; + void createGPUTileElementwiseTransform( OpPassManager &pm, const GPUTileElementwiseOptions &options); +void createGPUTileElementwiseInSCF( + OpPassManager &pm, const GPUTileElementwiseInSCFOptions &options); + inline void registerGPUElementwiseCodegenPipelines() { PassPipelineRegistration( "insert-gpu-tile-elementwise-transform", "Insert transformation IR to tile linalg elementwise op", createGPUTileElementwiseTransform); + + PassPipelineRegistration( + "tile-elementwise-in-scf", "tile elementwise op with nested forallOp", + createGPUTileElementwiseInSCF); } } // namespace mlir diff --git a/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp b/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp index 3bb9c541f..79c286acf 100644 --- a/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp @@ -19,14 +19,20 @@ #include "byteir/Conversion/ToLLVM/ToLLVM.h" #include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" +#include "byteir/Dialect/SCF/Passes.h" #include "byteir/Dialect/Transform/IR/TransformExtOps.h" #include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" +#include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Pipelines/Common/Utils.h" +#include "byteir/Transforms/AnchoredPipeline.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallSet.h" #include @@ -163,6 +169,64 @@ bool isFusionTarget(linalg::GenericOp genericOp) { return true; } +template +SmallVector getGPUMappingAttr(OpBuilder b, int64_t rank) { + + SmallVector mapping; + mapping.reserve(rank); + int64_t dimMapping = static_cast(gpu::MappingId::LinearDim0); + for (int64_t i = 0; i < rank; ++i) { + mapping.emplace_back(static_cast(dimMapping++)); + } + + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { + return GPUAttrType::get(b.getContext(), dim); + })); + return mappingAttrs; +} + +struct MappingElementwiseToGPUPass + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MappingElementwiseToGPUPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + scf::ForallOp topLevelForallOp; + OpBuilder b(funcOp); + funcOp->walk([&](scf::ForallOp forallOp) { + if (auto parentForallOp = forallOp->getParentOfType()) { + if (parentForallOp.getMapping().has_value()) { + auto parentMappingAttrs = + llvm::to_vector(parentForallOp.getMappingAttr()); + bool hasBlockMapping = + llvm::any_of(parentMappingAttrs, [](Attribute attr) { + return isa(attr); + }); + if (hasBlockMapping) { + int64_t rank = forallOp.getRank(); + SmallVector curMapping = + getGPUMappingAttr(b, rank); + forallOp.setMappingAttr(b.getArrayAttr(curMapping)); + } + } + } else { + // top level forall + if (!forallOp.getMapping().has_value()) { + int64_t rank = forallOp.getRank(); + SmallVector curMapping = + getGPUMappingAttr(b, rank); + forallOp.setMappingAttr(b.getArrayAttr(curMapping)); + } + } + }); + } +}; + void createGPUTileElementwiseTransformImpl(OpPassManager &pm, const std::string &anchor, const std::string &prefix, @@ -198,6 +262,19 @@ void createGPUTileElementwiseTransformImpl(OpPassManager &pm, pm.addPass(createGenericTransformInsertionPass(config)); } + +void createGPUTileElementwiseInSCFImpl(OpPassManager &pm, + int64_t maxBlockSize) { + auto elementwiseAnchor = getByteIRElementwiseFusionAttrName().str(); + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createForallCollapsingPass()); + anchoredPM.addPass(createForallTilingPass({maxBlockSize})); + anchoredPM.addPass(std::make_unique()); + anchoredPM.addPass(createForallNormalizePass()); + pm.addNestedPass( + createAnchoredPipelinePass(elementwiseAnchor, anchoredPM)); +} + } // namespace void mlir::createGPUTileElementwiseTransform( @@ -206,3 +283,9 @@ void mlir::createGPUTileElementwiseTransform( options.funcAnchor, options.annotatePrefix, options.warpSize, options.blockSize); } + +void mlir::createGPUTileElementwiseInSCF( + OpPassManager &pm, const GPUTileElementwiseInSCFOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileElementwiseInSCFImpl, pm, + options.maxBlockSize); +} From 2ba4efb977f2d24fb5456da4cf0d88f2dd94020c Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Wed, 26 Jun 2024 11:25:35 +0800 Subject: [PATCH 05/13] [compiler] add for-to-forall pass --- compiler/include/byteir/Dialect/SCF/Passes.h | 3 +- compiler/include/byteir/Dialect/SCF/Passes.td | 19 ++++- .../Dialect/SCF/Transforms/ForToForall.h | 30 +++++++ .../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 + .../Dialect/SCF/Transforms/ForToForall.cpp | 85 +++++++++++++++++++ 5 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 compiler/include/byteir/Dialect/SCF/Transforms/ForToForall.h create mode 100644 compiler/lib/Dialect/SCF/Transforms/ForToForall.cpp diff --git a/compiler/include/byteir/Dialect/SCF/Passes.h b/compiler/include/byteir/Dialect/SCF/Passes.h index 9096af759..811483a96 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.h +++ b/compiler/include/byteir/Dialect/SCF/Passes.h @@ -18,9 +18,10 @@ #ifndef BYTEIR_DIALECT_SCF_PASSES_H #define BYTEIR_DIALECT_SCF_PASSES_H +#include "byteir/Dialect/SCF/Transforms/ForToForall.h" #include "byteir/Dialect/SCF/Transforms/ForallCollapsing.h" -#include "byteir/Dialect/SCF/Transforms/ForallTiling.h" #include "byteir/Dialect/SCF/Transforms/ForallNormalize.h" +#include "byteir/Dialect/SCF/Transforms/ForallTiling.h" #include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/SCF/Transforms/InsertTrivialSCFLoop.h" diff --git a/compiler/include/byteir/Dialect/SCF/Passes.td b/compiler/include/byteir/Dialect/SCF/Passes.td index 20320991d..f5f2fa3c6 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.td +++ b/compiler/include/byteir/Dialect/SCF/Passes.td @@ -101,7 +101,7 @@ def ForallTiling : Pass<"forall-tiling"> { //===----------------------------------------------------------------------===// def ForallNormalize : Pass<"forall-normalize"> { - let summary = "norlize forall"; + let summary = "normalize forall"; let constructor = "mlir::createForallNormalizePass()"; let dependentDialects = [ "scf::SCFDialect", @@ -114,4 +114,21 @@ def ForallNormalize : Pass<"forall-normalize"> { ]; } +//===----------------------------------------------------------------------===// +// ForallNormalize +//===----------------------------------------------------------------------===// + +def ForToForall : Pass<"for-to-forall"> { + let summary = "convert for to forall"; + let constructor = "mlir::createForToForallPass()"; + let dependentDialects = [ + "scf::SCFDialect", + ]; + let options = [ + Option<"anchorTag", "anchor-tag", "std::string", + /*default=*/"", + "Optional unitAttr anchored tag to apply this pass"> + ]; +} + #endif // BYTEIR_DIALECT_SCF_PASSES \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/SCF/Transforms/ForToForall.h b/compiler/include/byteir/Dialect/SCF/Transforms/ForToForall.h new file mode 100644 index 000000000..fced74e41 --- /dev/null +++ b/compiler/include/byteir/Dialect/SCF/Transforms/ForToForall.h @@ -0,0 +1,30 @@ +//===- ForToForall.h ------------------------------------- C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_FORTOFORALL_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_FORTOFORALL_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { + +std::unique_ptr createForToForallPass(llvm::StringRef anchorTag = ""); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_FORTOFORALL_H diff --git a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt index c4a030f7f..058ba0579 100644 --- a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(ByteIRSCFPasses ForallCollapsing.cpp ForallNormalize.cpp ForallTiling.cpp + ForToForall.cpp FuseNestedForall.cpp InsertTrivialSCFLoop.cpp TilingInterfaceToSCFFor.cpp diff --git a/compiler/lib/Dialect/SCF/Transforms/ForToForall.cpp b/compiler/lib/Dialect/SCF/Transforms/ForToForall.cpp new file mode 100644 index 000000000..6e1128aed --- /dev/null +++ b/compiler/lib/Dialect/SCF/Transforms/ForToForall.cpp @@ -0,0 +1,85 @@ +//===- ForToForall.cpp ------------------------------------ C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +// in LLVM project +// Orignal license: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/SCF/Transforms/ForToForall.h" +#include "byteir/Dialect/SCF/Util/Util.h" +#include "byteir/Utils/LoopUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/RegionUtils.h" +#include + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::scf; + +namespace { +struct ForToForallPass : public ForToForallBase { + ForToForallPass(llvm::StringRef anchor) : ForToForallBase() { + anchorTag = anchor.str(); + } + void runOnOperation() override { + Operation *rootOp = getOperation(); + + rootOp->walk([&](scf::ForOp forOp) { + // skip non-anchored + if (!anchorTag.empty() && !forOp->hasAttr(anchorTag)) { + return; + } + SmallVector initArgs = forOp.getInitArgs(); + if (initArgs.size() > 0) { + return; + } + + OpBuilder builder(forOp); + auto lb = forOp.getLowerBound(); + auto ub = forOp.getUpperBound(); + auto step = forOp.getStep(); + auto forallOp = builder.create( + forOp.getLoc(), llvm::ArrayRef{lb}, + llvm::ArrayRef{ub}, llvm::ArrayRef{step}, + initArgs, std::nullopt); + replaceAllUsesInRegionWith(forOp.getInductionVar(), + forallOp.getInductionVars()[0], + forOp.getRegion()); + forOp.getBody()->back().erase(); + forallOp.getBody()->getOperations().splice( + Block::iterator(forallOp.getBody()->back()), + forOp.getBody()->getOperations()); + forOp.erase(); + }); + } +}; + +} // namespace + +std::unique_ptr mlir::createForToForallPass(llvm::StringRef anchor) { + return std::make_unique(anchor); +} From 564e7b18ff7cca856f8abced037e369d318bc0ad Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Wed, 26 Jun 2024 11:26:59 +0800 Subject: [PATCH 06/13] refactor elementwise-op codegen pipeline --- .../lib/Pipelines/GPU/ElementwiseCodegen.cpp | 34 ++++++++++++++++++- compiler/lib/Pipelines/GPU/GPUOpt.cpp | 27 +++------------ compiler/lib/Pipelines/SCFOpt.cpp | 6 ++++ compiler/python/byteir/compile.py | 2 -- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp b/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp index 79c286acf..3f6e1c89a 100644 --- a/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp @@ -227,6 +227,38 @@ struct MappingElementwiseToGPUPass } }; +struct elementwiseForallSpecializationPass + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MappingElementwiseToGPUPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + bool hasForOp = (!funcOp.getOps().empty()); + PassManager pm(funcOp->getContext(), func::FuncOp::getOperationName()); + if (!hasForOp) { + // Note: a trivial loop will be removed by canonicalizer + // so no canonicalizer before used + pm.addPass(createInsertTrivialSCFLoopPass()); + } + pm.addPass(createForToForallPass()); + pm.addPass(createFuseNestedForallPass()); + pm.addPass(createForallCollapsingPass()); + + if (hasForOp) { + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + } + if (mlir::failed(runPipeline(pm, funcOp))) { + signalPassFailure(); + } + } +}; + void createGPUTileElementwiseTransformImpl(OpPassManager &pm, const std::string &anchor, const std::string &prefix, @@ -267,7 +299,7 @@ void createGPUTileElementwiseInSCFImpl(OpPassManager &pm, int64_t maxBlockSize) { auto elementwiseAnchor = getByteIRElementwiseFusionAttrName().str(); OpPassManager anchoredPM(func::FuncOp::getOperationName()); - anchoredPM.addPass(createForallCollapsingPass()); + anchoredPM.addPass(std::make_unique()); anchoredPM.addPass(createForallTilingPass({maxBlockSize})); anchoredPM.addPass(std::make_unique()); anchoredPM.addPass(createForallNormalizePass()); diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index 9b4ad345e..6d065e762 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -58,28 +58,11 @@ void createElementwiseGPUOptPipelineImpl(OpPassManager &pm, getByteIRElementwiseFusionAttrName(), anchoredPM)); } - // Note: a trivial loop will be removed by canonicalizer - // so no canonicalizer before used - pm.addNestedPass( - createInsertTrivialSCFLoopPass(getByteIRElementwiseFusionAttrName())); - - // attach ToGPUAttr - pm.addPass(createFuncTagPass(getByteIRElementwiseFusionAttrName(), - getToGPUAttrName())); - - std::string iteratorAttr = - getLoopToSIMTAttrName().str() + ":String:" + getLinearIdXName().str(); - - pm.addNestedPass( - createLoopTagPass(getByteIRElementwiseFusionAttrName(), iteratorAttr)); - - pm.addNestedPass(createLoopTagPass( - getByteIRElementwiseFusionAttrName(), getCoarsenSIMTAttrName().str())); - - pm.addPass(createConvertFuncToGPUPass(/*bs=*/{256, 1, 1})); - - addCleanUpExtPassPipeline(pm); - pm.addNestedPass(createGenPTXConfigPass(useBarePtrCallConv)); + // pm.addNestedPass(createGenPTXConfigPass(useBarePtrCallConv)); + GPUMappingForallOptions mappingOptions; + mappingOptions.funcAnchor = getByteIRElementwiseFusionAttrName().str(); + mappingOptions.blockDimsHint = llvm::cl::KernelDims{256, 1, 1}; + createGPUMappingForallTransform(pm, mappingOptions); } void createReductionGPUOptPipelineImpl(OpPassManager &pm) { diff --git a/compiler/lib/Pipelines/SCFOpt.cpp b/compiler/lib/Pipelines/SCFOpt.cpp index bf08dcd8e..3abf9e1ea 100644 --- a/compiler/lib/Pipelines/SCFOpt.cpp +++ b/compiler/lib/Pipelines/SCFOpt.cpp @@ -22,6 +22,7 @@ #include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" +#include "byteir/Pipelines/GPU/ElementwiseCodegen.h" #include "byteir/Transforms/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/Passes.h" @@ -36,6 +37,11 @@ using namespace mlir::affine; namespace { void addGenericSCFOptPasses(OpPassManager &pm) { + // for elementwise op + GPUTileElementwiseInSCFOptions tileOptions; + tileOptions.maxBlockSize = 256; + createGPUTileElementwiseInSCF(pm, tileOptions); + pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertLinalgExtToLoopsPass()); // lower affine.apply in case there is some diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 82acb7d79..55f0c7275 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -132,7 +132,6 @@ def _compile_cuda( PassManager.parse("builtin.module(gpu-opt)").run(module.operation) _print_verbose(module, "// IR Dump After GPU Opt:") if verbose else ... with context: - PassManager.parse("builtin.module(func.func(remove-func-body{anchor-attr=__byteir_elementwise_fusion__}))").run(module.operation) PassManager.parse("builtin.module(inline)").run(module.operation) PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation) if useBarePtrCallConv: @@ -242,7 +241,6 @@ def _compile_cuda_with_ait( PassManager.parse("builtin.module(gpu-opt)").run(processor.module.operation) _print_verbose(processor.module, "// IR Dump After GPU Opt:") if verbose else ... with context: - PassManager.parse("builtin.module(func.func(remove-func-body{anchor-attr=__byteir_elementwise_fusion__}))").run(processor.module.operation) PassManager.parse("builtin.module(inline)").run(processor.module.operation) PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation) if useBarePtrCallConv: From 550dc933c931f3b14fc85bc6e3508768ee6ce2cc Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Fri, 28 Jun 2024 15:44:58 +0800 Subject: [PATCH 07/13] [compiler] draft: naive horizontal fusion --- .../byteir/Transforms/HorizontalFusion.h | 31 ++ compiler/include/byteir/Transforms/Passes.h | 1 + compiler/include/byteir/Transforms/Passes.td | 12 + compiler/lib/Transforms/CMakeLists.txt | 1 + compiler/lib/Transforms/HorizontalFusion.cpp | 373 ++++++++++++++++++ compiler/lib/Transforms/PassDetail.h | 5 + compiler/lib/Utils/LoopUtils.cpp | 2 +- scripts/prepare.sh | 4 +- 8 files changed, 426 insertions(+), 3 deletions(-) create mode 100644 compiler/include/byteir/Transforms/HorizontalFusion.h create mode 100644 compiler/lib/Transforms/HorizontalFusion.cpp diff --git a/compiler/include/byteir/Transforms/HorizontalFusion.h b/compiler/include/byteir/Transforms/HorizontalFusion.h new file mode 100644 index 000000000..9302aaaff --- /dev/null +++ b/compiler/include/byteir/Transforms/HorizontalFusion.h @@ -0,0 +1,31 @@ +//===- GraphClusteringByDevice.h ------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_TRANSFORMS_HORIZONTALFUSION_H +#define BYTEIR_TRANSFORMS_HORIZONTALFUSION_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +class ModuleOp; + +std::unique_ptr> createHorizontalFusionPass(); + +} // namespace mlir + +#endif // BYTEIR_TRANSFORMS_HORIZONTALFUSION_H diff --git a/compiler/include/byteir/Transforms/Passes.h b/compiler/include/byteir/Transforms/Passes.h index 03dfe839e..96d2b749e 100644 --- a/compiler/include/byteir/Transforms/Passes.h +++ b/compiler/include/byteir/Transforms/Passes.h @@ -38,6 +38,7 @@ #include "byteir/Transforms/SetSpace.h" #include "byteir/Transforms/ShapeFuncOutlining.h" #include "byteir/Transforms/TryCatchModulePipeline.h" +#include "byteir/Transforms/HorizontalFusion.h" namespace mlir { diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index b92f1de90..06d4c2cdc 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -253,6 +253,18 @@ def GraphClusteringByDevice : Pass<"graph-clustering-by-device", "ModuleOp"> { ]; } +//===----------------------------------------------------------------------===// +// HorizontalFusion +//===----------------------------------------------------------------------===// +def HorizontalFusion : Pass<"horizontal-fusion-on-scf", "ModuleOp"> { + let summary = "Horizontal fusion based on scf.forall and memref."; + let constructor = "mlir::createHorizontalFusionPass()"; + let dependentDialects = [ + "scf::SCFDialect", + "memref::MemRefDialect" + ]; +} + //===----------------------------------------------------------------------===// // LoopTag //===----------------------------------------------------------------------===// diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index f1881c90a..f09516233 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(ByteIRTransforms FuncTag.cpp GenericDeviceConfig.cpp GraphClusteringByDevice.cpp + HorizontalFusion.cpp InsertUniqueId.cpp LoopTag.cpp LoopUnroll.cpp diff --git a/compiler/lib/Transforms/HorizontalFusion.cpp b/compiler/lib/Transforms/HorizontalFusion.cpp new file mode 100644 index 000000000..8dd4f0f4f --- /dev/null +++ b/compiler/lib/Transforms/HorizontalFusion.cpp @@ -0,0 +1,373 @@ +//===- HorizontalFusion.cpp ----------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Transforms/HorizontalFusion.h" + +#include "byteir/Dialect/Byre/ByreDialect.h" +#include "byteir/Dialect/mhlo/Transforms/HloFuser.h" +#include "byteir/Utils/IRRewrite.h" +#include "byteir/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include + +#include "PassDetail.h" + +using namespace mlir; +using namespace llvm; + +namespace { + +constexpr StringRef kernelTypeNameAttr = "__byteir_forall_type_name"; +constexpr StringRef kernelFuncNameAttr = "__byteir_forall_kernel_name"; + +void setForallTagAndName(ModuleOp &m) { + static constexpr StringRef elementwiseAttrName = + getByteIRElementwiseFusionAttrName(); + static constexpr StringRef reductionAttrName = + getByteIRReductionFusionAttrName(); + + for (auto funcOp : m.getOps()) { + auto funcName = funcOp.getName(); + StringRef kernelType; + if (funcOp->hasAttr(elementwiseAttrName)) { + kernelType = elementwiseAttrName; + } else if (funcOp->hasAttr(reductionAttrName)) { + kernelType = reductionAttrName; + } else + continue; + + mlir::OpBuilder opBuilder(funcOp); + for (auto forallOp : funcOp.getOps()) { + forallOp->setAttr(kernelTypeNameAttr, + opBuilder.getStringAttr(kernelType)); + forallOp->setAttr(kernelFuncNameAttr, opBuilder.getStringAttr(funcName)); + } + } +} + +inline bool isByreEntry(func::FuncOp &funcOp) { + return funcOp->hasAttr(getAttrPlaceholderName( + byre::ByreDialect::getEntryPointFunctionAttrName())); +} + +void moveForwardAlloc(ModuleOp &m) { + for (auto funcOp : m.getOps()) { + if (!isByreEntry(funcOp)) + continue; + mlir::OpBuilder b(funcOp); + b.setInsertionPointToStart(&(funcOp.getBody().front())); + Block::iterator insertionPoint = b.getInsertionPoint(); + for (auto alloc : + llvm::make_early_inc_range(funcOp.getOps())) { + alloc->moveAfter(&(funcOp.getBody().front()), insertionPoint); + } + } +} + +// HorizontalFusionPass +using HFusionPattern = llvm::SmallVector; +using HFusionPlan = llvm::SmallVector; + +struct HorizontalFusionPass + : public HorizontalFusionBase { + explicit HorizontalFusionPass() + : HorizontalFusionBase::HorizontalFusionBase() {} + + void runOnOperation() override; + void getCandidates(ModuleOp &m, SmallVector &candidates); + void makeHorizontalFusionPlan(SmallVector &, HFusionPlan &); + void doHorizontalFusion(HFusionPlan &); + bool isFusibleAndBenefit(scf::ForallOp pre, scf::ForallOp cur); + void collectWRMemref(scf::ForallOp forallOp, SmallVector &w, + SmallVector &r); + void collectUsePointInBlock(Block *block, SmallVector &vals, + llvm::SetVector &usePoints); +}; // HorizontalFusionPass + +void HorizontalFusionPass::getCandidates(ModuleOp &m, + SmallVector &candidates) { + for (auto funcOp : m.getOps()) { + if (!isByreEntry(funcOp)) + continue; + + for (auto forallOp : funcOp.getOps()) { + // TODO(chhuang) (1) check instrs nums; (2) skip large shape; + // just pass all elementwise kernel as candidates. + if (forallOp->hasAttr(kernelTypeNameAttr) && + forallOp->getAttr(kernelTypeNameAttr).cast().getValue() == + getByteIRElementwiseFusionAttrName()) { + candidates.push_back(forallOp); + } + } + } +} + +// traverse from top to down, greedy check whether fuseiable +void HorizontalFusionPass::makeHorizontalFusionPlan( + SmallVector &candidates, HFusionPlan &plan) { + Operation *head = nullptr; + HFusionPattern *pattern = nullptr; + for (auto cur : candidates) { + if (head && isFusibleAndBenefit(dyn_cast(head), + dyn_cast(cur))) { + pattern->push_back(cur); + continue; + } + head = cur; + HFusionPattern newPattern; + plan.push_back(newPattern); + pattern = &(plan.back()); + pattern->push_back(head); + } +} + +void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { + OpBuilder builder(getOperation()); + for (auto pattern : plan) { + if (pattern.size() < 2) + continue; + // TODO sort forall with shape and instrs count + + // merge + auto root = cast(pattern.front()); + SmallVector blockNums; + for (auto op : pattern) { + auto forall = cast(op); + blockNums.push_back(forall.getStaticUpperBound().front()); + } + // TODO should we align blockNum to multiple 32 to reduce divergence + int64_t allBlockNums = std::accumulate(blockNums.begin(), blockNums.end(), + 1, std::plus()); + auto front = cast(pattern.front()); + auto loc = front.getLoc(); + builder.setInsertionPoint(front); + SmallVector bounds; + Value tempBound = builder.create(loc, 0); + bounds.push_back(tempBound); + for (auto num : blockNums) { + Value n = builder.create(loc, num); + tempBound = builder.create(loc, tempBound, n); + bounds.push_back(tempBound); + } + + auto cstIZero = builder.create(loc, 0); + auto cstIOne = builder.create(loc, 1); + auto cstNums = builder.create(loc, allBlockNums); + // FIXME not hack lb and step here + auto lb = builder.create(loc, 0); + auto step = builder.create(loc, 1); + + // create grid level forall + auto hFuseForall = + builder.create(loc, /*lb*/ ArrayRef({lb}), + /*ub*/ ArrayRef({cstNums}), + /*step*/ ArrayRef({step}), + ValueRange(), front.getMapping()); + builder.setInsertionPointToStart(hFuseForall.getBody()); + auto blockId = hFuseForall.getBody()->getArgument(0); + + // create condition br one by one + Value switchValue = builder.create(loc, 0); + for (int64_t i = 0; i < blockNums.size(); ++i) { + auto cmp = builder.create(loc, arith::CmpIPredicate::sgt, + blockId, bounds[i]); + auto selVal = + builder.create(loc, cmp, cstIOne, cstIZero); + switchValue = builder.create(loc, switchValue, selVal); + } + SmallVector cases = + llvm::to_vector(llvm::seq(0, blockNums.size())); + auto switchOp = builder.create( + loc, /*resultTypes*/ TypeRange{}, switchValue, cases, cases.size()); + + // default region + { + Block &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + builder.setInsertionPointToStart(&defaultBlock); + builder.create(loc); + } + + // case region + for (int64_t i = 0; i < blockNums.size(); ++i) { + auto orgForall = cast(pattern[i]); + Block &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + builder.setInsertionPointToStart(&caseBlock); + auto orgId = builder.create(loc, blockId, bounds[i]); + Block::iterator insertionPoint = builder.getInsertionPoint(); + replaceAllUsesInRegionWith(orgForall.getBody()->getArgument(0), orgId, + orgForall.getRegion()); + caseBlock.getOperations().splice(insertionPoint, + orgForall.getBody()->getOperations()); + caseBlock.back().erase(); + builder.create(loc); + orgForall.erase(); + } + } +} + +bool HorizontalFusionPass::isFusibleAndBenefit(scf::ForallOp pre, + scf::ForallOp cur) { + // TODO check whether benefit + // TODO check all has same mapping + + auto same_shape = [](ArrayRef a, ArrayRef b) { + if (a.size() != b.size()) + return false; + return llvm::all_of((llvm::zip(a, b)), [](std::tuple s) { + return std::get<0>(s) == std::get<1>(s); + }); + }; + + // check fusiable + SmallVector preWriteVals; + SmallVector preReadVals; + SmallVector curWriteVals; + SmallVector curReadVals; + // TODO include alias and collect all uses. + collectWRMemref(pre, preWriteVals, preReadVals); + collectWRMemref(cur, curWriteVals, curReadVals); + + llvm::SetVector usePoints; + collectUsePointInBlock(cur->getParentOp()->getBlock(), curWriteVals, + usePoints); + collectUsePointInBlock(cur->getParentOp()->getBlock(), curReadVals, + usePoints); + + auto &domInfo = getAnalysis(); + auto checkDominace = [&](Operation *op) { + // just skip checking viewlike ops + if (isa_and_nonnull(op)) + return true; + return domInfo.properlyDominates(op, pre) || domInfo.dominates(cur, op); + }; + bool fusiable = llvm::all_of(usePoints, checkDominace); + + return fusiable; +} + +void HorizontalFusionPass::collectWRMemref(scf::ForallOp forallOp, + SmallVector &w, + SmallVector &r) { + auto collect = [](TypedValue memref, SmallVector &chunk) { + Value root = memref; + while (true) { + if (auto defOp = + dyn_cast_if_present(root.getDefiningOp())) { + root = defOp->getOperand(0); + continue; + } + break; + } + llvm::SetVector alias; + SmallVector worklist; + alias.insert(root); + worklist.push_back(root); + while (!worklist.empty()) { + auto val = worklist.pop_back_val(); + for (auto user : val.getUsers()) { + if (auto viewlike = dyn_cast_if_present(user)) { + for (auto res : viewlike->getResults()) { + worklist.push_back(res); + alias.insert(res); + } + } + } + } + }; + + forallOp->walk([&](memref::LoadOp load) { + auto memref = load.getMemref(); + collect(memref, r); + }); + forallOp->walk([&](memref::StoreOp store) { + auto memref = store.getMemref(); + collect(memref, w); + }); +} + +void HorizontalFusionPass::collectUsePointInBlock( + Block *block, SmallVector &vals, + llvm::SetVector &usePoints) { + for (auto val : vals) { + for (auto user : val.getUsers()) { + Operation *inBlockUser = user; + while (inBlockUser->getParentOp() && + inBlockUser->getParentOp()->getBlock() != block) { + inBlockUser = inBlockUser->getParentOp(); + } + if (inBlockUser->getParentOp()) { + usePoints.insert(inBlockUser); + } + } + } +} + +void HorizontalFusionPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + auto &domInfo = getAnalysis(); + + /// stage 1. inline all fused function back to entry function + { + setForallTagAndName(moduleOp); + + OpPassManager pm(moduleOp.getOperationName()); + pm.addPass(createInlinerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (mlir::failed(runPipeline(pm, moduleOp))) { + signalPassFailure(); + } + } + /// stage 2. make fusion planing + // move forward alloc. + // TODO Infact, one can move more ops. + moveForwardAlloc(moduleOp); + + SmallVector candidateForallOps; + getCandidates(moduleOp, candidateForallOps); + HFusionPlan horiFusionPlan; + makeHorizontalFusionPlan(candidateForallOps, horiFusionPlan); + + /// stage 3. do horizontal fusion + doHorizontalFusion(horiFusionPlan); + + /// postprocess + // TODO lazy alloc + + /// [deprecated] stage 4. outline scf.forall back to func call + /// or, outline after gpu codegen. +} + +} // namespace + +std::unique_ptr> mlir::createHorizontalFusionPass() { + return std::make_unique(); +} diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index 3b5126e3e..96dd8bc2e 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -59,6 +59,11 @@ namespace scf { class SCFDialect; } // namespace scf +namespace tensor { +class TensorDialect; +} // namespace tensor + + #define GEN_PASS_CLASSES #include "byteir/Transforms/Passes.h.inc" diff --git a/compiler/lib/Utils/LoopUtils.cpp b/compiler/lib/Utils/LoopUtils.cpp index 7cedd3532..1804fc79b 100644 --- a/compiler/lib/Utils/LoopUtils.cpp +++ b/compiler/lib/Utils/LoopUtils.cpp @@ -444,7 +444,7 @@ std::optional mlir::createTrivialSCFForIfHaveNone(func::FuncOp funcOp) { // if having scf::ForOp return nullopt - if (!funcOp.getOps().empty()) { + if (!funcOp.getOps().empty() || !funcOp.getOps().empty()) { return std::nullopt; } diff --git a/scripts/prepare.sh b/scripts/prepare.sh index bcb3f2631..233797355 100755 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -25,9 +25,9 @@ function install_mhlo_tools() { function copy_external_libs() { PREBUILT_FLASH_ATTN="/data00/external_libraries/libflash_attn.so" - mkdir $ROOT_PROJ_DIR/external_libs/libs + mkdir -p $ROOT_PROJ_DIR/external_libs/libs cp $PREBUILT_FLASH_ATTN $ROOT_PROJ_DIR/external_libs/libs - mkdir $ROOT_PROJ_DIR/runtime/test/test_files/external_libs/ + mkdir -p $ROOT_PROJ_DIR/runtime/test/test_files/external_libs/ cp $PREBUILT_FLASH_ATTN $ROOT_PROJ_DIR/runtime/test/test_files/external_libs/ } From 24ece270596f9644a83aaf35a082fe8bb31121f0 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Fri, 28 Jun 2024 15:51:51 +0800 Subject: [PATCH 08/13] [compiler] insert hfuse pass into scf-opt and add h-fuse demo --- compiler/lib/Pipelines/SCFOpt.cpp | 1 + .../test/Transforms/horizontalFusion.mlir | 89 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 compiler/test/Transforms/horizontalFusion.mlir diff --git a/compiler/lib/Pipelines/SCFOpt.cpp b/compiler/lib/Pipelines/SCFOpt.cpp index 3abf9e1ea..44f9fdfcb 100644 --- a/compiler/lib/Pipelines/SCFOpt.cpp +++ b/compiler/lib/Pipelines/SCFOpt.cpp @@ -53,6 +53,7 @@ void addGenericSCFOptPasses(OpPassManager &pm) { pm.addNestedPass( createFuseNestedForallPass(getByteIRReductionFusionAttrName())); addCleanUpExtPassPipeline(pm); + pm.addPass(createHorizontalFusionPass()); } void addCPUSCFOptPasses(OpPassManager &pm) { diff --git a/compiler/test/Transforms/horizontalFusion.mlir b/compiler/test/Transforms/horizontalFusion.mlir new file mode 100644 index 000000000..853ef6e7a --- /dev/null +++ b/compiler/test/Transforms/horizontalFusion.mlir @@ -0,0 +1,89 @@ +// RUN: byteir-opt %s --horizontal-fusion-on-scf --cse --canonicalize | FileCheck %s +#map = affine_map<(d0, d1) -> (d0)> +module { + func.func private @Unknown0(%arg0: memref<32x16xf32>, %arg1: memref<32x16xf32>, %arg2: memref<32x16xf32>) -> memref<32x16xf32> attributes {__byteir_elementwise_fusion__} { + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %alloc = memref.alloc() : memref<32x16xf32> + scf.forall (%arg3) in (2) { + %0 = arith.muli %arg3, %c256 : index + scf.forall (%arg4) in (256) { + %1 = arith.addi %arg4, %0 : index + %2 = arith.remsi %1, %c16 : index + %3 = arith.divsi %1, %c16 : index + %4 = memref.load %arg0[%3, %2] : memref<32x16xf32> + %5 = memref.load %arg1[%3, %2] : memref<32x16xf32> + %6 = memref.load %arg2[%3, %2] : memref<32x16xf32> + %7 = arith.mulf %4, %5 : f32 + %8 = arith.divf %7, %6 : f32 + memref.store %8, %alloc[%3, %2] : memref<32x16xf32> + } {mapping = [#gpu.thread]} + } {mapping = [#gpu.block]} + return %alloc : memref<32x16xf32> + } + func.func private @Unknown1(%arg0: memref<32x16xf32>, %arg1: memref<32x16xf32>) -> memref<16xf32> attributes {__byteir_reduction_fusion__} { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf32> + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<16xf32> + scf.forall (%arg2) in (16) { + scf.forall (%arg3) in (1) { + %0 = vector.transfer_read %arg0[%c0, %arg2], %cst {in_bounds = [true], permutation_map = #map} : memref<32x16xf32>, vector<32xf32> + %1 = vector.transfer_read %arg1[%c0, %arg2], %cst {in_bounds = [true], permutation_map = #map} : memref<32x16xf32>, vector<32xf32> + %2 = arith.mulf %0, %0 : vector<32xf32> + %3 = arith.subf %2, %1 : vector<32xf32> + %4 = vector.reduction , %3, %cst : vector<32xf32> into f32 + %5 = vector.insertelement %4, %cst_0[%c0 : index] : vector<1xf32> + %6 = vector.extract %5[0] : f32 from vector<1xf32> + %7 = vector.broadcast %6 : f32 to vector + vector.transfer_write %7, %alloc[%arg2] : vector, memref<16xf32> + } {mapping = [#gpu.warp]} + } {mapping = [#gpu.block]} + return %alloc : memref<16xf32> + } + func.func private @Unknown2(%arg0: memref<16xf32>) -> memref<16xf32> attributes {__byteir_elementwise_fusion__} { + %cst = arith.constant 0.000000e+00 : f32 + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %alloc = memref.alloc() : memref<16xf32> + scf.forall (%arg1) in (1) { + %0 = arith.muli %arg1, %c256 : index + %1 = arith.subi %c16, %0 : index + %2 = arith.cmpi sgt, %1, %c256 : index + %3 = arith.select %2, %c256, %1 : index + scf.forall (%arg2) in (%3) { + %4 = arith.addi %arg2, %0 : index + %5 = memref.load %arg0[%4] : memref<16xf32> + %6 = arith.maximumf %5, %cst : f32 + memref.store %6, %alloc[%4] : memref<16xf32> + } {mapping = [#gpu.thread]} + } {mapping = [#gpu.block]} + return %alloc : memref<16xf32> + } + func.func private @Unknown3(%arg0: memref<16xf32>, %arg1: memref<32x16xf32>) -> memref<32x16xf32> attributes {__byteir_elementwise_fusion__} { + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %alloc = memref.alloc() : memref<32x16xf32> + scf.forall (%arg2) in (2) { + %0 = arith.muli %arg2, %c256 : index + scf.forall (%arg3) in (256) { + %1 = arith.addi %arg3, %0 : index + %2 = arith.remsi %1, %c16 : index + %3 = arith.divsi %1, %c16 : index + %4 = memref.load %arg0[%2] : memref<16xf32> + %5 = memref.load %arg1[%3, %2] : memref<32x16xf32> + %6 = arith.addf %4, %5 : f32 + memref.store %6, %alloc[%3, %2] : memref<32x16xf32> + } {mapping = [#gpu.thread]} + } {mapping = [#gpu.block]} + return %alloc : memref<32x16xf32> + } + func.func @main(%arg0: memref<32x16xf32>, %arg1: memref<32x16xf32>, %arg2: memref<32x16xf32>) -> (memref<32x16xf32>, memref<32x16xf32>) attributes {__placeholder__byre.entry_point} { + %0 = call @Unknown0(%arg0, %arg1, %arg2) : (memref<32x16xf32>, memref<32x16xf32>, memref<32x16xf32>) -> memref<32x16xf32> + %1 = call @Unknown1(%0, %arg0) : (memref<32x16xf32>, memref<32x16xf32>) -> memref<16xf32> + %2 = call @Unknown2(%1) : (memref<16xf32>) -> memref<16xf32> + %3 = call @Unknown3(%2, %0) : (memref<16xf32>, memref<32x16xf32>) -> memref<32x16xf32> + return %0, %3 : memref<32x16xf32>, memref<32x16xf32> + } +} + From e503729879cb15e3aa483d43cb8e9866795a0da4 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Fri, 28 Jun 2024 16:46:21 +0800 Subject: [PATCH 09/13] add vector-opt pipeline and refator gpu-opt --- .../byteir/Pipelines/InitAllPipelines.h | 2 + compiler/include/byteir/Pipelines/VectorOpt.h | 46 ++++++++++++ compiler/lib/Pipelines/CMakeLists.txt | 1 + compiler/lib/Pipelines/GPU/GPUOpt.cpp | 73 +++---------------- compiler/lib/Pipelines/SCFOpt.cpp | 1 - compiler/lib/Pipelines/VectorOpt.cpp | 68 +++++++++++++++++ .../test/Transforms/horizontalFusion.mlir | 4 + 7 files changed, 130 insertions(+), 65 deletions(-) create mode 100644 compiler/include/byteir/Pipelines/VectorOpt.h create mode 100644 compiler/lib/Pipelines/VectorOpt.cpp diff --git a/compiler/include/byteir/Pipelines/InitAllPipelines.h b/compiler/include/byteir/Pipelines/InitAllPipelines.h index 27dc678a9..0cc2c68b5 100644 --- a/compiler/include/byteir/Pipelines/InitAllPipelines.h +++ b/compiler/include/byteir/Pipelines/InitAllPipelines.h @@ -32,6 +32,7 @@ #include "byteir/Pipelines/LinalgTensorOpt.h" #include "byteir/Pipelines/SCFOpt.h" #include "byteir/Pipelines/ShapeOpt.h" +#include "byteir/Pipelines/VectorOpt.h" #include "byteir/Pipelines/GPU/ElementwiseCodegen.h" #include "byteir/Pipelines/GPU/GPUOpt.h" @@ -58,6 +59,7 @@ inline void registerAllByteIRCommonPipelines() { registerLinalgMemrefOptPipeline(); registerLinalgTensorOptPipeline(); registerSCFOptPipeline(); + registerVectorOptPipeline(); registerShapeOptPipeline(); registerByteIRBufferizeOptPipeline(); registerByteIRAllOptPipeline(); diff --git a/compiler/include/byteir/Pipelines/VectorOpt.h b/compiler/include/byteir/Pipelines/VectorOpt.h new file mode 100644 index 000000000..d5ded6dbb --- /dev/null +++ b/compiler/include/byteir/Pipelines/VectorOpt.h @@ -0,0 +1,46 @@ +//===- VectorOpt.h --------------------------------------------------- C++ ---===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_PIPELINES_VECTOROPT_H +#define BYTEIR_PIPELINES_VECTOROPT_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include +#include + +namespace mlir { +struct VectorOptPipelineOptions + : public PassPipelineOptions { + Option target{ + *this, "target", + llvm::cl::desc("An optional attribute to speicify target."), + llvm::cl::init("")}; +}; + +void createVectorOptPipeline(OpPassManager &pm, + const VectorOptPipelineOptions &options); + +inline void registerVectorOptPipeline() { + PassPipelineRegistration("vector-opt", "Vector Opt Pipeline", + createVectorOptPipeline); +} + +} // namespace mlir + +#endif // BYTEIR_PIPELINES_VECTOROPT_H diff --git a/compiler/lib/Pipelines/CMakeLists.txt b/compiler/lib/Pipelines/CMakeLists.txt index 9626a5e6c..b4f350a90 100644 --- a/compiler/lib/Pipelines/CMakeLists.txt +++ b/compiler/lib/Pipelines/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(ByteIRPipelines LinalgTensorOpt.cpp SCFOpt.cpp ShapeOpt.cpp + VectorOpt.cpp ADDITIONAL_HEADER_DIRS ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Pipelines diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index 6d065e762..478fb789a 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -23,16 +23,12 @@ #include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/SCF/Passes.h" #include "byteir/Dialect/Transform/Transforms/TransformDialectInterpreter.h" -#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h" -#include "byteir/Dialect/Vector/Transforms/Passes.h" -#include "byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" #include "byteir/Pipelines/GPU/MappingForall.h" #include "byteir/Transforms/Passes.h" #include "byteir/Transforms/RemoveFuncBody.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" @@ -43,77 +39,26 @@ using namespace mlir; using namespace mlir::bufferization; namespace { -void createElementwiseGPUOptPipelineImpl(OpPassManager &pm, - const bool &useBarePtrCallConv, - const std::string &target) { - // apply PromotoBufferStack to func's with - // getByteIRElementwiseFusionAttrName - { - OpPassManager anchoredPM(func::FuncOp::getOperationName()); - anchoredPM.addPass(createPromoteBuffersToStackPass( - /*isSmallAlloc =*/[](Value) { return true; })); - - pm.addNestedPass(createAnchoredPipelinePass( - getByteIRElementwiseFusionAttrName(), anchoredPM)); - } - - // pm.addNestedPass(createGenPTXConfigPass(useBarePtrCallConv)); - GPUMappingForallOptions mappingOptions; - mappingOptions.funcAnchor = getByteIRElementwiseFusionAttrName().str(); - mappingOptions.blockDimsHint = llvm::cl::KernelDims{256, 1, 1}; - createGPUMappingForallTransform(pm, mappingOptions); -} - -void createReductionGPUOptPipelineImpl(OpPassManager &pm) { +void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, + const std::string &target) { + pm.addPass(createHorizontalFusionPass()); GPUMappingForallOptions options; - options.funcAnchor = getByteIRReductionFusionAttrName().str(); options.blockDimsHint = llvm::cl::KernelDims{256, 1, 1}; - // vector redution to gpu shuffle & lowering - { - OpPassManager anchoredPM(func::FuncOp::getOperationName()); - anchoredPM.addPass( - createMoveForallRegionIntoWarpOpPass(/* warpSize = */ 32)); - VectorWarpDistributePassOptions options; - options.warpOpToSCF = true; - options.distributeTransferWriteOps = true; - options.hoistUniform = true; - options.propagateDistribution = true; - anchoredPM.addPass(createVectorWarpDistributePass(options)); - anchoredPM.addPass(createCanonicalizerPass()); - anchoredPM.addPass(createCSEPass()); - anchoredPM.addPass(createScalarVectorLoweringPass()); - anchoredPM.addPass(createCanonicalizeExtPass()); - anchoredPM.addPass(createConvertVectorToSCFPass()); - pm.addNestedPass(createAnchoredPipelinePass( - getByteIRReductionFusionAttrName(), anchoredPM)); - } createGPUMappingForallTransform(pm, options); pm.addPass(createTransformDialectInterpreter(true)); pm.addPass(createCSEPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createGpuLauchSinkIndexComputationsPass()); - - { - OpPassManager anchoredPM(func::FuncOp::getOperationName()); - - anchoredPM.addPass(createPromoteBuffersToStackPass( - /*isSmallAlloc =*/[](Value value) { - return value.getParentRegion()->getParentOfType(); - })); - - pm.addNestedPass(createAnchoredPipelinePass( - getByteIRReductionFusionAttrName(), anchoredPM)); - } + pm.addPass(createPromoteBuffersToStackPass( + /*isSmallAlloc =*/[](Value value) { + return value.getParentRegion()->getParentOfType(); + })); pm.addPass(createGpuKernelOutliningPass()); -} - -void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, - const std::string &target) { - createElementwiseGPUOptPipelineImpl(pm, useBarePtrCallConv, target); - createReductionGPUOptPipelineImpl(pm); pm.addPass(createCollectGPUKernelPass("unified", false)); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Pipelines/SCFOpt.cpp b/compiler/lib/Pipelines/SCFOpt.cpp index 44f9fdfcb..3abf9e1ea 100644 --- a/compiler/lib/Pipelines/SCFOpt.cpp +++ b/compiler/lib/Pipelines/SCFOpt.cpp @@ -53,7 +53,6 @@ void addGenericSCFOptPasses(OpPassManager &pm) { pm.addNestedPass( createFuseNestedForallPass(getByteIRReductionFusionAttrName())); addCleanUpExtPassPipeline(pm); - pm.addPass(createHorizontalFusionPass()); } void addCPUSCFOptPasses(OpPassManager &pm) { diff --git a/compiler/lib/Pipelines/VectorOpt.cpp b/compiler/lib/Pipelines/VectorOpt.cpp new file mode 100644 index 000000000..e0721f722 --- /dev/null +++ b/compiler/lib/Pipelines/VectorOpt.cpp @@ -0,0 +1,68 @@ +//===- VectorOpt.cpp --------------------------------------------- C++---*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Pipelines/VectorOpt.h" + +#include "byteir/Dialect/Linalg/Passes.h" +#include "byteir/Dialect/Linalg/Transforms/LinalgExtToLoops.h" +#include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" +#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h" +#include "byteir/Dialect/Vector/Transforms/Passes.h" +#include "byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h" +#include "byteir/Dialect/mhlo/Passes.h" +#include "byteir/Pipelines/Common/Utils.h" +#include "byteir/Transforms/Passes.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::affine; + +namespace { +void addGPUVectorOptPasses(OpPassManager &pm) { + // vector redution to gpu shuffle & lowering + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createMoveForallRegionIntoWarpOpPass(/* warpSize = */ 32)); + VectorWarpDistributePassOptions options; + options.warpOpToSCF = true; + options.distributeTransferWriteOps = true; + options.hoistUniform = true; + options.propagateDistribution = true; + anchoredPM.addPass(createVectorWarpDistributePass(options)); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createScalarVectorLoweringPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createConvertVectorToSCFPass()); + pm.addNestedPass(createAnchoredPipelinePass( + getByteIRReductionFusionAttrName(), anchoredPM)); +} + +void createVectorOptPipelineImpl(OpPassManager &pm, const std::string &target) { + if (target == "GPU") { + addGPUVectorOptPasses(pm); + } +} +} // namespace + +void mlir::createVectorOptPipeline(OpPassManager &pm, + const VectorOptPipelineOptions &options) { + invokeOpPassPipelineBuilder(createVectorOptPipelineImpl, pm, options.target); +} diff --git a/compiler/test/Transforms/horizontalFusion.mlir b/compiler/test/Transforms/horizontalFusion.mlir index 853ef6e7a..0e29b0345 100644 --- a/compiler/test/Transforms/horizontalFusion.mlir +++ b/compiler/test/Transforms/horizontalFusion.mlir @@ -87,3 +87,7 @@ module { } } +// CHECK-LABEL: func.func @main +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK-not: scf.forall From a44c04c3701b4817926deeb7d36a8501ec49eb09 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Tue, 2 Jul 2024 17:41:33 +0800 Subject: [PATCH 10/13] [compiler] fix hfusion case range error - as title - add hfuse test case - format --- compiler/include/byteir/Pipelines/VectorOpt.h | 9 +- compiler/include/byteir/Transforms/Passes.h | 2 +- .../Dialect/mhlo/Transforms/GenericFusion.cpp | 21 +-- compiler/lib/Pipelines/GPU/GPUOpt.cpp | 2 +- compiler/lib/Pipelines/SCFOpt.cpp | 13 ++ compiler/lib/Transforms/HorizontalFusion.cpp | 120 +++++++++++--- compiler/lib/Transforms/PassDetail.h | 1 - compiler/lib/Utils/LoopUtils.cpp | 3 +- compiler/python/byteir/compile.py | 3 + .../test/Transforms/horizontalFusion.mlir | 147 +++++++++++++----- 10 files changed, 242 insertions(+), 79 deletions(-) diff --git a/compiler/include/byteir/Pipelines/VectorOpt.h b/compiler/include/byteir/Pipelines/VectorOpt.h index d5ded6dbb..a51c6561e 100644 --- a/compiler/include/byteir/Pipelines/VectorOpt.h +++ b/compiler/include/byteir/Pipelines/VectorOpt.h @@ -1,4 +1,5 @@ -//===- VectorOpt.h --------------------------------------------------- C++ ---===// +//===- VectorOpt.h --------------------------------------------------- C++ +//---===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,11 +35,11 @@ struct VectorOptPipelineOptions }; void createVectorOptPipeline(OpPassManager &pm, - const VectorOptPipelineOptions &options); + const VectorOptPipelineOptions &options); inline void registerVectorOptPipeline() { - PassPipelineRegistration("vector-opt", "Vector Opt Pipeline", - createVectorOptPipeline); + PassPipelineRegistration( + "vector-opt", "Vector Opt Pipeline", createVectorOptPipeline); } } // namespace mlir diff --git a/compiler/include/byteir/Transforms/Passes.h b/compiler/include/byteir/Transforms/Passes.h index 96d2b749e..122ab7dcb 100644 --- a/compiler/include/byteir/Transforms/Passes.h +++ b/compiler/include/byteir/Transforms/Passes.h @@ -28,6 +28,7 @@ #include "byteir/Transforms/FuncTag.h" #include "byteir/Transforms/GenericDeviceConfig.h" #include "byteir/Transforms/GraphClusteringByDevice.h" +#include "byteir/Transforms/HorizontalFusion.h" #include "byteir/Transforms/InsertUniqueId.h" #include "byteir/Transforms/LoopTag.h" #include "byteir/Transforms/LoopUnroll.h" @@ -38,7 +39,6 @@ #include "byteir/Transforms/SetSpace.h" #include "byteir/Transforms/ShapeFuncOutlining.h" #include "byteir/Transforms/TryCatchModulePipeline.h" -#include "byteir/Transforms/HorizontalFusion.h" namespace mlir { diff --git a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp index 9aec234d9..922d1f7c6 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp @@ -59,12 +59,12 @@ namespace elementwise { // TODO: maybe we should support non-splat constant on device in future bool isFusibleCandidate(Operation *op) { - return isMhlo(op) && - (op->hasTrait<::mlir::OpTrait::Elementwise>() || - op->hasTrait() || - isSplatMhloConstantLike(op) || - isa(op) || - isCustomMhloRngOp(op)); + return isMhlo(op) && (op->hasTrait<::mlir::OpTrait::Elementwise>() || + op->hasTrait() || + isSplatMhloConstantLike(op) || + isa(op) || + isCustomMhloRngOp(op)); } // every candidate can start @@ -73,7 +73,7 @@ bool isFusibleStart(Operation *op) { return true; } bool isFusibleTrigger(Operation *op) { if (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op) || isCustomMhloRngOp(op)) { + isa(op) || isCustomMhloRngOp(op)) { return true; } @@ -97,8 +97,8 @@ bool isFusibleWith(Operation *target, Operation * /*start*/) { return target->hasTrait<::mlir::OpTrait::Elementwise>() || target->hasTrait() || isSplatMhloConstantLike(target) || - isa( - target) || + isa(target) || isCustomMhloRngOp(target); } @@ -111,7 +111,8 @@ bool isFusibleWithNoElementwiseFuse(Operation *target, Operation * /*start*/) { bool isValidSingleOp(Operation *op) { return op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op) || + isa(op) || isCustomMhloRngOp(op); } diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index 478fb789a..2b6dee7ff 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -51,7 +51,7 @@ void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, pm.addPass(createCSEPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createGpuLauchSinkIndexComputationsPass()); - pm.addPass(createPromoteBuffersToStackPass( + pm.addNestedPass(createPromoteBuffersToStackPass( /*isSmallAlloc =*/[](Value value) { return value.getParentRegion()->getParentOfType(); })); diff --git a/compiler/lib/Pipelines/SCFOpt.cpp b/compiler/lib/Pipelines/SCFOpt.cpp index 3abf9e1ea..4beb3162a 100644 --- a/compiler/lib/Pipelines/SCFOpt.cpp +++ b/compiler/lib/Pipelines/SCFOpt.cpp @@ -19,6 +19,7 @@ #include "byteir/Dialect/Linalg/Passes.h" #include "byteir/Dialect/Linalg/Transforms/LinalgExtToLoops.h" +#include "byteir/Dialect/SCF/Passes.h" #include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" @@ -38,12 +39,16 @@ using namespace mlir::affine; namespace { void addGenericSCFOptPasses(OpPassManager &pm) { // for elementwise op +#if 0 GPUTileElementwiseInSCFOptions tileOptions; tileOptions.maxBlockSize = 256; createGPUTileElementwiseInSCF(pm, tileOptions); +#endif pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertLinalgExtToLoopsPass()); + // TODO fix scf.for in reduction kernel + // pm.addNestedPass(createForToForallPass()); // lower affine.apply in case there is some pm.addPass(memref::createFoldMemRefAliasOpsPass()); pm.addPass(createLowerAffinePass()); @@ -53,6 +58,14 @@ void addGenericSCFOptPasses(OpPassManager &pm) { pm.addNestedPass( createFuseNestedForallPass(getByteIRReductionFusionAttrName())); addCleanUpExtPassPipeline(pm); + // for copy op +#if 1 + { + GPUTileElementwiseInSCFOptions tileOptions; + tileOptions.maxBlockSize = 256; + createGPUTileElementwiseInSCF(pm, tileOptions); + } +#endif } void addCPUSCFOptPasses(OpPassManager &pm) { diff --git a/compiler/lib/Transforms/HorizontalFusion.cpp b/compiler/lib/Transforms/HorizontalFusion.cpp index 8dd4f0f4f..a5421712d 100644 --- a/compiler/lib/Transforms/HorizontalFusion.cpp +++ b/compiler/lib/Transforms/HorizontalFusion.cpp @@ -77,6 +77,16 @@ inline bool isByreEntry(func::FuncOp &funcOp) { byre::ByreDialect::getEntryPointFunctionAttrName())); } +bool isDefInParent(Value val, Operation *parent) { + auto defOp = val.getDefiningOp(); + while (defOp) { + if (defOp->getParentOp() == parent) + return true; + defOp = defOp->getParentOp(); + } + return false; +} + void moveForwardAlloc(ModuleOp &m) { for (auto funcOp : m.getOps()) { if (!isByreEntry(funcOp)) @@ -105,9 +115,11 @@ struct HorizontalFusionPass void makeHorizontalFusionPlan(SmallVector &, HFusionPlan &); void doHorizontalFusion(HFusionPlan &); bool isFusibleAndBenefit(scf::ForallOp pre, scf::ForallOp cur); - void collectWRMemref(scf::ForallOp forallOp, SmallVector &w, - SmallVector &r); - void collectUsePointInBlock(Block *block, SmallVector &vals, + void collectWRMemref(scf::ForallOp forallOp, llvm::SetVector &w, + llvm::SetVector &r); + void collectDirectOperands(scf::ForallOp forallOp, + llvm::SetVector &vals); + void collectUsePointInBlock(Block *block, llvm::SetVector &vals, llvm::SetVector &usePoints); }; // HorizontalFusionPass @@ -153,7 +165,7 @@ void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { for (auto pattern : plan) { if (pattern.size() < 2) continue; - // TODO sort forall with shape and instrs count + // TODO sort forall with shape and insts count // merge auto root = cast(pattern.front()); @@ -164,7 +176,8 @@ void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { } // TODO should we align blockNum to multiple 32 to reduce divergence int64_t allBlockNums = std::accumulate(blockNums.begin(), blockNums.end(), - 1, std::plus()); + 0, std::plus()); + auto front = cast(pattern.front()); auto loc = front.getLoc(); builder.setInsertionPoint(front); @@ -192,12 +205,12 @@ void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { ValueRange(), front.getMapping()); builder.setInsertionPointToStart(hFuseForall.getBody()); auto blockId = hFuseForall.getBody()->getArgument(0); - + // create condition br one by one Value switchValue = builder.create(loc, 0); for (int64_t i = 0; i < blockNums.size(); ++i) { - auto cmp = builder.create(loc, arith::CmpIPredicate::sgt, - blockId, bounds[i]); + auto cmp = builder.create(loc, arith::CmpIPredicate::sge, + blockId, bounds[i + 1]); auto selVal = builder.create(loc, cmp, cstIOne, cstIZero); switchValue = builder.create(loc, switchValue, selVal); @@ -206,14 +219,14 @@ void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { llvm::to_vector(llvm::seq(0, blockNums.size())); auto switchOp = builder.create( loc, /*resultTypes*/ TypeRange{}, switchValue, cases, cases.size()); - + // default region { Block &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); builder.setInsertionPointToStart(&defaultBlock); builder.create(loc); } - + // case region for (int64_t i = 0; i < blockNums.size(); ++i) { auto orgForall = cast(pattern[i]); @@ -246,14 +259,17 @@ bool HorizontalFusionPass::isFusibleAndBenefit(scf::ForallOp pre, }; // check fusiable - SmallVector preWriteVals; - SmallVector preReadVals; - SmallVector curWriteVals; - SmallVector curReadVals; + // llvm::SetVector preWriteVals; + // llvm::SetVector preReadVals; + llvm::SetVector curWriteVals; + llvm::SetVector curReadVals; // TODO include alias and collect all uses. - collectWRMemref(pre, preWriteVals, preReadVals); + // collectWRMemref(pre, preWriteVals, preReadVals); collectWRMemref(cur, curWriteVals, curReadVals); + llvm::SetVector directOperands; + // collectDirectOperands(cur, directOperands); + llvm::SetVector usePoints; collectUsePointInBlock(cur->getParentOp()->getBlock(), curWriteVals, usePoints); @@ -261,21 +277,28 @@ bool HorizontalFusionPass::isFusibleAndBenefit(scf::ForallOp pre, usePoints); auto &domInfo = getAnalysis(); - auto checkDominace = [&](Operation *op) { + + auto checkUsePointDominace = [&](Operation *op) { // just skip checking viewlike ops if (isa_and_nonnull(op)) return true; return domInfo.properlyDominates(op, pre) || domInfo.dominates(cur, op); }; - bool fusiable = llvm::all_of(usePoints, checkDominace); + + auto checkDirectOpdDominace = [&](Operation *op) { + return domInfo.properlyDominates(op, pre); + }; + + bool fusiable = llvm::all_of(usePoints, checkUsePointDominace); + fusiable &= llvm::all_of(directOperands, checkDirectOpdDominace); return fusiable; } void HorizontalFusionPass::collectWRMemref(scf::ForallOp forallOp, - SmallVector &w, - SmallVector &r) { - auto collect = [](TypedValue memref, SmallVector &chunk) { + llvm::SetVector &w, + llvm::SetVector &r) { + auto collect = [](Value memref, llvm::SetVector &chunk) { Value root = memref; while (true) { if (auto defOp = @@ -285,9 +308,8 @@ void HorizontalFusionPass::collectWRMemref(scf::ForallOp forallOp, } break; } - llvm::SetVector alias; SmallVector worklist; - alias.insert(root); + chunk.insert(root); worklist.push_back(root); while (!worklist.empty()) { auto val = worklist.pop_back_val(); @@ -295,7 +317,7 @@ void HorizontalFusionPass::collectWRMemref(scf::ForallOp forallOp, if (auto viewlike = dyn_cast_if_present(user)) { for (auto res : viewlike->getResults()) { worklist.push_back(res); - alias.insert(res); + chunk.insert(res); } } } @@ -310,10 +332,59 @@ void HorizontalFusionPass::collectWRMemref(scf::ForallOp forallOp, auto memref = store.getMemref(); collect(memref, w); }); + + forallOp->walk([&](Operation *op) { + if (!isa(op->getDialect())) + return WalkResult::advance(); + for (auto opd : op->getOperands()) { + if (isDefInParent(opd, forallOp.getOperation())) + continue; + bool readEffect = false; + if (auto opInter = dyn_cast(op)) { + if (opInter.getEffectOnValue(opd).has_value()) { + readEffect = true; + } + } + if (readEffect) + collect(opd, r); + else + collect(opd, w); + } + return WalkResult::advance(); + }); +} + +void HorizontalFusionPass::collectDirectOperands( + scf::ForallOp forallOp, llvm::SetVector &vals) { + auto collect = [](Value memref, llvm::SetVector &opds) { + if (auto def = memref.getDefiningOp()) + opds.insert(def); + }; + + forallOp->walk([&](Operation *op) { + if (!isa(op->getDialect())) + return WalkResult::advance(); + for (auto opd : op->getOperands()) { + if (isDefInParent(opd, forallOp.getOperation())) + continue; + if (!isa(opd.getType())) + continue; + collect(opd, vals); + } + return WalkResult::advance(); + }); + // forallOp->walk([&](memref::LoadOp load) { + // auto memref = load.getMemref(); + // collect(memref, vals); + //}); + // forallOp->walk([&](memref::StoreOp store) { + // auto memref = store.getMemref(); + // collect(memref, vals); + //}); } void HorizontalFusionPass::collectUsePointInBlock( - Block *block, SmallVector &vals, + Block *block, llvm::SetVector &vals, llvm::SetVector &usePoints) { for (auto val : vals) { for (auto user : val.getUsers()) { @@ -350,6 +421,7 @@ void HorizontalFusionPass::runOnOperation() { // move forward alloc. // TODO Infact, one can move more ops. moveForwardAlloc(moduleOp); + // llvm::dbgs() << "chh dbg alllll:\n" << moduleOp << "\n"; SmallVector candidateForallOps; getCandidates(moduleOp, candidateForallOps); diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index 96dd8bc2e..ba48d2dec 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -63,7 +63,6 @@ namespace tensor { class TensorDialect; } // namespace tensor - #define GEN_PASS_CLASSES #include "byteir/Transforms/Passes.h.inc" diff --git a/compiler/lib/Utils/LoopUtils.cpp b/compiler/lib/Utils/LoopUtils.cpp index 1804fc79b..fa53b393c 100644 --- a/compiler/lib/Utils/LoopUtils.cpp +++ b/compiler/lib/Utils/LoopUtils.cpp @@ -444,7 +444,8 @@ std::optional mlir::createTrivialSCFForIfHaveNone(func::FuncOp funcOp) { // if having scf::ForOp return nullopt - if (!funcOp.getOps().empty() || !funcOp.getOps().empty()) { + if (!funcOp.getOps().empty() || + !funcOp.getOps().empty()) { return std::nullopt; } diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 55f0c7275..390abd098 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -125,6 +125,9 @@ def _compile_cuda( with context: PassManager.parse("builtin.module(scf-opt)").run(module.operation) _print_verbose(module, "// IR Dump After SCF Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(vector-opt{target=GPU})").run(module.operation) + _print_verbose(module, "// IR Dump After Vec Opt:") if verbose else ... with context: if useBarePtrCallConv: PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true})").run(module.operation) diff --git a/compiler/test/Transforms/horizontalFusion.mlir b/compiler/test/Transforms/horizontalFusion.mlir index 0e29b0345..218deaa77 100644 --- a/compiler/test/Transforms/horizontalFusion.mlir +++ b/compiler/test/Transforms/horizontalFusion.mlir @@ -1,22 +1,35 @@ // RUN: byteir-opt %s --horizontal-fusion-on-scf --cse --canonicalize | FileCheck %s #map = affine_map<(d0, d1) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> module { func.func private @Unknown0(%arg0: memref<32x16xf32>, %arg1: memref<32x16xf32>, %arg2: memref<32x16xf32>) -> memref<32x16xf32> attributes {__byteir_elementwise_fusion__} { %c16 = arith.constant 16 : index - %c256 = arith.constant 256 : index %alloc = memref.alloc() : memref<32x16xf32> - scf.forall (%arg3) in (2) { - %0 = arith.muli %arg3, %c256 : index - scf.forall (%arg4) in (256) { - %1 = arith.addi %arg4, %0 : index - %2 = arith.remsi %1, %c16 : index - %3 = arith.divsi %1, %c16 : index - %4 = memref.load %arg0[%3, %2] : memref<32x16xf32> - %5 = memref.load %arg1[%3, %2] : memref<32x16xf32> - %6 = memref.load %arg2[%3, %2] : memref<32x16xf32> - %7 = arith.mulf %4, %5 : f32 - %8 = arith.divf %7, %6 : f32 - memref.store %8, %alloc[%3, %2] : memref<32x16xf32> + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c0_0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %0 = arith.muli %c1, %c256 : index + %c0_1 = arith.constant 0 : index + %c512_2 = arith.constant 512 : index + %1 = arith.subi %c512_2, %c0_1 : index + %2 = arith.ceildivsi %1, %0 : index + %c1_3 = arith.constant 1 : index + scf.forall (%arg3) in (%2) { + %3 = arith.muli %arg3, %0 : index + %c0_4 = arith.constant 0 : index + %c1_5 = arith.constant 1 : index + scf.forall (%arg4) in (%0) { + %4 = arith.addi %arg4, %3 : index + %5 = arith.remsi %4, %c16 : index + %6 = arith.divsi %4, %c16 : index + %7 = memref.load %arg0[%6, %5] : memref<32x16xf32> + %8 = memref.load %arg1[%6, %5] : memref<32x16xf32> + %9 = memref.load %arg2[%6, %5] : memref<32x16xf32> + %10 = arith.mulf %7, %8 : f32 + %11 = arith.divf %10, %9 : f32 + memref.store %11, %alloc[%6, %5] : memref<32x16xf32> } {mapping = [#gpu.thread]} } {mapping = [#gpu.block]} return %alloc : memref<32x16xf32> @@ -43,37 +56,87 @@ module { } func.func private @Unknown2(%arg0: memref<16xf32>) -> memref<16xf32> attributes {__byteir_elementwise_fusion__} { %cst = arith.constant 0.000000e+00 : f32 - %c16 = arith.constant 16 : index - %c256 = arith.constant 256 : index %alloc = memref.alloc() : memref<16xf32> - scf.forall (%arg1) in (1) { - %0 = arith.muli %arg1, %c256 : index - %1 = arith.subi %c16, %0 : index - %2 = arith.cmpi sgt, %1, %c256 : index - %3 = arith.select %2, %c256, %1 : index - scf.forall (%arg2) in (%3) { - %4 = arith.addi %arg2, %0 : index - %5 = memref.load %arg0[%4] : memref<16xf32> - %6 = arith.maximumf %5, %cst : f32 - memref.store %6, %alloc[%4] : memref<16xf32> + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c0_0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %0 = arith.muli %c1, %c256 : index + %c0_1 = arith.constant 0 : index + %c16_2 = arith.constant 16 : index + %1 = arith.subi %c16_2, %c0_1 : index + %2 = arith.ceildivsi %1, %0 : index + %c1_3 = arith.constant 1 : index + scf.forall (%arg1) in (%2) { + %3 = arith.muli %arg1, %0 : index + %4 = affine.min #map1(%0, %c16, %3) + %c0_4 = arith.constant 0 : index + %c1_5 = arith.constant 1 : index + scf.forall (%arg2) in (%4) { + %5 = arith.addi %arg2, %3 : index + %6 = memref.load %arg0[%5] : memref<16xf32> + %7 = arith.maximumf %6, %cst : f32 + memref.store %7, %alloc[%5] : memref<16xf32> } {mapping = [#gpu.thread]} } {mapping = [#gpu.block]} return %alloc : memref<16xf32> } func.func private @Unknown3(%arg0: memref<16xf32>, %arg1: memref<32x16xf32>) -> memref<32x16xf32> attributes {__byteir_elementwise_fusion__} { %c16 = arith.constant 16 : index + %alloc = memref.alloc() : memref<32x16xf32> + %c0 = arith.constant 0 : index %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c0_0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %0 = arith.muli %c1, %c256 : index + %c0_1 = arith.constant 0 : index + %c512_2 = arith.constant 512 : index + %1 = arith.subi %c512_2, %c0_1 : index + %2 = arith.ceildivsi %1, %0 : index + %c1_3 = arith.constant 1 : index + scf.forall (%arg2) in (%2) { + %3 = arith.muli %arg2, %0 : index + %c0_4 = arith.constant 0 : index + %c1_5 = arith.constant 1 : index + scf.forall (%arg3) in (%0) { + %4 = arith.addi %arg3, %3 : index + %5 = arith.remsi %4, %c16 : index + %6 = arith.divsi %4, %c16 : index + %7 = memref.load %arg0[%5] : memref<16xf32> + %8 = memref.load %arg1[%6, %5] : memref<32x16xf32> + %9 = arith.addf %7, %8 : f32 + memref.store %9, %alloc[%6, %5] : memref<32x16xf32> + } {mapping = [#gpu.thread]} + } {mapping = [#gpu.block]} + return %alloc : memref<32x16xf32> + } + func.func private @Unknown4(%arg0: memref<32x16xf32>) -> memref<32x16xf32> attributes {__byteir_elementwise_fusion__} { + %c16 = arith.constant 16 : index %alloc = memref.alloc() : memref<32x16xf32> - scf.forall (%arg2) in (2) { - %0 = arith.muli %arg2, %c256 : index - scf.forall (%arg3) in (256) { - %1 = arith.addi %arg3, %0 : index - %2 = arith.remsi %1, %c16 : index - %3 = arith.divsi %1, %c16 : index - %4 = memref.load %arg0[%2] : memref<16xf32> - %5 = memref.load %arg1[%3, %2] : memref<32x16xf32> - %6 = arith.addf %4, %5 : f32 - memref.store %6, %alloc[%3, %2] : memref<32x16xf32> + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c0_0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %0 = arith.muli %c1, %c256 : index + %c0_1 = arith.constant 0 : index + %c512_2 = arith.constant 512 : index + %1 = arith.subi %c512_2, %c0_1 : index + %2 = arith.ceildivsi %1, %0 : index + %c1_3 = arith.constant 1 : index + scf.forall (%arg1) in (%2) { + %3 = arith.muli %arg1, %0 : index + %c0_4 = arith.constant 0 : index + %c1_5 = arith.constant 1 : index + scf.forall (%arg2) in (%0) { + %4 = arith.addi %arg2, %3 : index + %5 = arith.remsi %4, %c16 : index + %6 = arith.divsi %4, %c16 : index + %7 = memref.load %arg0[%6, %5] : memref<32x16xf32> + %8 = arith.mulf %7, %7 : f32 + memref.store %8, %alloc[%6, %5] : memref<32x16xf32> } {mapping = [#gpu.thread]} } {mapping = [#gpu.block]} return %alloc : memref<32x16xf32> @@ -83,11 +146,21 @@ module { %1 = call @Unknown1(%0, %arg0) : (memref<32x16xf32>, memref<32x16xf32>) -> memref<16xf32> %2 = call @Unknown2(%1) : (memref<16xf32>) -> memref<16xf32> %3 = call @Unknown3(%2, %0) : (memref<16xf32>, memref<32x16xf32>) -> memref<32x16xf32> - return %0, %3 : memref<32x16xf32>, memref<32x16xf32> + %4 = call @Unknown4(%arg1) : (memref<32x16xf32>) -> memref<32x16xf32> + return %3, %4 : memref<32x16xf32>, memref<32x16xf32> } } // CHECK-LABEL: func.func @main // CHECK: scf.forall + // CHECK: scf.forall +// CHECK: scf.forall + // CHECK: scf.forall +// CHECK: scf.forall + // CHECK: scf.forall // CHECK: scf.forall -// CHECK-not: scf.forall + // CHECK: case 0 + // CHECK: scf.forall + // CHECK: case 1 + // CHECK: scf.forall + // CHECK: default From 0ff9eaa69321dda22e8ccbef5ee07f1ec6ec03b2 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Mon, 8 Jul 2024 17:07:53 +0800 Subject: [PATCH 11/13] [compiler] fix forall tiling test case --- compiler/test/Dialect/SCF/forallTiling.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compiler/test/Dialect/SCF/forallTiling.mlir b/compiler/test/Dialect/SCF/forallTiling.mlir index 2b5c0b63d..ef927bec9 100644 --- a/compiler/test/Dialect/SCF/forallTiling.mlir +++ b/compiler/test/Dialect/SCF/forallTiling.mlir @@ -37,9 +37,9 @@ func.func @Elementwise(%arg0: memref<32x1024x?x30xf32>) -> memref<32768x?x30xf32 %2 = arith.divsi %arg1, %c30 : index %3 = arith.remsi %2, %dim : index %4 = arith.divsi %2, %dim : index - %subview = memref.subview %collapse_shape[%4, %3, %1] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> - %subview_0 = memref.subview %alloc[%4, %3, %1] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%subview : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) outs(%subview_0 : memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} { + %subview = memref.subview %collapse_shape[%4, %3, %1] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> + %subview_0 = memref.subview %alloc[%4, %3, %1] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%subview : memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>>) outs(%subview_0 : memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>>) attrs = {__byteir_gpu_tile_elementwise_0} { ^bb0(%in: f32, %out: f32): %5 = arith.mulf %in, %in : f32 linalg.yield %5 : f32 @@ -65,5 +65,5 @@ func.func @Elementwise(%arg0: memref<32x1024x?x30xf32>) -> memref<32768x?x30xf32 // CHECK-NEXT: %[[V4:.*]] = arith.divsi %[[V2]], %[[C30]] : index // CHECK-NEXT: %[[V5:.*]] = arith.remsi %[[V4]], %[[DIM]] : index // CHECK-NEXT: %[[V6:.*]] = arith.divsi %[[V4]], %[[DIM]] : index - // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[COLLAPSE]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> - // CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[300, 30, 1], offset: ?>> + // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[COLLAPSE]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> + // CHECK-NEXT: %[[SUBVIEW_0:.*]] = memref.subview %[[ALLOC]][%[[V6]], %[[V5]], %[[V3]]] [1, 1, 1] [1, 1, 1] : memref<32768x?x30xf32> to memref<1x1x1xf32, strided<[?, 30, 1], offset: ?>> From ce5920a0b5134028353cb87e3cb3ca1275fa3379 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Tue, 9 Jul 2024 15:34:40 +0800 Subject: [PATCH 12/13] [compiler] add hfusion strategy - as title, implement a simple heuristic strategy to chose the fusion pattern - fix set-arg-space update op type error --- compiler/lib/Transforms/HorizontalFusion.cpp | 133 ++++++++++++-- compiler/lib/Transforms/SetSpace.cpp | 177 ++++++++++--------- 2 files changed, 213 insertions(+), 97 deletions(-) diff --git a/compiler/lib/Transforms/HorizontalFusion.cpp b/compiler/lib/Transforms/HorizontalFusion.cpp index a5421712d..e95f5c229 100644 --- a/compiler/lib/Transforms/HorizontalFusion.cpp +++ b/compiler/lib/Transforms/HorizontalFusion.cpp @@ -23,6 +23,7 @@ #include "byteir/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -101,10 +102,76 @@ void moveForwardAlloc(ModuleOp &m) { } } +template inline bool hasMappingAttr(scf::ForallOp forall) { + auto mapping = forall.getMapping(); + bool hasAttr = + mapping.has_value() && llvm::any_of(mapping.value(), [&](Attribute attr) { + return isa(attr); + }); + return hasAttr; +}; + +int64_t getForallLoopSize(scf::ForallOp &forall) { + auto ubs = forall.getStaticUpperBound(); + auto lbs = forall.getStaticLowerBound(); + auto steps = forall.getStaticStep(); + int64_t loopSize = 1; + for (auto &&[ub, lb, s] : llvm::zip(ubs, lbs, steps)) { + loopSize *= (ub - lb) / s; + } + return loopSize; +} + +int64_t extractForallBlockSize(Operation *op) { + int64_t blockSize = 0; + if (auto forall = dyn_cast_if_present(op)) { + // FIXME. assume only one forall in outter forall body. + forall->walk([&](scf::ForallOp innerForall) { + auto mapping = innerForall.getMapping(); + if (!mapping.has_value()) + return WalkResult::advance(); + bool hasThreadMapping = hasMappingAttr(forall); + if (hasThreadMapping) { + blockSize = getForallLoopSize(innerForall); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + return blockSize; +} + +int64_t extractForallGridSize(Operation *op) { + int64_t gridSize = 0; + if (auto forall = dyn_cast_if_present(op)) { + if (hasMappingAttr(forall)) { + gridSize = getForallLoopSize(forall); + } + } + return gridSize; +} + +int64_t extractForallInstrCount(Operation *op) { + int instrCount = 0; + if (auto forall = dyn_cast_if_present(op)) { + // TODO skip viewlike ops. + instrCount = std::distance(forall.getOps().begin(), forall.getOps().end()); + for (auto innerForall : forall.getOps()) { + innerForall->walk([&](Operation *instr) { instrCount++; }); + } + } + return instrCount; +} + // HorizontalFusionPass using HFusionPattern = llvm::SmallVector; using HFusionPlan = llvm::SmallVector; +// TODO(chhuang) +constexpr int64_t kMAX_INSTR_COUNT = 1024; +constexpr int64_t kMAX_BLOCK_SIZE = 8192 * 8192; +constexpr int64_t kMAX_GRID_SIZE = 1024; + struct HorizontalFusionPass : public HorizontalFusionBase { explicit HorizontalFusionPass() @@ -125,16 +192,53 @@ struct HorizontalFusionPass void HorizontalFusionPass::getCandidates(ModuleOp &m, SmallVector &candidates) { + auto hasGPUMapping = [](scf::ForallOp forall) { + bool hasBlockMapping = hasMappingAttr(forall); + if (!hasBlockMapping) + return false; + bool hasThreadMapping = !forall.getOps().empty(); + for (auto innerForall : forall.getOps()) { + hasThreadMapping &= + hasMappingAttr(innerForall); + } + return hasThreadMapping; + }; + + // Fuse too much instrs may lead to register spill and use too much resource. + auto checkInstrCount = [](scf::ForallOp forall, const int64_t max_instr) { + return extractForallInstrCount(forall.getOperation()) < max_instr; + }; + + // It's not benefit to fuse kernel with large grid size, which already has + // enough blocks to occupy GPU SMs. + auto checkGridSize = [](scf::ForallOp forall, const int64_t max_grid) { + return extractForallGridSize(forall) < max_grid; + }; + + auto checkBlockSize = [](scf::ForallOp forall, const int64_t max_block) { + return extractForallBlockSize(forall) < max_block; + }; + + auto checkAll = [&](scf::ForallOp forall) { + bool check = true; + check &= hasGPUMapping(forall); + check &= checkInstrCount(forall, kMAX_INSTR_COUNT); + check &= checkGridSize(forall, kMAX_GRID_SIZE); + check &= checkBlockSize(forall, kMAX_BLOCK_SIZE); + + return check; + }; + for (auto funcOp : m.getOps()) { if (!isByreEntry(funcOp)) continue; for (auto forallOp : funcOp.getOps()) { // TODO(chhuang) (1) check instrs nums; (2) skip large shape; - // just pass all elementwise kernel as candidates. if (forallOp->hasAttr(kernelTypeNameAttr) && forallOp->getAttr(kernelTypeNameAttr).cast().getValue() == - getByteIRElementwiseFusionAttrName()) { + getByteIRElementwiseFusionAttrName() && + checkAll(forallOp)) { candidates.push_back(forallOp); } } @@ -162,10 +266,19 @@ void HorizontalFusionPass::makeHorizontalFusionPlan( void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { OpBuilder builder(getOperation()); - for (auto pattern : plan) { + for (auto &&pattern : plan) { if (pattern.size() < 2) continue; - // TODO sort forall with shape and insts count + // Sort forall with shape and insts count. So the same blocksize will be + // placed adjacent each other and fuse together computations of similar + // sizes. + std::sort(pattern.begin(), pattern.end(), [&](Operation *a, Operation *b) { + auto aBlockSize = extractForallBlockSize(a); + auto bBlockSize = extractForallBlockSize(b); + if (aBlockSize != bBlockSize) + return aBlockSize < bBlockSize; + return extractForallInstrCount(a) < extractForallInstrCount(b); + }); // merge auto root = cast(pattern.front()); @@ -248,7 +361,6 @@ void HorizontalFusionPass::doHorizontalFusion(HFusionPlan &plan) { bool HorizontalFusionPass::isFusibleAndBenefit(scf::ForallOp pre, scf::ForallOp cur) { // TODO check whether benefit - // TODO check all has same mapping auto same_shape = [](ArrayRef a, ArrayRef b) { if (a.size() != b.size()) @@ -259,12 +371,9 @@ bool HorizontalFusionPass::isFusibleAndBenefit(scf::ForallOp pre, }; // check fusiable - // llvm::SetVector preWriteVals; - // llvm::SetVector preReadVals; llvm::SetVector curWriteVals; llvm::SetVector curReadVals; - // TODO include alias and collect all uses. - // collectWRMemref(pre, preWriteVals, preReadVals); + // include alias and collect all uses. collectWRMemref(cur, curWriteVals, curReadVals); llvm::SetVector directOperands; @@ -419,9 +528,8 @@ void HorizontalFusionPass::runOnOperation() { } /// stage 2. make fusion planing // move forward alloc. - // TODO Infact, one can move more ops. moveForwardAlloc(moduleOp); - // llvm::dbgs() << "chh dbg alllll:\n" << moduleOp << "\n"; + // TODO move forward viewlike ops. SmallVector candidateForallOps; getCandidates(moduleOp, candidateForallOps); @@ -433,9 +541,6 @@ void HorizontalFusionPass::runOnOperation() { /// postprocess // TODO lazy alloc - - /// [deprecated] stage 4. outline scf.forall back to func call - /// or, outline after gpu codegen. } } // namespace diff --git a/compiler/lib/Transforms/SetSpace.cpp b/compiler/lib/Transforms/SetSpace.cpp index 4d19350e9..7e21a8c10 100644 --- a/compiler/lib/Transforms/SetSpace.cpp +++ b/compiler/lib/Transforms/SetSpace.cpp @@ -407,102 +407,113 @@ void updateFuncReturnTypes( void updateOpTypes(FuncOp func, ModuleOp m, DenseMap ©PairToCopyTargets, ArgSideEffectAnalysis *analysis) { - // rewrite all types - for (auto &block : func.getBlocks()) { - for (auto &op : block.without_terminator()) { - if (auto viewLikeOp = llvm::dyn_cast(op)) { - auto src = viewLikeOp.getViewSource(); - auto srcType = dyn_cast(src.getType()); - if (!srcType) - continue; - auto srcSpace = srcType.getMemorySpace(); - if (!srcSpace) - continue; - - auto currSpace = srcSpace; - // if op has space attribute, use it as memory space - if (auto opSpaceAttr = op.getAttrOfType(SPACE_ATTR_NAME)) { - if (srcSpace != opSpaceAttr) { - // insert copy if src space is different with spaceAttr - auto newSrcType = cloneMemRefTypeWithMemSpace(srcType, opSpaceAttr); - auto newArg = createCopyInputArg(&op, src, newSrcType, opSpaceAttr, - copyPairToCopyTargets); - op.setOperand(0, newArg); - currSpace = opSpaceAttr; - } - } - // propagate memory space from currSpace to dest - for (auto result : op.getResults()) { - auto dstType = dyn_cast(result.getType()); - if (!dstType) + auto update_op_types = [&]() { + // rewrite all types + for (auto &block : func.getBlocks()) { + for (auto &op : block.without_terminator()) { + if (auto viewLikeOp = llvm::dyn_cast(op)) { + auto src = viewLikeOp.getViewSource(); + auto srcType = dyn_cast(src.getType()); + if (!srcType) + continue; + auto srcSpace = srcType.getMemorySpace(); + if (!srcSpace) continue; - auto dstSpace = dstType.getMemorySpace(); - if (dstSpace) { - if (dstSpace != currSpace) { - // insert copy if dst space was already set to different space + auto currSpace = srcSpace; + // if op has space attribute, use it as memory space + if (auto opSpaceAttr = + op.getAttrOfType(SPACE_ATTR_NAME)) { + if (srcSpace != opSpaceAttr) { + // insert copy if src space is different with spaceAttr + auto newSrcType = + cloneMemRefTypeWithMemSpace(srcType, opSpaceAttr); + auto newArg = createCopyInputArg( + &op, src, newSrcType, opSpaceAttr, copyPairToCopyTargets); + op.setOperand(0, newArg); + currSpace = opSpaceAttr; + } + } + // propagate memory space from currSpace to dest + for (auto result : op.getResults()) { + auto dstType = dyn_cast(result.getType()); + if (!dstType) + continue; + auto dstSpace = dstType.getMemorySpace(); + + if (dstSpace) { + if (dstSpace != currSpace) { + // insert copy if dst space was already set to different space + auto newDstType = + cloneMemRefTypeWithMemSpace(dstType, currSpace); + result.setType(newDstType); + createCopyReturn(viewLikeOp, result, dstType, + copyPairToCopyTargets); + } + } else { + // set to spaceAttr if no space auto newDstType = cloneMemRefTypeWithMemSpace(dstType, currSpace); result.setType(newDstType); - createCopyReturn(viewLikeOp, result, dstType, - copyPairToCopyTargets); } - } else { - // set to spaceAttr if no space - auto newDstType = cloneMemRefTypeWithMemSpace(dstType, currSpace); - result.setType(newDstType); } - } - } else if (auto opSpaceAttr = - op.getAttrOfType(SPACE_ATTR_NAME)) { - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto operand = op.getOperand(i); - if (auto MemrefTy = dyn_cast(operand.getType())) { - auto curSpace = MemrefTy.getMemorySpace(); - - if (curSpace == nullptr) { - // if no space, use opSpaceAttr + } else if (auto opSpaceAttr = + op.getAttrOfType(SPACE_ATTR_NAME)) { + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto operand = op.getOperand(i); + if (auto MemrefTy = dyn_cast(operand.getType())) { + auto curSpace = MemrefTy.getMemorySpace(); + + if (curSpace == nullptr) { + // if no space, use opSpaceAttr + auto newOperandType = + cloneMemRefTypeWithMemSpace(MemrefTy, opSpaceAttr); + operand.setType(newOperandType); + } else if (opSpaceAttr != curSpace) { + // insert copy when curSpace is not desired opSpaceAttr + CopyType_t copyKey = {operand, opSpaceAttr}; + + if (copyPairToCopyTargets.count(copyKey) == 0) { + // if copy not exist, insert copy + auto argSEType = analysis->getType(&op, i); + auto newArg = + createCopyArg(&op, operand, MemrefTy, opSpaceAttr, + copyPairToCopyTargets, argSEType); + op.setOperand(i, newArg); + } else { + // if copy already exist, directly refer it + auto taget = copyPairToCopyTargets[copyKey]; + op.setOperand(i, taget); + } + } // if else + } // if MemrefTy + } // for i < op.getNumOperands() + + // set operand type + for (auto operand : op.getOperands()) { + if (auto MemrefTy = dyn_cast(operand.getType())) { auto newOperandType = cloneMemRefTypeWithMemSpace(MemrefTy, opSpaceAttr); operand.setType(newOperandType); - } else if (opSpaceAttr != curSpace) { - // insert copy when curSpace is not desired opSpaceAttr - CopyType_t copyKey = {operand, opSpaceAttr}; - - if (copyPairToCopyTargets.count(copyKey) == 0) { - // if copy not exist, insert copy - auto argSEType = analysis->getType(&op, i); - auto newArg = createCopyArg(&op, operand, MemrefTy, opSpaceAttr, - copyPairToCopyTargets, argSEType); - op.setOperand(i, newArg); - } else { - // if copy already exist, directly refer it - auto taget = copyPairToCopyTargets[copyKey]; - op.setOperand(i, taget); - } - } // if else - } // if MemrefTy - } // for i < op.getNumOperands() - - // set operand type - for (auto operand : op.getOperands()) { - if (auto MemrefTy = dyn_cast(operand.getType())) { - auto newOperandType = - cloneMemRefTypeWithMemSpace(MemrefTy, opSpaceAttr); - operand.setType(newOperandType); + } } - } - // set result type in case it has - for (auto result : op.getResults()) { - if (auto MemrefTy = dyn_cast(result.getType())) { - auto newOperandType = - cloneMemRefTypeWithMemSpace(MemrefTy, opSpaceAttr); - result.setType(newOperandType); + // set result type in case it has + for (auto result : op.getResults()) { + if (auto MemrefTy = dyn_cast(result.getType())) { + auto newOperandType = + cloneMemRefTypeWithMemSpace(MemrefTy, opSpaceAttr); + result.setType(newOperandType); + } } } - } - } // for op in block.without_terminator() - } + } // for op in block.without_terminator() + } + }; + + // Do twice as the viewlike op's operand may not be updated before updating + // current op. + update_op_types(); + update_op_types(); // respect to function return type for (auto &&retOp : func.getOps()) { From 1a932c575dd38b92fc442261169d3b24fbbc8b73 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Tue, 9 Jul 2024 20:09:41 +0800 Subject: [PATCH 13/13] [compiler] del invalid code --- compiler/lib/Pipelines/SCFOpt.cpp | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/compiler/lib/Pipelines/SCFOpt.cpp b/compiler/lib/Pipelines/SCFOpt.cpp index 4beb3162a..1e92cd316 100644 --- a/compiler/lib/Pipelines/SCFOpt.cpp +++ b/compiler/lib/Pipelines/SCFOpt.cpp @@ -38,17 +38,9 @@ using namespace mlir::affine; namespace { void addGenericSCFOptPasses(OpPassManager &pm) { - // for elementwise op -#if 0 - GPUTileElementwiseInSCFOptions tileOptions; - tileOptions.maxBlockSize = 256; - createGPUTileElementwiseInSCF(pm, tileOptions); -#endif pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertLinalgExtToLoopsPass()); - // TODO fix scf.for in reduction kernel - // pm.addNestedPass(createForToForallPass()); // lower affine.apply in case there is some pm.addPass(memref::createFoldMemRefAliasOpsPass()); pm.addPass(createLowerAffinePass()); @@ -59,13 +51,12 @@ void addGenericSCFOptPasses(OpPassManager &pm) { createFuseNestedForallPass(getByteIRReductionFusionAttrName())); addCleanUpExtPassPipeline(pm); // for copy op -#if 1 + // for elementwise op { GPUTileElementwiseInSCFOptions tileOptions; tileOptions.maxBlockSize = 256; createGPUTileElementwiseInSCF(pm, tileOptions); } -#endif } void addCPUSCFOptPasses(OpPassManager &pm) {