diff --git a/compiler/include/byteir/Dialect/SCF/Passes.h b/compiler/include/byteir/Dialect/SCF/Passes.h index ce65fa057..4e8636d5c 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.h +++ b/compiler/include/byteir/Dialect/SCF/Passes.h @@ -18,6 +18,7 @@ #ifndef BYTEIR_DIALECT_SCF_PASSES_H #define BYTEIR_DIALECT_SCF_PASSES_H +#include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/SCF/Transforms/InsertTrivialSCFLoop.h" namespace mlir { diff --git a/compiler/include/byteir/Dialect/SCF/Passes.td b/compiler/include/byteir/Dialect/SCF/Passes.td index 55253b987..17be7a018 100644 --- a/compiler/include/byteir/Dialect/SCF/Passes.td +++ b/compiler/include/byteir/Dialect/SCF/Passes.td @@ -38,4 +38,21 @@ def InsertTrivialSCFLoop : Pass<"insert-trivial-scf-loop", "mlir::func::FuncOp"> ]; } +//===----------------------------------------------------------------------===// +// FuseNestedForall +//===----------------------------------------------------------------------===// + +def FuseNestedForall : Pass<"fuse-nested-forall", "mlir::func::FuncOp"> { + let summary = "Fuse nested forall if possible"; + let constructor = "mlir::createFuseNestedForallPass()"; + let dependentDialects = [ + "scf::SCFDialect" + ]; + let options = [ + Option<"anchorTag", "anchor-tag", "std::string", + /*default=*/"", + "Optional unitAttr anchored tag to apply this pass"> + ]; +} + #endif // BYTEIR_DIALECT_SCF_PASSES \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/SCF/Transforms/FuseNestedForall.h b/compiler/include/byteir/Dialect/SCF/Transforms/FuseNestedForall.h new file mode 100644 index 000000000..88fe2bab4 --- /dev/null +++ b/compiler/include/byteir/Dialect/SCF/Transforms/FuseNestedForall.h @@ -0,0 +1,34 @@ +//===- FuseNestedForall.h ------------------------------------- C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_FUSENESTEDFORALL_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_FUSENESTEDFORALL_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +std::unique_ptr> +createFuseNestedForallPass(llvm::StringRef anchorTag = ""); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_FUSENESTEDFORALL_H diff --git a/compiler/include/byteir/Dialect/Transform/Passes.td b/compiler/include/byteir/Dialect/Transform/Passes.td index 49a471a71..3de8f8bc2 100644 --- a/compiler/include/byteir/Dialect/Transform/Passes.td +++ b/compiler/include/byteir/Dialect/Transform/Passes.td @@ -42,6 +42,9 @@ def DetensorizeTransformInsertion : Pass<"insert-detensorize-transform", "Module let summary = "Insert detensorize transform IR to functions."; let constructor = "mlir::createDetensorizeTransformInsertionPass()"; let options = [ + Option<"usingVectorizeOp", "using-vectorize-op", "bool", + /*default=*/"false", + "using vectorizeOp to detensorize linalg op.">, Option<"funcAnchorAttr", "func-anchor", "std::string", /*default=*/"", "An optional Unit attribute anchoring on target functions.">, diff --git a/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h b/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h index 5741663ac..f7fe61a12 100644 --- a/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h +++ b/compiler/include/byteir/Dialect/Transform/Transforms/TransformInsertion.h @@ -39,7 +39,7 @@ createGenericTransformInsertionPass(const TransformInsertionConfig &config); std::unique_ptr> createDetensorizeTransformInsertionPass( - const std::string &funcAnchor = "", + const bool usingVectorizeOp = false, const std::string &funcAnchor = "", const std::string &matchPrefix = "__byteir_detensorize"); std::unique_ptr> createFuseExtTransformInsertionPass( diff --git a/compiler/include/byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h b/compiler/include/byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h new file mode 100644 index 000000000..5fe350c2a --- /dev/null +++ b/compiler/include/byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h @@ -0,0 +1,39 @@ +//===- MoveForallRegionIntoWarpOp.h ---------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_MOVEFORALLREGIONINTOWARPOP_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_MOVEFORALLREGIONINTOWARPOP_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +constexpr StringRef getMoveForallRegionIntoWarpOpAttrName() { + return "__byteir_move_forall_region_into_warp_execute_on_lane0"; +} + +std::unique_ptr> +createMoveForallRegionIntoWarpOpPass(int64_t warpSize = 32, + llvm::StringRef anchorTag = ""); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_MOVEFORALLREGIONINTOWARPOP_H diff --git a/compiler/include/byteir/Dialect/Vector/Transforms/Passes.h b/compiler/include/byteir/Dialect/Vector/Transforms/Passes.h index 590c20105..a93aab21b 100644 --- a/compiler/include/byteir/Dialect/Vector/Transforms/Passes.h +++ b/compiler/include/byteir/Dialect/Vector/Transforms/Passes.h @@ -18,13 +18,20 @@ #ifndef BYTEIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H #define BYTEIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H +#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h" +#include "byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h" #include "mlir/Pass/Pass.h" #include namespace mlir { +namespace func { +class FuncOp; +} // namespace func /// Generate the code for registering transforms passes. #define GEN_PASS_DECL_VECTORTRANSPOSELOWERINGPASS +#define GEN_PASS_DECL_MOVEFORALLREGIONINTOWARPOPPASS +#define GEN_PASS_DECL_SCALARVECTORLOWERINGPASS #define GEN_PASS_REGISTRATION #include "byteir/Dialect/Vector/Transforms/Passes.h.inc" diff --git a/compiler/include/byteir/Dialect/Vector/Transforms/Passes.td b/compiler/include/byteir/Dialect/Vector/Transforms/Passes.td index f0f306bba..2c9378823 100644 --- a/compiler/include/byteir/Dialect/Vector/Transforms/Passes.td +++ b/compiler/include/byteir/Dialect/Vector/Transforms/Passes.td @@ -36,5 +36,72 @@ def VectorTransposeLoweringPass : Pass<"vector-transpose-lowering", "func::FuncO ]; } +//===----------------------------------------------------------------------===// +// Move Forall Region Into WarpOp +//===----------------------------------------------------------------------===// + +def MoveForallRegionIntoWarpOpPass : Pass<"move-forall-region-into-warp-op", "mlir::func::FuncOp"> { + let summary = "move region of forall into warp_execute_on_lane_0 op"; + let constructor = "mlir::createMoveForallRegionIntoWarpOpPass()"; + let dependentDialects = [ + "memref::MemRefDialect", + "vector::VectorDialect", + "gpu::GPUDialect", + ]; + let options = [ + Option<"warpSize", "warp-size", "int64_t", "32", "warp size">, + Option<"anchorTag", "anchor-tag", "std::string", + /*default=*/"", + "Optional unitAttr anchored tag to apply this pass"> + ]; +} + +//===----------------------------------------------------------------------===// +// Vector Warp Distribute +//===----------------------------------------------------------------------===// +def VectorWarpDistributePass : Pass<"vector-warp-distribute", "mlir::func::FuncOp"> { + let summary = "vector warp distribute transformation"; + let constructor = "mlir::createVectorWarpDistributePass()"; + let dependentDialects = [ + "scf::SCFDialect", + "memref::MemRefDialect", + "vector::VectorDialect", + "gpu::GPUDialect", + "affine::AffineDialect", + ]; + let options = [ + Option<"warpOpToSCF", "rewrite-warp-ops-to-scf-if", "bool", + /*default=*/"false", + "Lower vector.warp_execute_on_lane0 to scf.if op">, + + Option<"distributeTransferWriteOps", "distribute-transfer-write", "bool", + /*default=*/"false", + "distribution of transfer write">, + + Option<"hoistUniform", "hoist-uniform", "bool", + /*default=*/"false", + "hoist-uniform">, + + Option<"propagateDistribution", "propagate-distribution", "bool", + /*default=*/"false", + "distribution propgation">, + + Option<"maxTransferWriteElements", "max-transfer-write-elements", "int64_t", + /*default=*/"1", + "Maximum number of transfer write elements to distribute">, + ]; +} + +//===----------------------------------------------------------------------===// +// Scalar Vector Lowering +//===----------------------------------------------------------------------===// + +def ScalarVectorLoweringPass : Pass<"scalar-vector-lowering", "func::FuncOp"> { + let summary = "Pass to lower scalar vector"; + let dependentDialects = [ + "memref::MemRefDialect", + "vector::VectorDialect" + ]; +} #endif // BYTEIR_DIALECT_VECTOR_TRANSFORMS_PASSES \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h b/compiler/include/byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h new file mode 100644 index 000000000..a5c6d6878 --- /dev/null +++ b/compiler/include/byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h @@ -0,0 +1,38 @@ +//===- VectorWarpDistribute.h ---------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_VECTORWARPDISTRIBUTE_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_VECTORWARPDISTRIBUTE_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +#define GEN_PASS_DECL_VECTORWARPDISTRIBUTEPASS +#include "byteir/Dialect/Vector/Transforms/Passes.h.inc" + +std::unique_ptr> +createVectorWarpDistributePass(const VectorWarpDistributePassOptions &options = + VectorWarpDistributePassOptions()); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_VECTORWARPDISTRIBUTE_H diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.h b/compiler/include/byteir/Dialect/mhlo/Passes.h index 313cfc9ae..1b08fc77a 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.h +++ b/compiler/include/byteir/Dialect/mhlo/Passes.h @@ -62,6 +62,10 @@ inline void registerByteIRMhloPassesExt() { return mlir::createConcatSliceFusionPass(); }); + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::createInsertSliceWithElemwiseFusionPass(); + }); + // register createCatFusionPass ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return mlir::createCatFusionPass(); diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h b/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h index 8a69bef8f..28013be0a 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/HloFuser.h @@ -95,6 +95,9 @@ createElementFusionPass(bool clusterSingleElemwiseOp = false, std::unique_ptr> createConcatSliceFusionPass(); +std::unique_ptr> +createInsertSliceWithElemwiseFusionPass(); + std::unique_ptr> createMatmulEpilogueFusionPass(); std::unique_ptr> createIOConvertFusionPass(); diff --git a/compiler/include/byteir/Pipelines/GPU/MappingForall.h b/compiler/include/byteir/Pipelines/GPU/MappingForall.h index 202cc1e40..34e40c805 100644 --- a/compiler/include/byteir/Pipelines/GPU/MappingForall.h +++ b/compiler/include/byteir/Pipelines/GPU/MappingForall.h @@ -18,6 +18,7 @@ #ifndef BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H #define BYTEIR_PIPELINES_GPU_MAPPING_FORALL_H +#include "byteir/Utils/OptionUtils.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Pass/PassRegistry.h" @@ -34,6 +35,12 @@ struct GPUMappingForallOptions *this, "annotate-prefix", llvm::cl::desc("An optional annotate prefix attribute on target ops."), llvm::cl::init("__byteir_gpu_split_grid_reduction")}; + Option warpSize{*this, "warp-size", llvm::cl::desc("warp size."), + llvm::cl::init(32)}; + Option blockDimsHint{ + *this, "block-size-hint", + llvm::cl::desc("block dims hint for dynamic shape."), + llvm::cl::init(llvm::cl::KernelDims{1024, 1, 1})}; // TODO: option for grid/block dims hint }; diff --git a/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h b/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h index 7aea80d51..648f0a968 100644 --- a/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h +++ b/compiler/include/byteir/Pipelines/GPU/ReductionCodegen.h @@ -54,9 +54,6 @@ struct GPUTileGridReductionOptions llvm::cl::init(32)}; Option blockSize{*this, "block-size", llvm::cl::desc("block size"), llvm::cl::init(256)}; - Option usingForall{*this, "using-forall", - llvm::cl::desc("using forall"), - llvm::cl::init(true)}; }; struct GPUSplitBlockReductionOptions @@ -92,9 +89,44 @@ struct GPUTileBlockReductionOptions llvm::cl::init(32)}; Option blockSize{*this, "block-size", llvm::cl::desc("block size"), llvm::cl::init(256)}; - Option usingForall{*this, "using-forall", - llvm::cl::desc("using forall"), - llvm::cl::init(true)}; +}; + +struct GPUTileSplitWarpReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_split_warp_reduction")}; + Option blockSize{*this, "block-size", llvm::cl::desc("block size"), + llvm::cl::init(256)}; + Option warpSize{*this, "warp-size", llvm::cl::desc("warp size"), + llvm::cl::init(32)}; +}; + +struct GPUTileWarpReductionOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_warp_reduction")}; + Option splitFactor{*this, "split-factor", + llvm::cl::desc("split factor"), + llvm::cl::init(32)}; + Option warpSize{*this, "warp-size", llvm::cl::desc("warp size"), + llvm::cl::init(32)}; + Option usingGPUShuffle{*this, "using-gpu-shuffle", + llvm::cl::desc("using gpu shuffle"), + llvm::cl::init(true)}; }; struct GPUTileThreadReductionOptions @@ -118,6 +150,10 @@ void createGPUSplitBlockReductionTransform( OpPassManager &pm, const GPUSplitBlockReductionOptions &options); void createGPUTileBlockReductionTransform( OpPassManager &pm, const GPUTileBlockReductionOptions &options); +void createGPUTileSplitWarpReductionTransform( + OpPassManager &pm, const GPUTileSplitWarpReductionOptions &options); +void createGPUTileWarpReductionTransform( + OpPassManager &pm, const GPUTileWarpReductionOptions &options); void createGPUTileThreadReductionTransform( OpPassManager &pm, const GPUTileThreadReductionOptions &options); @@ -142,6 +178,16 @@ inline void registerGPUReductionCodegenPipelines() { "Insert transformation IR to tile linalg reduction op", createGPUTileBlockReductionTransform); + PassPipelineRegistration( + "insert-gpu-tile-split-warp-reduction-transform", + "Insert transformation IR to split block reduction to warp", + createGPUTileSplitWarpReductionTransform); + + PassPipelineRegistration( + "insert-gpu-tile-warp-reduction-transform", + "Insert transformation IR to vectorize warp redution", + createGPUTileWarpReductionTransform); + PassPipelineRegistration( "insert-gpu-tile-thread-reduction-transform", "Insert transformation IR to tile linalg reduction op", diff --git a/compiler/include/byteir/Utils/OptionUtils.h b/compiler/include/byteir/Utils/OptionUtils.h new file mode 100644 index 000000000..cd67e6e18 --- /dev/null +++ b/compiler/include/byteir/Utils/OptionUtils.h @@ -0,0 +1,45 @@ +//===- OptionUtils.h -------------------------------- -*- C++ ------*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#pragma once + +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm::cl { + +struct KernelDims { + int64_t x; + int64_t y; + int64_t z; +}; + +template <> class parser : public basic_parser { +public: + parser(Option &O) : basic_parser(O) {} + bool parse(Option &O, StringRef ArgName, StringRef Arg, KernelDims &Val); + StringRef getValueName() const override { return "vector"; } + void printOptionDiff(const Option &O, KernelDims V, const OptVal &Default, + size_t GlobalWidth) const; + /// Print an instance of the underling option value to the given stream. + static void print(raw_ostream &os, const KernelDims &value); + + void anchor() override; +}; +} // namespace llvm::cl diff --git a/compiler/lib/CAPI/Passes.cpp b/compiler/lib/CAPI/Passes.cpp index 6707875c9..c94394a5a 100644 --- a/compiler/lib/CAPI/Passes.cpp +++ b/compiler/lib/CAPI/Passes.cpp @@ -28,6 +28,7 @@ #include "byteir/Dialect/Shape/Passes.h" #include "byteir/Dialect/Tensor/Passes.h" #include "byteir/Dialect/Transform/Passes.h" +#include "byteir/Dialect/Vector/Transforms/Passes.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/InitAllPipelines.h" #include "byteir/Target/CUDA/ToCUDA.h" @@ -52,6 +53,7 @@ void byteirRegisterAllPasses() { registerByteIRMemRefPasses(); registerByteIRMhloPassesExt(); registerByteIRSCFPasses(); + registerByteIRVectorPasses(); registerByteIRShapePasses(); registerByteIRTensorPasses(); registerByteIRTransformPasses(); diff --git a/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp b/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp index f759d2d0b..0016ee5f1 100644 --- a/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp +++ b/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp @@ -50,6 +50,54 @@ class ConvertReshapeLikeOpToByrePattern : public OpConversionPattern { } }; +class ConvertCastOpToByrePattern : public OpConversionPattern { +public: + ConvertCastOpToByrePattern(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(memref::CastOp op, memref::CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto subview = op.getSource().getDefiningOp()) { + if (!subview.getSource().getType().getLayout().isIdentity()) + return failure(); + if (!op.getType().cast().getLayout().isIdentity()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + subview.getSource(), 0); + return success(); + } + + return failure(); + } +}; + +class ConvertCollapseShapeOpToByrePattern + : public OpConversionPattern { +public: + ConvertCollapseShapeOpToByrePattern(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp op, + memref::CollapseShapeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto subview = op.getSrc().getDefiningOp()) { + if (!subview.getSource().getType().getLayout().isIdentity()) + return failure(); + if (!op.getType().getLayout().isIdentity()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + subview.getSource(), 0); + return success(); + } + + return failure(); + } +}; + class ConvertViewOpToByrePattern : public OpConversionPattern { public: ConvertViewOpToByrePattern(MLIRContext *ctx) @@ -196,7 +244,8 @@ void mlir::populateMemrefToByrePattern(RewritePatternSet &patterns) { ConvertGetGlobalOpToByrePattern, ConvertReshapeLikeOpToByrePattern, ConvertReshapeLikeOpToByrePattern, - ConvertSubViewOpToByrePattern>(patterns.getContext()); + ConvertSubViewOpToByrePattern, ConvertCastOpToByrePattern, + ConvertCollapseShapeOpToByrePattern>(patterns.getContext()); } std::unique_ptr> diff --git a/compiler/lib/Conversion/ToGPU/FuncToGPU.cpp b/compiler/lib/Conversion/ToGPU/FuncToGPU.cpp index 8210eb5f4..40b0a0a07 100644 --- a/compiler/lib/Conversion/ToGPU/FuncToGPU.cpp +++ b/compiler/lib/Conversion/ToGPU/FuncToGPU.cpp @@ -47,7 +47,7 @@ #define DEBUG_TYPE "func-to-gpu" // TODO: configurable coarsen factor -#define COARSEN_FACTOR 4 +#define COARSEN_FACTOR 1 using namespace llvm; using namespace mlir; @@ -55,6 +55,9 @@ using namespace mlir::arith; using namespace mlir::gpu; namespace { +constexpr int64_t kGridTileNumThreshold = 64; +constexpr int64_t kNumWave = 128; +constexpr int64_t kWarpSize = 32; static void creaetGuardedSIMT(OpBuilder &b, Value id, Value bound, LoopLikeOpInterface looplike, bool coarsen) { @@ -290,42 +293,86 @@ void setValidStaticGPUConfigAttr(func::FuncOp func, ArrayRef bs, if (toGPUSizes[i] <= 0) { toGPUSizes[i] = 1; } - - auto attr = IntegerAttr::get(IntegerType::get(ctx, 32), toGPUSizes[i]); - toGPUAttrs.push_back(attr); } // estimate maxGridSizes if possible SmallVector maxGridSizes = {0, 0, 0}; // collect loops from inner to outer - func.walk([&](LoopLikeOpInterface loopLike) { - if (loopLike->hasAttrOfType(getLoopToSIMTAttrName())) { - auto coarsen = - loopLike->hasAttrOfType(getCoarsenSIMTAttrName()); - int64_t factor = coarsen ? coarsenFactor : 1; - - auto strAttr = - loopLike->getAttrOfType(getLoopToSIMTAttrName()); - - if (strAttr.getValue() == getLinearIdXName()) { - maxGridSizes[0] = - estimateGridSize(loopLike, maxGridSizes[0], toGPUSizes[0] * factor); - } else if (strAttr.getValue() == getLinearIdYName()) { - maxGridSizes[1] = - estimateGridSize(loopLike, maxGridSizes[1], toGPUSizes[1] * factor); - } else if (strAttr.getValue() == getLinearIdZName()) { - maxGridSizes[2] = - estimateGridSize(loopLike, maxGridSizes[2], toGPUSizes[2] * factor); - } else if (strAttr.getValue() == getBlockIdXName()) { - maxGridSizes[0] = estimateGridSize(loopLike, maxGridSizes[0], factor); - } else if (strAttr.getValue() == getBlockIdYName()) { - maxGridSizes[1] = estimateGridSize(loopLike, maxGridSizes[1], factor); - } else if (strAttr.getValue() == getBlockIdZName()) { - maxGridSizes[2] = estimateGridSize(loopLike, maxGridSizes[2], factor); + bool firstCheck = true; + auto isSuitableConfig = [&]() -> bool { + if (llvm::all_of(maxGridSizes, [](int64_t val) { return val == 0; })) { + return false; + } + int64_t totalGridSize = 1; + for (auto v : maxGridSizes) { + if (v != 0) + totalGridSize *= v; + } + int64_t totalBlockSize = 1; + for (size_t i = 0; i < 3; ++i) { + totalBlockSize *= toGPUSizes[i]; + } + if (totalGridSize < kGridTileNumThreshold && + totalBlockSize >= kWarpSize * 2) { + return false; + } + return true; + }; + while (!isSuitableConfig()) { + if (!firstCheck) { + for (int64_t i = 2; i >= 0; --i) { + if (toGPUSizes[i] >= 2) { + toGPUSizes[i] /= 2; + break; + } } } - }); + firstCheck = false; + maxGridSizes = {0, 0, 0}; + + func.walk([&](LoopLikeOpInterface loopLike) { + if (loopLike->hasAttrOfType(getLoopToSIMTAttrName())) { + auto coarsen = + loopLike->hasAttrOfType(getCoarsenSIMTAttrName()); + int64_t factor = coarsen ? coarsenFactor : 1; + + auto strAttr = + loopLike->getAttrOfType(getLoopToSIMTAttrName()); + + if (strAttr.getValue() == getLinearIdXName()) { + maxGridSizes[0] = estimateGridSize(loopLike, maxGridSizes[0], + toGPUSizes[0] * factor); + } else if (strAttr.getValue() == getLinearIdYName()) { + maxGridSizes[1] = estimateGridSize(loopLike, maxGridSizes[1], + toGPUSizes[1] * factor); + } else if (strAttr.getValue() == getLinearIdZName()) { + maxGridSizes[2] = estimateGridSize(loopLike, maxGridSizes[2], + toGPUSizes[2] * factor); + } else if (strAttr.getValue() == getBlockIdXName()) { + maxGridSizes[0] = estimateGridSize(loopLike, maxGridSizes[0], factor); + } else if (strAttr.getValue() == getBlockIdYName()) { + maxGridSizes[1] = estimateGridSize(loopLike, maxGridSizes[1], factor); + } else if (strAttr.getValue() == getBlockIdZName()) { + maxGridSizes[2] = estimateGridSize(loopLike, maxGridSizes[2], factor); + } + } + }); + } + int64_t threshold = kGridTileNumThreshold * kNumWave; + for (size_t i = 0; i < maxGridSizes.size(); ++i) { + if (maxGridSizes[i] > threshold) { + maxGridSizes[i] = threshold; + } else { + threshold /= maxGridSizes[i]; + break; + } + } + + for (size_t i = 0; i < 3; ++i) { + auto attr = IntegerAttr::get(IntegerType::get(ctx, 32), toGPUSizes[i]); + toGPUAttrs.push_back(attr); + } for (size_t i = 0; i < 3; ++i) { size_t j = i + 3; // if no estimation use suggested attr value diff --git a/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp b/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp index 555553e25..fcb5fccd5 100644 --- a/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp +++ b/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp @@ -1116,6 +1116,90 @@ class ByteirRepeatCustomCallConverter } }; +/// Converts mhlo.concatenate operation to a linalg.generic op. +struct StaticConcatenateToLinalgWithIndexSwitch + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mhlo::ConcatenateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Shortcut the one-operand case, simplifies code below. + if (adaptor.getOperands().size() == 1) { + rewriter.replaceOp(op, adaptor.getOperands()[0]); + return success(); + } + + auto resultType = op.getResult().getType().dyn_cast(); + + if (!resultType) + return failure(); + + uint64_t dim = op.getDimension(); + for (auto operand : adaptor.getOperands()) { + auto operandType = operand.getType().dyn_cast(); + if (!operandType) + return failure(); + if (operandType.getShape()[dim] != 1) + return failure(); + } + + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0); + + // Allocate the output tensor with tensor.empty. + Value result = + getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); + + int64_t nloops = resultType.getRank(); + rewriter.replaceOpWithNewOp( + op, + /*resultTensorTypes=*/resultType, + /*inputs=*/ValueRange{}, /*outputBuffers=*/result, + llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), + getNParallelLoopsAttrs(nloops), + [&](OpBuilder &nestedBuilder, Location loc, ValueRange) { + OpBuilder b = nestedBuilder; + SmallVector cases; + int64_t resDimSize = resultType.getShape()[dim]; + for (int64_t i = 0; i < resDimSize - 1; ++i) + cases.emplace_back(i); + + SmallVector extractIndices; + extractIndices.reserve(nloops); + for (int64_t i = 0; i < nloops; i++) { + extractIndices.push_back(b.create(loc, i)); + } + + Value indexOp = b.create(loc, dim); + extractIndices[dim] = zero; + + auto indexSwitchOp = b.create( + loc, TypeRange{resultType.getElementType()}, indexOp, cases, + cases.size()); + + for (int64_t i = 0; i < resDimSize - 1; ++i) { + Block &curBlock = indexSwitchOp.getCaseRegions()[i].emplaceBlock(); + b.setInsertionPointToStart(&curBlock); + Value val = b.create( + loc, adaptor.getOperands()[i], extractIndices); + b.create(loc, val); + } + + Block &curBlock = indexSwitchOp.getDefaultRegion().emplaceBlock(); + b.setInsertionPointToStart(&curBlock); + Value val = b.create( + loc, adaptor.getOperands()[resDimSize - 1], extractIndices); + b.create(loc, val); + + nestedBuilder.create(loc, + indexSwitchOp.getResults()[0]); + }, + linalg::getPrunedAttributeList(op)); + return success(); + } +}; + struct HloFusionToLinalgPass : public HloFusionToLinalgBase { @@ -1173,6 +1257,8 @@ void mlir::populateHloToLinalgExtConversionPattern( patterns.add(typeConverter, ctx, PatternBenefit(2)); patterns.add(typeConverter, ctx, PatternBenefit(2)); + patterns.add(ctx, + PatternBenefit(2)); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/compiler/lib/Conversion/ToLinalg/TensorToLinalg.cpp b/compiler/lib/Conversion/ToLinalg/TensorToLinalg.cpp index 45a6fccac..99c242f14 100644 --- a/compiler/lib/Conversion/ToLinalg/TensorToLinalg.cpp +++ b/compiler/lib/Conversion/ToLinalg/TensorToLinalg.cpp @@ -19,6 +19,7 @@ #include "byteir/Utils/AffineUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -153,6 +154,124 @@ class CollapseShapeToLinalgGeneric } }; +class InsertSliceToLinalgGeneric + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::InsertSliceOp op, + tensor::InsertSliceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto ctx = op.getContext(); + auto source = op.getSource(); + auto dest = op.getDest(); + auto resultTy = op.getResultType(); + auto inputTy = op.getSourceType(); + int64_t nloops = resultTy.getRank(); + + if (ShapedType::isDynamicShape(resultTy.getShape())) { + return failure(); + } + + // Find input/output values and types. + auto loc = op.getLoc(); + auto emptyOp = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); + Value output = emptyOp.getResult(); + + SmallVector maps; + maps.push_back(AffineMap::getMultiDimIdentityMap(nloops, ctx)); + + // Build `linalg.generic` op. + ValueRange inputs = adaptor.getOperands(); + auto ranges = op.getOrCreateRanges(rewriter, loc); + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + + auto linalgOp = rewriter.create( + loc, resultTy, ValueRange(), output, maps, + getNParallelLoopsAttrs(nloops), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + SmallVector linalgIndices; + for (int64_t i = 0; i < nloops; ++i) { + linalgIndices.emplace_back( + nestedBuilder.create(loc, i)); + } + + SmallVector srcIndices; + Value predicate; + int64_t srcDimIdx = 0; + for (int64_t i = 0; i < resultTy.getRank(); ++i) { + if (droppedDims.test(i)) + continue; + srcIndices.emplace_back(linalgIndices[i]); + + AffineExpr offset, stride, idx; + SmallVector symbolVals = { + linalgIndices[i], + ranges[i].offset, + ranges[i].stride, + }; + + bindSymbols(nestedBuilder.getContext(), idx, offset, stride); + OpFoldResult remainder = affine::makeComposedFoldedAffineApply( + nestedBuilder, loc, + AffineMap::get(0, 3, (idx - offset) % stride), symbolVals); + + Value remainderVal = + getValueOrCreateConstantIndexOp(nestedBuilder, loc, remainder); + + OpFoldResult division = affine::makeComposedFoldedAffineApply( + nestedBuilder, loc, + AffineMap::get(0, 3, (idx - offset).floorDiv(stride)), + symbolVals); + Value divisionVal = + getValueOrCreateConstantIndexOp(nestedBuilder, loc, division); + + Value zero = nestedBuilder.create(loc, 0); + Value equalZero = nestedBuilder.create( + loc, arith::CmpIPredicate::eq, remainderVal, zero); + + Value size = getValueOrCreateConstantIndexOp(nestedBuilder, loc, + ranges[i].size); + Value inBound = nestedBuilder.create( + loc, arith::CmpIPredicate::ult, divisionVal, size); + Value curPredicate = + nestedBuilder.create(loc, equalZero, inBound); + + predicate = predicate ? nestedBuilder.create( + loc, predicate, curPredicate) + : curPredicate; + srcDimIdx += 1; + } + + auto ifPred = nestedBuilder.create( + loc, resultTy.getElementType(), predicate, + /*withElseRegion=*/true); + + // Pred == true, return source + { + OpBuilder ifPredThenB = ifPred.getThenBodyBuilder(); + Value val = + ifPredThenB.create(loc, source, srcIndices); + ifPredThenB.create(loc, val); + } + + // Pred == false, therefore return dest. + { + OpBuilder ifPredElseB = ifPred.getElseBodyBuilder(); + Value val = + ifPredElseB.create(loc, dest, linalgIndices); + ifPredElseB.create(loc, val); + } + nestedBuilder.create(loc, ifPred.getResults()); + }); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + struct TensorToLinalgPass : public TensorToLinalgBase { TensorToLinalgPass() = default; @@ -173,8 +292,8 @@ struct TensorToLinalgPass : public TensorToLinalgBase { target.addLegalDialect(); - target.addLegalOp(); + shape::ShapeDialect, affine::AffineDialect>(); + target.addLegalOp(); target.addDynamicallyLegalOp( [&](tensor::ExpandShapeOp op) { @@ -184,7 +303,10 @@ struct TensorToLinalgPass : public TensorToLinalgBase { [&](tensor::CollapseShapeOp op) { return !op.getResultType().hasStaticShape(); }); - + target.addDynamicallyLegalOp( + [&](tensor::ExtractSliceOp op) { + return !op.getResultType().hasStaticShape(); + }); populateTensorToLinalgConversionPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); if (failed(applyPartialConversion(func, target, frozenPatterns))) { @@ -228,8 +350,8 @@ LogicalResult mlir::simplifyTensorReshapeLikeOp(RewriterBase &rewriter, void mlir::populateTensorToLinalgConversionPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); } std::unique_ptr> mlir::createTensorToLinalgPass() { diff --git a/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp b/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp index ebc98da05..80d3945b9 100644 --- a/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp +++ b/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp @@ -49,32 +49,6 @@ bool anyIncompatibleUse(Value oldValue, Value newValue) { }) && (oldValue.getType() != newValue.getType()); } - -bool anyIncompatibleUseExceptReshapeOp(Value oldValue, Value newValue) { - return llvm::any_of(oldValue.getUses(), - [](OpOperand &operand) { - Operation *op = operand.getOwner(); - Dialect *dialect = op->getDialect(); - return llvm::isa(op) || - (dialect && dialect->getNamespace() == "byre"); - }) && - (oldValue.getType() != newValue.getType()); -} - -SmallVector getReshapeOp(Value value) { - SmallVector reshapeOps; - auto operation = value.getDefiningOp(); - while (operation && - isa(operation)) { - reshapeOps.push_back(operation); - value = operation->getOperand(0); - operation = value.getDefiningOp(); - } - if (operation && isa(operation)) - return reshapeOps; - return {}; -} - class RemoveCopyPattern : public OpRewritePattern { public: RemoveCopyPattern(MLIRContext *context, DominanceInfo &dom) @@ -225,41 +199,51 @@ class RemoveCopyPattern : public OpRewritePattern { } } - if (auto &&reshapeOps = getReshapeOp(src); reshapeOps.size()) { - auto &&srcAllocOp = aliases[0][0].getDefiningOp(); - if (auto targetDef = target.getDefiningOp()) { - if (isa(targetDef)) - hoistUpOpInBlock(targetDef, domInfo); + if (llvm::isa(target) && + isa(copyOp->getParentOp())) { + memref::AllocOp srcAllocOp; + for (auto alias : aliases[0]) { + auto defOp = alias.getDefiningOp(); + if (!defOp) { + return failure(); + } + if (!llvm::isa(defOp)) { + return failure(); + } + if (auto allocOp = dyn_cast(defOp)) { + srcAllocOp = allocOp; + } } - if (!domInfo.properlyDominates(target, srcAllocOp)) { - LLVM_DEBUG(llvm::dbgs() << "failed at target " << target - << " not dominated by " << srcAllocOp << "\n"); + if (!srcAllocOp || target.getType() != src.getType()) { return failure(); } + // using CollapseShapeOp/ExpandShapeOp reshape target to src alloc. rewriter.setInsertionPoint(srcAllocOp); + Value alias = src; Value reshapeTarget = target; - for (size_t pos = 0; pos < reshapeOps.size(); ++pos) { - auto shapeOp = reshapeOps[pos]; - if (auto collapseShapeOp = dyn_cast(shapeOp)) { + while (!alias.getDefiningOp()) { + auto defOp = alias.getDefiningOp(); + if (auto collapseShapeOp = dyn_cast(defOp)) { + // FIXME: expandShape doesn't support expanding dynamic dims. reshapeTarget = rewriter.create( - srcAllocOp.getLoc(), collapseShapeOp.getSrcType(), reshapeTarget, + alias.getLoc(), collapseShapeOp.getSrcType(), reshapeTarget, collapseShapeOp.getReassociationIndices()); + alias = collapseShapeOp.getSrc(); } else if (auto expandShapeOp = - dyn_cast(shapeOp)) { + dyn_cast(defOp)) { reshapeTarget = rewriter.create( srcAllocOp.getLoc(), expandShapeOp.getSrcType(), reshapeTarget, expandShapeOp.getReassociationIndices()); + alias = expandShapeOp.getSrc(); } } - if (!anyIncompatibleUseExceptReshapeOp(srcAllocOp, target)) { - rewriter.replaceOp(srcAllocOp, {reshapeTarget}); - } - if (!anyIncompatibleUse(src, target)) - rewriter.replaceOp(src.getDefiningOp(), {target}); - return success(); + rewriter.replaceOp(srcAllocOp, {reshapeTarget}); + rewriter.eraseOp(copyOp); } + return failure(); } @@ -307,4 +291,4 @@ void mlir::populateRemoveCopyAfterBufferizationPattern( std::unique_ptr> mlir::createRemoveCopyPass() { return std::make_unique(); -} +} \ No newline at end of file diff --git a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt index 418a394ca..58fe58399 100644 --- a/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(ByteIRSCFPasses + FuseNestedForall.cpp InsertTrivialSCFLoop.cpp TilingInterfaceToSCFFor.cpp diff --git a/compiler/lib/Dialect/SCF/Transforms/FuseNestedForall.cpp b/compiler/lib/Dialect/SCF/Transforms/FuseNestedForall.cpp new file mode 100644 index 000000000..1b395ef81 --- /dev/null +++ b/compiler/lib/Dialect/SCF/Transforms/FuseNestedForall.cpp @@ -0,0 +1,243 @@ +//===- FuseNestedForall.cpp ------------------------------------ C++ --===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" +#include "byteir/Utils/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "llvm/ADT/DenseSet.h" +#include + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::scf; + +namespace { + +static bool checkMappingAttributeTypes(SmallVector mapping) { + if (mapping.empty()) { + return true; + } + + bool hasBlockMapping = llvm::any_of(mapping, [](Attribute attr) { + return isa(attr); + }); + bool hasWarpgroupMapping = llvm::any_of(mapping, [](Attribute attr) { + return isa(attr); + }); + bool hasWarpMapping = llvm::any_of(mapping, [](Attribute attr) { + return isa(attr); + }); + bool hasThreadMapping = llvm::any_of(mapping, [](Attribute attr) { + return isa(attr); + }); + int64_t countMappingTypes = 0; + countMappingTypes += hasBlockMapping ? 1 : 0; + countMappingTypes += hasWarpgroupMapping ? 1 : 0; + countMappingTypes += hasWarpMapping ? 1 : 0; + countMappingTypes += hasThreadMapping ? 1 : 0; + if (countMappingTypes > 1) { + return false; + } + + llvm::DenseSet seen; + for (Attribute map : mapping) { + if (seen.contains(map)) { + return false; + } + seen.insert(map); + } + + auto isLinear = [](Attribute a) { + return cast(a).isLinearMapping(); + }; + + if (llvm::any_of(mapping, isLinear) && !llvm::all_of(mapping, isLinear)) { + return false; + } + + return true; +} + +bool isPerfectNestedForall(scf::ForallOp parentForall, + scf::ForallOp nestedForall) { + Block &body = parentForall.getRegion().front(); + scf::InParallelOp parentReturnOp = parentForall.getTerminator(); + scf::InParallelOp nestedReturnOp = nestedForall.getTerminator(); + + // InParallelOp has a single region with a single block + if (!parentReturnOp.getRegion().front().empty() || + !nestedReturnOp.getRegion().front().empty()) + return false; + + Operation *lastOp = &(*std::prev(body.end(), 2)); + + if (!llvm::isa(lastOp) || + lastOp != nestedForall.getOperation()) { + return false; + } + + SmallVector mixedLb = nestedForall.getMixedLowerBound(); + SmallVector mixedUb = nestedForall.getMixedUpperBound(); + SmallVector mixedStep = nestedForall.getMixedStep(); + + auto isValueInParentForallBody = + [&](const SmallVector &config) -> bool { + for (OpFoldResult ofr : config) { + auto maybeCst = getConstantIntValue(ofr); + if (!maybeCst.has_value()) { + Value v = ofr.get(); + if (v.getParentBlock() == &body) { + return true; + } + } + } + return false; + }; + + if (isValueInParentForallBody(mixedLb) || + isValueInParentForallBody(mixedUb) || + isValueInParentForallBody(mixedStep)) { + return false; + } + + SmallVector mapping; + auto mappingAttrs = llvm::to_vector(parentForall.getMappingAttr()); + mappingAttrs.append(nestedForall.getMappingAttr().begin(), + nestedForall.getMappingAttr().end()); + size_t numLoops = parentForall.getInductionVars().size() + + nestedForall.getInductionVars().size(); + if (numLoops != mappingAttrs.size() || + !checkMappingAttributeTypes(mappingAttrs)) { + return false; + } + + return true; +} + +scf::ForallOp fuseNestedForallImpl(scf::ForallOp parentForall, + scf::ForallOp nestedForall) { + IRRewriter rewriter(parentForall.getContext()); + Location loc = parentForall.getLoc(); + + auto outputs = llvm::to_vector(parentForall.getOutputs()); + auto nestOutpus = llvm::to_vector(nestedForall.getOutputs()); + outputs.append(nestOutpus.begin(), nestOutpus.end()); + + SmallVector mixedLb = parentForall.getMixedLowerBound(); + SmallVector mixedUb = parentForall.getMixedUpperBound(); + SmallVector mixedStep = parentForall.getMixedStep(); + SmallVector nestedLb = nestedForall.getMixedLowerBound(); + SmallVector nestedUb = nestedForall.getMixedUpperBound(); + SmallVector nestedStep = nestedForall.getMixedStep(); + + mixedLb.append(nestedLb.begin(), nestedLb.end()); + mixedUb.append(nestedUb.begin(), nestedUb.end()); + mixedStep.append(nestedStep.begin(), nestedStep.end()); + + auto mappingAttrs = llvm::to_vector(parentForall.getMappingAttr()); + auto nestMappingAttrs = llvm::to_vector(nestedForall.getMappingAttr()); + mappingAttrs.append(nestMappingAttrs.begin(), nestMappingAttrs.end()); + + rewriter.setInsertionPoint(parentForall); + auto newForallOp = + rewriter.create(loc, mixedLb, mixedUb, mixedStep, outputs, + rewriter.getArrayAttr(mappingAttrs)); + newForallOp.getTerminator()->erase(); + + Block *parentForallLoopBody = parentForall.getBody(); + Block *newLoopBody = newForallOp.getBody(); + rewriter.mergeBlocks(parentForallLoopBody, newLoopBody, + newLoopBody->getArguments().take_front( + parentForallLoopBody->getNumArguments())); + + scf::ForallOp clonedNestedForallOp = + cast(&(*std::prev(newLoopBody->end(), 2))); + + rewriter.setInsertionPoint(newForallOp.getTerminator()); + IRMapping bvm; + for (auto [oldIv, newIv] : + llvm::zip_equal(clonedNestedForallOp.getInductionVars(), + newLoopBody->getArguments().take_back( + clonedNestedForallOp.getInductionVars().size()))) { + bvm.map(oldIv, newIv); + } + + for (Operation &op : clonedNestedForallOp.getBody()->without_terminator()) + rewriter.clone(op, bvm); + + clonedNestedForallOp->erase(); + parentForall->erase(); + + return newForallOp; +} + +struct FuseNestedForallPass + : public FuseNestedForallBase { + FuseNestedForallPass(llvm::StringRef anchor) : FuseNestedForallBase() { + anchorTag = anchor.str(); + } + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + + // skip non-anchored + if (!anchorTag.empty() && !funcOp->hasAttr(anchorTag)) { + return; + } + + llvm::DenseMap> fuseCluster; + funcOp->walk([&](scf::ForallOp curForallOp) { + fuseCluster[curForallOp] = {curForallOp}; + }); + + funcOp->walk([&](scf::ForallOp curForallOp) { + if (auto parentForallOp = curForallOp->getParentOfType()) { + if (isPerfectNestedForall(parentForallOp, curForallOp)) { + fuseCluster[parentForallOp].append(fuseCluster[curForallOp].begin(), + fuseCluster[curForallOp].end()); + fuseCluster.erase(curForallOp); + } + } + }); + + for (const auto &cluster : fuseCluster) { + if (cluster.second.size() < 2) { + continue; + } + // from inside out + auto loops = llvm::to_vector(llvm::reverse(cluster.second)); + scf::ForallOp nestedForall = loops[0]; + for (size_t i = 1; i < loops.size(); ++i) { + nestedForall = fuseNestedForallImpl(loops[i], nestedForall); + } + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::createFuseNestedForallPass(llvm::StringRef anchor) { + return std::make_unique(anchor); +} diff --git a/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp b/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp index 4a3975262..e47e755d3 100644 --- a/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp +++ b/compiler/lib/Dialect/Tensor/Transforms/CanonicalizeExt.cpp @@ -28,7 +28,9 @@ #include "byteir/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -52,6 +54,19 @@ using namespace mlir; namespace { +static bool isNormalizeExtractSlice(tensor::ExtractSliceOp extractSliceOp) { + ArrayRef offsets = extractSliceOp.getStaticOffsets(); + ArrayRef strides = extractSliceOp.getStaticStrides(); + if (!llvm::all_of(offsets, [](int64_t v) { return v == 0; })) { + return false; + } + + if (!llvm::all_of(strides, [](int64_t v) { return v == 1; })) { + return false; + } + return true; +} + std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes) { SmallVector reassociation; @@ -74,6 +89,194 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes) { return reassociation; } +static std::optional +reifyDimsInSliceOp(Operation *op, RewriterBase &rewriter, int64_t dim) { + if (auto expandOp = llvm::dyn_cast(op)) { + auto ReassociationIndices = expandOp.getReassociationIndices(); + auto src = expandOp.getSrc(); + RankedTensorType resType = expandOp.getType(); + if (RankedTensorType srcType = src.getType().dyn_cast()) { + int64_t dynDimCount = 0; + + if (resType.getShape()[dim] != ShapedType::kDynamic) { + return std::nullopt; + } + int srcDim = expandOp.getCorrespondingSourceDim(dim); + for (auto idx : ReassociationIndices[srcDim]) { + + if (resType.getShape()[idx] == ShapedType::kDynamic) { + dynDimCount += 1; + } else if (resType.getShape()[idx] != 1) { + return std::nullopt; + } + } + if (src.getDefiningOp() && dynDimCount == 1) { + return reifyDimsInSliceOp(src.getDefiningOp(), rewriter, srcDim); + } + } + } else if (auto extractSliceOp = llvm::dyn_cast(op)) { + ReifiedRankedShapedTypeDims reifiedShapes; + if (failed(extractSliceOp.reifyResultShapes(rewriter, reifiedShapes))) { + return std::nullopt; + } + return reifiedShapes[0][dim]; + } + return std::nullopt; +} + +/// When the shape of extracted_slice equal to input tensor, +/// convert fill + extracted_slice to fill + collaspe_slice. +/// +/// Example: +/// %0 = tensor.extract_slice %arg2[0, %arg1] [%dim_0, 1] [1, 1] : +/// tensor to tensor %1 = tensor.expand_shape %0 [[0, 1]] : +/// tensor into tensor %2 = linalg.fill ins(%cst : f32) outs(%1 +/// : tensor) -> tensor %3 = tensor.extract_slice %2[0, 0] +/// [%dim_0, 1] [1, 1] : tensor to tensor +/// +/// will be converted to +/// +/// %0 = tensor.extract_slice %arg2[0, %arg1] [%dim_0, 1] [1, 1] : +/// tensor to tensor %1 = tensor.expand_shape %0 [[0, 1]] : +/// tensor into tensor %2 = linalg.fill ins(%cst : f32) outs(%1 +/// : tensor) -> tensor %3 = tensor.collapse_shape %2[0, 0] +/// [%dim_0, 1] [1, 1] : tensor to tensor +struct ExtractFullSliceFromLinalgFillOp + : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const override { + if (!isNormalizeExtractSlice(extractSliceOp)) { + return failure(); + } + + linalg::FillOp fillOp = + extractSliceOp.getSource().getDefiningOp(); + if (!fillOp || fillOp.getNumResults() != 1) { + return failure(); + } + + ReifiedRankedShapedTypeDims fillReifiedShapes; + if (failed(fillOp.reifyResultShapes(rewriter, fillReifiedShapes))) { + return failure(); + } + + // reifyDims if getDefiningOp is expandShape/extractSlice. + for (size_t i = 0; i < fillReifiedShapes[0].size(); ++i) { + auto maybeCst = getConstantIntValue(fillReifiedShapes[0][i]); + if (maybeCst.has_value()) + continue; + Value val = fillReifiedShapes[0][i].get(); + if (auto dimOp = val.getDefiningOp()) { + auto maybeCstIdx = dimOp.getConstantIndex(); + if (maybeCstIdx.has_value() && dimOp.getSource().getDefiningOp()) { + auto reifedDim = reifyDimsInSliceOp(dimOp.getSource().getDefiningOp(), + rewriter, maybeCstIdx.value()); + if (reifedDim.has_value()) + fillReifiedShapes[0][i] = reifedDim.value(); + } + } + } + + SmallVector sizes = extractSliceOp.getMixedSizes(); + if (!isEqualConstantIntOrValueArray(sizes, fillReifiedShapes[0])) { + return failure(); + } + + RankedTensorType resultType = extractSliceOp.getType(); + if (resultType.getRank() == + static_cast(fillReifiedShapes[0].size())) { + rewriter.replaceOp(extractSliceOp, fillOp.getResult(0)); + return success(); + } + + // convert extract_slice to collapse_shape + SmallVector reassociation; + int64_t srcIdx = 0; + for (int64_t i = 0; i < resultType.getRank(); ++i) { + if (resultType.getShape()[i] == 1) { + reassociation.emplace_back(ReassociationIndices{srcIdx}); + srcIdx += 1; + } else { + ReassociationIndices indices; + while (srcIdx < static_cast(sizes.size())) { + indices.emplace_back(srcIdx); + auto maybeCst = getConstantIntValue(sizes[srcIdx]); + srcIdx += 1; + if (!maybeCst.has_value() || maybeCst.value() != 1) { + break; + } + } + reassociation.emplace_back(indices); + } + } + + // reassociation is empty if tensor with rank 0 + if (resultType.getRank() > 0) { + while (srcIdx < static_cast(sizes.size())) { + reassociation.back().emplace_back(srcIdx); + srcIdx += 1; + } + } + + rewriter.replaceOpWithNewOp( + extractSliceOp, fillOp.getResult(0), reassociation); + + return success(); + } +}; + +struct ExtractFullSliceFromSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp curSliceOp, + PatternRewriter &rewriter) const override { + if (!isNormalizeExtractSlice(curSliceOp)) { + return failure(); + } + + tensor::ExtractSliceOp preSliceOp = + curSliceOp.getSource().getDefiningOp(); + if (!preSliceOp) + return failure(); + RankedTensorType preResultType = preSliceOp.getType(); + RankedTensorType curResultType = curSliceOp.getType(); + if (preResultType.getRank() != curResultType.getRank()) { + return failure(); + } + + if (!isNormalizeExtractSlice(preSliceOp)) { + return failure(); + } + + SmallVector preSizes = preSliceOp.getMixedSizes(); + SmallVector curSizes = curSliceOp.getMixedSizes(); + preSizes.erase(std::remove_if(preSizes.begin(), preSizes.end(), + [](OpFoldResult ofr) { + auto maybeCst = getConstantIntValue(ofr); + return maybeCst.has_value() && + maybeCst.value() == 1; + }), + preSizes.end()); + + curSizes.erase(std::remove_if(curSizes.begin(), curSizes.end(), + [](OpFoldResult ofr) { + auto maybeCst = getConstantIntValue(ofr); + return maybeCst.has_value() && + maybeCst.value() == 1; + }), + curSizes.end()); + if (!isEqualConstantIntOrValueArray(preSizes, curSizes)) { + return failure(); + } + + rewriter.replaceOp(curSliceOp, preSliceOp.getResult()); + return success(); + } +}; + /// Fold extract_slice + collapse_shape into rank reduced extract_slice /// /// Example: @@ -162,6 +365,32 @@ struct FoldZeroRankFromElementsInsertSlice return success(); } }; + +struct EliminateTensorExtractFromInsert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + auto insertOp = extractOp.getTensor().getDefiningOp(); + if (!insertOp) { + return failure(); + } + + SmallVector insert_idx = insertOp.getIndices(); + SmallVector extract_idx = insertOp.getIndices(); + if (insert_idx.size() != extract_idx.size()) { + return failure(); + } + for (auto [x, y] : llvm::zip(insert_idx, extract_idx)) { + if (!x || x != y) { + return failure(); + } + } + rewriter.replaceOp(extractOp, insertOp.getScalar()); + return success(); + } +}; } // namespace void mlir::tensor::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, @@ -174,6 +403,9 @@ void mlir::tensor::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); } void mlir::tensor::getCanonicalizationExtPatterns(RewritePatternSet &patterns, diff --git a/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp b/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp index 470bdf567..8a751ff45 100644 --- a/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp +++ b/compiler/lib/Dialect/Transform/Transforms/TransformInsertion.cpp @@ -93,11 +93,13 @@ void insertTransformIR(ModuleOp m, const TransformInsertionConfig &config) { struct DetensorizeTransformInsertionPass : public DetensorizeTransformInsertionBase< DetensorizeTransformInsertionPass> { - explicit DetensorizeTransformInsertionPass(const std::string &funcAnchor, + explicit DetensorizeTransformInsertionPass(const bool usingVectorizeOp, + const std::string &funcAnchor, const std::string &matchPrefix) : DetensorizeTransformInsertionBase() { this->funcAnchorAttr = funcAnchor; this->matchPrefix = matchPrefix; + this->usingVectorizeOp = usingVectorizeOp; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -135,9 +137,18 @@ struct DetensorizeTransformInsertionPass return false; }; - auto transformBuilder = [](ImplicitLocOpBuilder &b, Operation *, - Value pdlValue) { - b.create(pdlValue); + auto transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlValue) { + if (usingVectorizeOp && llvm::isa(op)) { + b.create( + /* target */ pdlValue, + /* vector_sizes */ ValueRange{}, + /* vectorize_nd_extract */ b.getUnitAttr(), + /* scalable_sizes */ SmallVector{}, + /* static_vector_sizes */ SmallVector{}); + } else { + b.create(pdlValue); + } }; insertTransformIR(getOperation(), {funcAnchorAttr, matchPrefix, opFilter, @@ -256,10 +267,11 @@ struct RewriteInDPSTransformInsertionPass } // namespace std::unique_ptr> -mlir::createDetensorizeTransformInsertionPass(const std::string &funcAnchor, +mlir::createDetensorizeTransformInsertionPass(const bool usingVectorizeOp, + const std::string &funcAnchor, const std::string &matchPrefix) { - return std::make_unique(funcAnchor, - matchPrefix); + return std::make_unique( + usingVectorizeOp, funcAnchor, matchPrefix); } std::unique_ptr> @@ -283,4 +295,4 @@ mlir::createRewriteInDPSTransformInsertionPass(const std::string &funcAnchor, const std::string &matchPrefix) { return std::make_unique(funcAnchor, matchPrefix); -} \ No newline at end of file +} diff --git a/compiler/lib/Dialect/Vector/Transforms/CMakeLists.txt b/compiler/lib/Dialect/Vector/Transforms/CMakeLists.txt index 51d02b3cd..01ae7f877 100644 --- a/compiler/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_dialect_library(ByteIRVectorPasses CanonicalizeExt.cpp VectorLowerings.cpp + VectorWarpDistribute.cpp + MoveForallRegionIntoWarpOp.cpp ADDITIONAL_HEADER_DIRS ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Dialect/Vector/Transforms diff --git a/compiler/lib/Dialect/Vector/Transforms/CanonicalizeExt.cpp b/compiler/lib/Dialect/Vector/Transforms/CanonicalizeExt.cpp index b5b88580b..cf3c7823e 100644 --- a/compiler/lib/Dialect/Vector/Transforms/CanonicalizeExt.cpp +++ b/compiler/lib/Dialect/Vector/Transforms/CanonicalizeExt.cpp @@ -88,9 +88,78 @@ struct CoalecsedForExtractFromShapeCast return success(); } }; + +struct EliminateExtractElementFromInsertElement + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractElementOp extractEleOp, + PatternRewriter &rewriter) const override { + auto insertEleOp = + extractEleOp.getVector().getDefiningOp(); + if (!insertEleOp) { + return failure(); + } + + VectorType vectorType = extractEleOp.getSourceVectorType(); + Value insertPos = insertEleOp.getPosition(); + Value extractPos = extractEleOp.getPosition(); + + if ((!extractPos && !insertPos) || + (extractPos && extractPos == insertPos)) { + rewriter.replaceOp(extractEleOp, insertEleOp.getSource()); + return success(); + } + + return failure(); + } +}; + +struct EliminateExtractElementFromSplat + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractElementOp extractEleOp, + PatternRewriter &rewriter) const override { + auto splatOp = extractEleOp.getVector().getDefiningOp(); + if (!splatOp) { + return failure(); + } + rewriter.replaceOp(extractEleOp, splatOp.getInput()); + return success(); + } +}; + +struct EliminateExtractElementFromConstant + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractElementOp extractEleOp, + PatternRewriter &rewriter) const override { + auto constantOp = + extractEleOp.getVector().getDefiningOp(); + if (!constantOp) { + return failure(); + } + auto attr = llvm::dyn_cast(constantOp.getValue()); + if (!attr || !attr.isSplat()) { + return failure(); + } + + auto splatValue = attr.getSplatValue(); + rewriter.replaceOpWithNewOp( + extractEleOp, extractEleOp.getType(), splatValue); + + return success(); + } +}; + } // namespace void mlir::vector::populateCanonicalizeExtPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/compiler/lib/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.cpp b/compiler/lib/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.cpp new file mode 100644 index 000000000..606b46dcb --- /dev/null +++ b/compiler/lib/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.cpp @@ -0,0 +1,176 @@ +//===- MoveForallRegionIntoWarpOp.cpp ------------------------------------ C++ +//--===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/Vector/Transforms/Passes.h" + +#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::scf; +namespace mlir { +#define GEN_PASS_DEF_MOVEFORALLREGIONINTOWARPOPPASS +#include "byteir/Dialect/Vector/Transforms/Passes.h.inc" +} // namespace mlir + +namespace { +static std::optional getLogicalWarpSize(scf::ForallOp forallOp, + int64_t warpSize) { + int64_t newWarpSize = 1; + bool hasVectorOp = false; + Block *loopBody = forallOp.getBody(); + for (auto &op : loopBody->without_terminator()) { + if (llvm::isa(op.getDialect())) { + hasVectorOp = true; + for (auto result : op.getResults()) { + if (VectorType vecType = result.getType().dyn_cast()) { + if (vecType.getRank() > 2) { + return std::nullopt; + } + + if (vecType.getRank() == 1) { + int64_t vectorSize = vecType.getShape()[0]; + if (vectorSize > warpSize && vectorSize % warpSize == 0) { + return std::nullopt; + } + + if (vectorSize <= 0 || __builtin_popcount(vectorSize) != 1) { + return std::nullopt; + } + + if (vectorSize <= warpSize) { + newWarpSize = std::max(vectorSize, newWarpSize); + } + } + } + } + } + } + if (!hasVectorOp) { + return std::nullopt; + } + return newWarpSize; +} +static bool isDistributedToWarp(scf::ForallOp forallOp, int64_t warpSize) { + bool onlyMapToWarp = + llvm::all_of(forallOp.getMappingAttr(), [](Attribute attr) { + return isa(attr); + }); + + if (!onlyMapToWarp) + return false; + if (!getLogicalWarpSize(forallOp, warpSize).has_value()) + return false; + return true; +} + +struct MoveForallRegionIntoWarpOpPass + : public impl::MoveForallRegionIntoWarpOpPassBase< + MoveForallRegionIntoWarpOpPass> { + MoveForallRegionIntoWarpOpPass(int64_t warpSize, llvm::StringRef anchor) + : MoveForallRegionIntoWarpOpPassBase() { + anchorTag = anchor.str(); + this->warpSize = warpSize; + } + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + + // skip non-anchored + if (!anchorTag.empty() && !funcOp->hasAttr(anchorTag)) { + return; + } + + funcOp->walk([&](scf::ForallOp forallOp) { + if (isDistributedToWarp(forallOp, warpSize)) { + int64_t logicalWarpSize = + getLogicalWarpSize(forallOp, warpSize).value(); + + // save original op in forall loop body + Block &sourceBlock = forallOp.getRegion().front(); + SmallVector opsInLoopBody; + for (auto &op : sourceBlock.without_terminator()) { + opsInLoopBody.emplace_back(&op); + } + + Location loc = forallOp.getLoc(); + mlir::OpBuilder builder(forallOp); + Block *targetBlock = forallOp.getBody(); + Block::iterator insertionPoint = forallOp.getBody()->begin(); + + // create laneid + builder.setInsertionPointToStart(forallOp.getBody()); + auto laneId = builder.create(loc); + + // create guard + if (logicalWarpSize < warpSize) { + auto predicate = builder.create( + loc, arith::CmpIPredicate::ult, laneId, + builder.create(loc, logicalWarpSize)); + auto ifOp = builder.create(loc, predicate, + /*withElseRegion=*/false); + targetBlock = ifOp.thenBlock(); + insertionPoint = ifOp.thenBlock()->begin(); + } + + // create WarpExecuteOnLane0Op and terminator + builder.setInsertionPoint(targetBlock, insertionPoint); + auto warpOp = builder.create( + loc, TypeRange{}, laneId, logicalWarpSize); + builder.setInsertionPointToStart(warpOp.getBody()); + builder.create(warpOp.getLoc()); + + // clone loop body into WarpExecuteOnLane0Op + builder.setInsertionPoint(warpOp.getBody()->getTerminator()); + IRMapping bvm; + for (auto op : opsInLoopBody) { + builder.clone(*op, bvm); + } + + // remove ops in loop body + for (auto op : llvm::reverse(opsInLoopBody)) { + op->erase(); + } + } + }); + } +}; + +} // namespace + +std::unique_ptr> +mlir::createMoveForallRegionIntoWarpOpPass(int64_t warpSize, + llvm::StringRef anchor) { + return std::make_unique(warpSize, anchor); +} diff --git a/compiler/lib/Dialect/Vector/Transforms/PassDetail.h b/compiler/lib/Dialect/Vector/Transforms/PassDetail.h new file mode 100644 index 000000000..926ee01be --- /dev/null +++ b/compiler/lib/Dialect/Vector/Transforms/PassDetail.h @@ -0,0 +1,35 @@ +//===- PassDetail.h -------------------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_PASSDETAIL_H +#define BYTEIR_DIALECT_SCF_TRANSFORMS_PASSDETAIL_H + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" + +// forward dialects for conversions +namespace mlir { +namespace vector { +class VectorDialect; +} // namespace vector + +#define GEN_PASS_CLASSES +#include "byteir/Dialect/Vector/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_PASSDETAIL_H diff --git a/compiler/lib/Dialect/Vector/Transforms/VectorLowerings.cpp b/compiler/lib/Dialect/Vector/Transforms/VectorLowerings.cpp index 30171718d..21d09998a 100644 --- a/compiler/lib/Dialect/Vector/Transforms/VectorLowerings.cpp +++ b/compiler/lib/Dialect/Vector/Transforms/VectorLowerings.cpp @@ -19,7 +19,9 @@ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -35,6 +37,7 @@ namespace mlir { #define GEN_PASS_DEF_VECTORTRANSPOSELOWERINGPASS +#define GEN_PASS_DEF_SCALARVECTORLOWERINGPASS #include "byteir/Dialect/Vector/Transforms/Passes.h.inc" } // namespace mlir @@ -70,4 +73,22 @@ struct VectorTransposeLoweringPass return signalPassFailure(); } }; + +struct ScalarVectorLoweringPass + : public impl::ScalarVectorLoweringPassBase { + using ScalarVectorLoweringPassBase::ScalarVectorLoweringPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + vector::populateVectorBroadcastLoweringPatterns(patterns); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + vector::populateBubbleVectorBitCastOpPatterns(patterns); + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::populateScalarVectorTransferLoweringPatterns(patterns, + /*benefit=*/1, true); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; } // namespace diff --git a/compiler/lib/Dialect/Vector/Transforms/VectorWarpDistribute.cpp b/compiler/lib/Dialect/Vector/Transforms/VectorWarpDistribute.cpp new file mode 100644 index 000000000..577332f99 --- /dev/null +++ b/compiler/lib/Dialect/Vector/Transforms/VectorWarpDistribute.cpp @@ -0,0 +1,206 @@ +//===- VectorWarpDistribute.cpp -----------------------------*--- C++ -*-= == // +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from TestVectorTransforms.cpp in LLVM project +// Original license: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/Vector/Transforms/Passes.h" + +#include "byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include "./PassDetail.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::vector; + +namespace { +/// Allocate shared memory for a single warp to test lowering of +/// WarpExecuteOnLane0Op. +static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, + WarpExecuteOnLane0Op warpOp, + Type type) { + static constexpr int64_t kSharedMemorySpace = 3; + // Compute type of shared memory buffer. + MemRefType memrefType; + if (auto vectorType = dyn_cast(type)) { + memrefType = + MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, + kSharedMemorySpace); + } else { + memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); + } + + // Get symbol table holding all shared memory globals. + ModuleOp moduleOp = warpOp->getParentOfType(); + SymbolTable symbolTable(moduleOp); + + // Create a pretty name. + SmallString<64> buf; + llvm::raw_svector_ostream os(buf); + interleave(memrefType.getShape(), os, "x"); + os << "x" << memrefType.getElementType(); + std::string symbolName = (Twine("__shared_") + os.str()).str(); + + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(moduleOp); + auto global = builder.create( + loc, + /*sym_name=*/symbolName, + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/memrefType, + /*initial_value=*/Attribute(), + /*constant=*/false, + /*alignment=*/IntegerAttr()); + symbolTable.insert(global); + // The symbol table inserts at the end of the module, but globals are a bit + // nicer if they are at the beginning. + global->moveBefore(&moduleOp.front()); + + builder.restoreInsertionPoint(ip); + return builder.create(loc, memrefType, symbolName); +} + +static Value warpReduction(Location loc, OpBuilder &builder, Value input, + CombiningKind kind, uint32_t size) { + // First reduce on a single thread to get per lane reduction value. + Value laneVal = builder.create(loc, kind, input); + // Parallel reduction using butterfly shuffles. + for (uint64_t i = 1; i < size; i <<= 1) { + Value shuffled = builder + .create(loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); + laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); + } + return laneVal; +} + +struct VectorWarpDistributePass + : public VectorWarpDistributePassBase { + VectorWarpDistributePass(const VectorWarpDistributePassOptions &options) + : VectorWarpDistributePassBase() { + warpOpToSCF = options.warpOpToSCF; + distributeTransferWriteOps = options.distributeTransferWriteOps; + hoistUniform = options.hoistUniform; + propagateDistribution = options.propagateDistribution; + maxTransferWriteElements = options.maxTransferWriteElements; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + getOperation().walk([&](Operation *op) { + if (auto warpOp = dyn_cast(op)) { + if (hoistUniform) { + moveScalarUniformCode(warpOp); + } + WalkResult::interrupt(); + } + }); + MLIRContext *ctx = &getContext(); + auto distributionFn = [](Value val) { + // Create an identity dim map of the same rank as the vector. + VectorType vecType = dyn_cast(val.getType()); + int64_t vecRank = vecType ? vecType.getRank() : 0; + OpBuilder builder(val.getContext()); + if (vecRank == 0) + return AffineMap::get(val.getContext()); + return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); + }; + auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, + Value srcIdx, int64_t warpSz) { + assert((val.getType().isF32() || val.getType().isInteger(32)) && + "unsupported shuffle type"); + Type i32Type = builder.getIntegerType(32); + Value srcIdxI32 = + builder.create(loc, i32Type, srcIdx); + Value warpSzI32 = builder.create( + loc, builder.getIntegerAttr(i32Type, warpSz)); + Value result = builder + .create(loc, val, srcIdxI32, warpSzI32, + gpu::ShuffleMode::IDX) + .getResult(0); + return result; + }; + if (distributeTransferWriteOps && propagateDistribution) { + RewritePatternSet patterns(ctx); + vector::populatePropagateWarpVectorDistributionPatterns( + patterns, distributionFn, shuffleFn, /*benefit=*/1, + /*readBenefit=*/0); + vector::populateDistributeReduction(patterns, warpReduction, 1); + populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } else if (distributeTransferWriteOps) { + RewritePatternSet patterns(ctx); + populateDistributeTransferWriteOpPatterns(patterns, distributionFn, + maxTransferWriteElements); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } else if (propagateDistribution) { + RewritePatternSet patterns(ctx); + vector::populatePropagateWarpVectorDistributionPatterns( + patterns, distributionFn, shuffleFn); + vector::populateDistributeReduction(patterns, warpReduction); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + WarpExecuteOnLane0LoweringOptions options; + options.warpAllocationFn = allocateGlobalSharedMemory; + options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, + WarpExecuteOnLane0Op warpOp) { + builder.create(loc); + }; + // Test on one pattern in isolation. + if (warpOpToSCF) { + populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + return; + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::createVectorWarpDistributePass( + const VectorWarpDistributePassOptions &options) { + return std::make_unique(options); +} diff --git a/compiler/lib/Dialect/mhlo/Transforms/FusionOutlining.cpp b/compiler/lib/Dialect/mhlo/Transforms/FusionOutlining.cpp index 89b5ca11d..43b1b262f 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/FusionOutlining.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/FusionOutlining.cpp @@ -38,6 +38,20 @@ namespace { static unsigned cnt = 0; +static std::string getNameByDominateOp(mhlo::FusionOp fusionOp) { + Block &block = fusionOp.getFusedComputation().front(); + for (mlir::Operation &op : block.getOperations()) { + if (isa(op)) { + return "Transpose"; + } else if (isa(op)) { + return "Concat"; + } else if (isa(op)) { + return "Reduce"; + } + } + return "Elementwise"; +} + static std::string getOutlineFuncitonName(mhlo::FusionOp fusionOp, unsigned &cnt) { StringAttr nameAttr = @@ -45,7 +59,7 @@ static std::string getOutlineFuncitonName(mhlo::FusionOp fusionOp, std::string funcName; if (nameAttr == nullptr) { - funcName = "Unknown" + Twine(cnt++).str(); + funcName = getNameByDominateOp(fusionOp) + Twine(cnt++).str(); } else { funcName = nameAttr.getValue().str() + Twine(cnt++).str(); } diff --git a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp index 459f67f49..653041f2f 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp @@ -25,6 +25,7 @@ #include "byteir/Utils/Utils.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" @@ -61,7 +62,8 @@ bool isFusibleCandidate(Operation *op) { (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || isSplatMhloConstantLike(op) || - isa(op) || + isa(op) || isCustomMhloRngOp(op)); } @@ -71,7 +73,7 @@ bool isFusibleStart(Operation *op) { return true; } bool isFusibleTrigger(Operation *op) { if (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op) || isCustomMhloRngOp(op)) { + isa(op) || isCustomMhloRngOp(op)) { return true; } @@ -91,12 +93,16 @@ bool isFusibleTrigger(Operation *op) { return false; } -bool isFusibleWith(Operation *target, Operation * /*start*/) { +bool isFusibleWith(Operation *target, Operation *start) { + if (isa(target)) { + return start->hasTrait<::mlir::OpTrait::Elementwise>(); + } + return target->hasTrait<::mlir::OpTrait::Elementwise>() || target->hasTrait() || isSplatMhloConstantLike(target) || - isa( - target) || + isa(target) || isCustomMhloRngOp(target); } @@ -109,7 +115,8 @@ bool isFusibleWithNoElementwiseFuse(Operation *target, Operation * /*start*/) { bool isValidSingleOp(Operation *op) { return op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op) || + isa(op) || isCustomMhloRngOp(op); } @@ -197,6 +204,40 @@ static GenericFuserConfig config_concat_slice_fuse{ elementwise::concat_slice::isValidSingleOp, elementwise::concat_slice::isValidFusionPattern}; } // namespace concat_slice +namespace insert_slice_with_elementwise { +bool isFusibleCandidate(Operation *op) { + return llvm::isa(op) || + op->hasTrait<::mlir::OpTrait::Elementwise>(); +} + +bool isFusibleTrigger(Operation *op) { + return llvm::isa(op) || + op->hasTrait<::mlir::OpTrait::Elementwise>(); +} + +bool isFusibleStart(Operation *op) { return llvm::isa(op); } + +bool isFusibleWith(Operation *target, Operation *start) { + if (llvm::isa(start)) { + return llvm::isa(target); + } + return llvm::isa(target) || + target->hasTrait<::mlir::OpTrait::Elementwise>(); +} + +bool isValidSingleOp(Operation *op) { return false; } + +bool isValidFusionPattern(const MhloFusionPattern &pattern) { return true; } + +static GenericFuserConfig config_insert_slice_with_elementwise_fuse{ + getByteIRElementwiseFusionAttrName(), + elementwise::insert_slice_with_elementwise::isFusibleCandidate, + elementwise::insert_slice_with_elementwise::isFusibleStart, + elementwise::insert_slice_with_elementwise::isFusibleTrigger, + elementwise::insert_slice_with_elementwise::isFusibleWith, + elementwise::insert_slice_with_elementwise::isValidSingleOp, + elementwise::insert_slice_with_elementwise::isValidFusionPattern}; +} // namespace insert_slice_with_elementwise } // namespace elementwise //===----------------------------------------------------------------------===// @@ -446,6 +487,40 @@ struct ConcatSliceFusionPass : public GenericFusionPass { } }; +struct InsertSliceWithElementwiseFusionPass + : public GenericFusionPass { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + InsertSliceWithElementwiseFusionPass) + + InsertSliceWithElementwiseFusionPass() : GenericFusionPass(true) {} + + /// Returns the command-line argument attached to this pass. + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("fuse-insert-slice-with-elemwise"); + } + ::llvm::StringRef getArgument() const override { + return "fuse-insert-slice-with-elemwise"; + } + + ::llvm::StringRef getDescription() const override { + return "Fuse insertSliceOp with elementwiseOp"; + } + + /// Returns the derived pass name. + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("InsertSliceWithElemwiseFusion"); + } + ::llvm::StringRef getName() const override { + return "InsertSliceWithElemwiseFusion"; + } + + const GenericFuserConfig &getConfig() { + return elementwise::insert_slice_with_elementwise:: + config_insert_slice_with_elementwise_fuse; + } +}; + // a derived fusion pass for matmul epilogue fusion struct MatmulEpilogueFusionPass : public GenericFusionPass { @@ -559,6 +634,11 @@ mlir::createConcatSliceFusionPass() { return std::make_unique(); } +std::unique_ptr> +mlir::createInsertSliceWithElemwiseFusionPass() { + return std::make_unique(); +} + std::unique_ptr> mlir::createMatmulEpilogueFusionPass() { return std::make_unique(); diff --git a/compiler/lib/Pipelines/ByreOpt.cpp b/compiler/lib/Pipelines/ByreOpt.cpp index 111429f1e..e11a48a4b 100644 --- a/compiler/lib/Pipelines/ByreOpt.cpp +++ b/compiler/lib/Pipelines/ByreOpt.cpp @@ -18,6 +18,7 @@ #include "byteir/Pipelines/ByreOpt.h" #include "byteir/Conversion/MemrefToByre/MemrefToByre.h" +#include "byteir/Dialect/MemRef/Transforms/RemoveCopy.h" #include "byteir/Conversion/ToByre/ToByre.h" #include "byteir/Dialect/Byre/ByreDialect.h" @@ -43,6 +44,7 @@ void createByreOptPipelineImpl(OpPassManager &pm, const std::string &entryFunc, entryFunc)); pm.addPass(createConvertFuncAndCallToByrePass(appendArgTypes)); + pm.addNestedPass(createRemoveCopyPass()); // only applied on entry point function OpPassManager anchoredPM(func::FuncOp::getOperationName()); diff --git a/compiler/lib/Pipelines/CMakeLists.txt b/compiler/lib/Pipelines/CMakeLists.txt index 69d24e566..2bd772b9e 100644 --- a/compiler/lib/Pipelines/CMakeLists.txt +++ b/compiler/lib/Pipelines/CMakeLists.txt @@ -30,6 +30,7 @@ add_mlir_library(ByteIRPipelines ByteIRMemRefPasses ByteIRTransforms ByteIRUtils + MLIRBufferTransforms LINK_LIBS PUBLIC ByteIRGPUPipelines @@ -52,4 +53,5 @@ add_mlir_library(ByteIRPipelines ByteIRToByre ByteIRToGPU ByteIRToLinalg + MLIRBufferTransforms ) \ No newline at end of file diff --git a/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp b/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp index 3e389a18f..e9adedc2a 100644 --- a/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/ElementwiseCodegen.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/SmallSet.h" @@ -181,18 +182,73 @@ void createGPUTileElementwiseTransformImpl(OpPassManager &pm, config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, Value pdlV) { - auto tileConfig = - getTileConfig(llvm::cast(op), warpSize, blockSize) - .value(); + auto genericOp = llvm::cast(op); + auto tileConfig = getTileConfig(genericOp, warpSize, blockSize).value(); + int64_t numTiled = getNumTiledLoops(tileConfig.tileSizes); auto pdlType = pdl::OperationType::get(b.getContext()); + int64_t numLoops = genericOp.getNumLoops(); + SmallVector interchange; + for (size_t i = 0; i < numTiled; ++i) + interchange.emplace_back(i); + + auto isPermute = [&](const SmallVector &arr) { + llvm::DenseSet elems(arr.begin(), arr.end()); + if (elems.size() != arr.size()) + return false; + int64_t maxElement = *std::max_element(arr.begin(), arr.end()); + if (maxElement != elems.size() - 1) + return false; + return true; + }; + + genericOp.walk([&](tensor::ExtractOp extractOp) { + if (llvm::all_of(extractOp.getResult().getUsers(), [](Operation *user) { + return llvm::isa(user); + })) { + if (llvm::all_of(extractOp.getIndices(), [](Value idx) { + if (auto defOp = idx.getDefiningOp()) { + return true; + } + return false; + })) { + SmallVector extractIndices; + for (auto idx : extractOp.getIndices()) { + auto defOp = idx.getDefiningOp(); + extractIndices.emplace_back(defOp.getDim()); + } + + if (extractIndices.size() != numTiled || !isPermute(extractIndices)) { + return WalkResult::advance(); + } + SmallVector tensorShape = + llvm::to_vector(extractOp.getTensor().getType().getShape()); + int64_t lastDim = -1; + for (int64_t i = tensorShape.size() - 1; i >= 0; --i) { + if (tensorShape[i] > 1) { + lastDim = i; + break; + } + } + + if (lastDim != -1) { + int64_t len = tensorShape.size(); + std::swap(interchange[extractIndices[lastDim]], + interchange[extractIndices[len - 1]]); + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + b.create( /* tiledOp type*/ pdlType, /* loops type */ - SmallVector(getNumTiledLoops(tileConfig.tileSizes), pdlType), + SmallVector(numTiled, pdlType), /* target */ pdlV, /* stop */ Value(), /* tillSizes */ b.getI64ArrayAttr(tileConfig.tileSizes), - /* interchange */ b.getI64ArrayAttr({}), + /* interchange */ b.getI64ArrayAttr({interchange}), /* keep_intermediate*/ false); }; diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index 842901b04..9b4ad345e 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -23,12 +23,16 @@ #include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/SCF/Passes.h" #include "byteir/Dialect/Transform/Transforms/TransformDialectInterpreter.h" +#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h" +#include "byteir/Dialect/Vector/Transforms/Passes.h" +#include "byteir/Dialect/Vector/Transforms/VectorWarpDistribute.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" #include "byteir/Pipelines/GPU/MappingForall.h" #include "byteir/Transforms/Passes.h" #include "byteir/Transforms/RemoveFuncBody.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" @@ -81,6 +85,27 @@ void createElementwiseGPUOptPipelineImpl(OpPassManager &pm, void createReductionGPUOptPipelineImpl(OpPassManager &pm) { GPUMappingForallOptions options; options.funcAnchor = getByteIRReductionFusionAttrName().str(); + options.blockDimsHint = llvm::cl::KernelDims{256, 1, 1}; + // vector redution to gpu shuffle & lowering + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass( + createMoveForallRegionIntoWarpOpPass(/* warpSize = */ 32)); + VectorWarpDistributePassOptions options; + options.warpOpToSCF = true; + options.distributeTransferWriteOps = true; + options.hoistUniform = true; + options.propagateDistribution = true; + anchoredPM.addPass(createVectorWarpDistributePass(options)); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createScalarVectorLoweringPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createConvertVectorToSCFPass()); + pm.addNestedPass(createAnchoredPipelinePass( + getByteIRReductionFusionAttrName(), anchoredPM)); + } + createGPUMappingForallTransform(pm, options); pm.addPass(createTransformDialectInterpreter(true)); pm.addPass(createCSEPass()); diff --git a/compiler/lib/Pipelines/GPU/MappingForall.cpp b/compiler/lib/Pipelines/GPU/MappingForall.cpp index 6e0ef686a..1eadc7768 100644 --- a/compiler/lib/Pipelines/GPU/MappingForall.cpp +++ b/compiler/lib/Pipelines/GPU/MappingForall.cpp @@ -19,6 +19,7 @@ #include "byteir/Conversion/ToGPU/ToGPU.h" #include "byteir/Conversion/ToLLVM/ToLLVM.h" +#include "byteir/Dialect/GPU/TransformOps/GPUExtTransformOps.h" #include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" #include "byteir/Dialect/Transform/IR/TransformExtOps.h" #include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" @@ -40,6 +41,7 @@ using namespace mlir; namespace { static constexpr int64_t kMaximumBlockDim = 1024; +static constexpr int64_t kNumGroup = 4; struct MappingForallConfig { SmallVector blockDims; @@ -60,7 +62,7 @@ bool isMappedToGPUBlocks(scf::ForallOp forallOp) { bool isMappedToGPUThreads(scf::ForallOp forallOp) { if (auto mapping = forallOp.getMappingAttr()) { - if (llvm::any_of(mapping.getValue(), [](Attribute attr) { + if (llvm::all_of(mapping.getValue(), [](Attribute attr) { return isa(attr); })) { return true; @@ -70,71 +72,167 @@ bool isMappedToGPUThreads(scf::ForallOp forallOp) { return false; } -void updateBlockDims(scf::ForallOp forallOp, SmallVector &blockDims) { +bool isMappedToGPUWarps(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::all_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +bool isMappedToGPUWarpGroups(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::all_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +bool isNonLinearMappingMode(scf::ForallOp forallOp) { + return !llvm::any_of(forallOp.getMapping()->getValue(), [](Attribute a) { + return cast(a).isLinearMapping(); + }); +} + +SmallVector getForallMappingSize(scf::ForallOp forallOp, + const int64_t warpSize) { + int64_t scale = 1; + + if (isMappedToGPUWarps(forallOp)) + scale = warpSize; + if (isMappedToGPUWarpGroups(forallOp)) + scale = warpSize * kNumGroup; + SmallVector mappingSizes; for (auto &&[lb, ub, step, mappingAttr] : llvm::zip( forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), forallOp.getMappingAttr().getValue())) { - if (auto threadMapping = - llvm::dyn_cast_or_null(mappingAttr)) { - auto numIterations = constantTripCount(lb, ub, step); - auto threadIdx = threadMapping.getMappingId(); - if (numIterations.has_value()) { - blockDims[threadIdx] = - std::max(blockDims[threadIdx], numIterations.value()); - } + auto numIterations = constantTripCount(lb, ub, step); + if (numIterations.has_value()) { + mappingSizes.emplace_back(numIterations.value()); + } else { + mappingSizes.emplace_back(ShapedType::kDynamic); } } + mappingSizes[0] *= scale; + return mappingSizes; } std::optional -getMappingForallConfig(scf::ForallOp forallOp) { +getMappingForallConfig(scf::ForallOp forallOp, const int64_t warpSize, + const SmallVector &blockDimsHint) { if (!isMappedToGPUBlocks(forallOp)) return std::nullopt; SmallVector blockDims{1, 1, 1}; auto &&block = forallOp.getRegion().front(); - for (auto &&nestedForall : block.getOps()) { - if (isMappedToGPUThreads(nestedForall)) { - updateBlockDims(nestedForall, blockDims); + auto hasDynamicDims = [&]() -> bool { + return llvm::any_of(blockDims, + [](int64_t x) { return x == ShapedType::kDynamic; }); + }; + forallOp->walk([&](scf::ForallOp nestedForall) { + if (!isMappedToGPUBlocks(nestedForall) && + isNonLinearMappingMode(nestedForall)) { + SmallVector mappingSizes = + getForallMappingSize(nestedForall, warpSize); + for (auto &&[val, mappingAttr] : + llvm::zip(mappingSizes, nestedForall.getMappingAttr().getValue())) { + auto threadIdx = + cast(mappingAttr).getMappingId(); + if (val == ShapedType::kDynamic) { + blockDims[threadIdx] = ShapedType::kDynamic; + break; + } else { + blockDims[threadIdx] = std::max(blockDims[threadIdx], val); + } + } } + }); + + if (hasDynamicDims()) { + return MappingForallConfig{blockDimsHint}; + } + + forallOp->walk([&](scf::ForallOp nestedForall) { + if (!isMappedToGPUBlocks(nestedForall) && + !isNonLinearMappingMode(nestedForall)) { + SmallVector mappingSizes = + getForallMappingSize(nestedForall, warpSize); + int64_t mul = 1; + for (size_t i = 0; i < mappingSizes.size(); ++i) { + if (mappingSizes[i] == ShapedType::kDynamic) { + mul = ShapedType::kDynamic; + break; + } + mul *= mappingSizes[i]; + } + + if (mul == ShapedType::kDynamic) { + blockDims[0] = ShapedType::kDynamic; + } else if (!hasDynamicDims()) { + for (size_t i = 0; i < blockDims.size(); ++i) { + mul = (mul + blockDims[i] - 1) / blockDims[i]; + } + blockDims[0] *= mul; + } + } + }); + + if (hasDynamicDims()) { + return MappingForallConfig{blockDimsHint}; } if (blockDims[0] * blockDims[1] * blockDims[2] > kMaximumBlockDim) { return std::nullopt; } + while (blockDims[0] * blockDims[1] * blockDims[2] * 2 <= warpSize) { + blockDims[0] *= 2; + } return MappingForallConfig{blockDims}; } -void createGPUMappingForallTransformImpl(OpPassManager &pm, - const std::string &anchor, - const std::string &prefix) { +void createGPUMappingForallTransformImpl( + OpPassManager &pm, const std::string &anchor, const std::string &prefix, + const int64_t &warpSize, const llvm::cl::KernelDims &blockDimsHint) { TransformInsertionConfig config; config.funcAnchor = anchor; config.matchPrefix = prefix; + SmallVector blockDimsHintVec{blockDimsHint.x, blockDimsHint.y, + blockDimsHint.z}; config.opFilter = [=](Operation *op) { if (auto forallOp = llvm::dyn_cast_or_null(op)) { - return getMappingForallConfig(forallOp).has_value(); + return getMappingForallConfig(forallOp, warpSize, blockDimsHintVec) + .has_value(); } return false; }; config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, Value pdlV) { - auto mappingConfig = - getMappingForallConfig(llvm::cast(op)).value(); + auto mappingConfig = getMappingForallConfig(llvm::cast(op), + warpSize, blockDimsHintVec) + .value(); auto pdlType = pdl::OperationType::get(b.getContext()); - auto launchOp = b.create( + auto launchOp = b.create( /* result type */ pdlType, /* target */ pdlV, /* grid_dims */ llvm::ArrayRef{}, /* generate_gpu_launch */ true); - b.create( + b.create( /* result type*/ pdlType, /* target */ launchOp.getResult(), - /* block_dims */ mappingConfig.blockDims, + /* block_dims */ + mappingConfig.blockDims, /* sync_after_distribute*/ true, - /* warp_dims */ 32); + /* warp_size */ warpSize); }; pm.addPass(createGenericTransformInsertionPass(config)); @@ -144,5 +242,6 @@ void createGPUMappingForallTransformImpl(OpPassManager &pm, void mlir::createGPUMappingForallTransform( OpPassManager &pm, const GPUMappingForallOptions &options) { invokeOpPassPipelineBuilder(createGPUMappingForallTransformImpl, pm, - options.funcAnchor, options.annotatePrefix); + options.funcAnchor, options.annotatePrefix, + options.warpSize, options.blockDimsHint); } diff --git a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp index 5dce9f2bc..b791546b8 100644 --- a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp @@ -27,6 +27,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" @@ -52,6 +53,7 @@ void createNVVMCodegenPipelineImpl(OpPassManager &pm, pm.addPass(createSimplifyLinearizedIndexPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); + pm.addNestedPass(createConvertVectorToLLVMPass()); pm.addNestedPass(createGPUToNVVMExtPass( useBarePtrCallConv, mlir::kDeriveIndexBitwidthFromDataLayout, gpuArch)); pm.addPass(createCSEPass()); diff --git a/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp b/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp index b602f77be..f930c513e 100644 --- a/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp @@ -22,6 +22,7 @@ #include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" #include "byteir/Dialect/Transform/IR/TransformExtOps.h" #include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" +#include "byteir/Dialect/Vector/Transforms/MoveForallRegionIntoWarpOp.h" #include "byteir/Pipelines/Common/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" @@ -32,6 +33,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/SmallSet.h" @@ -45,6 +47,9 @@ namespace { //----------------------------------------------------------------------------// // TODO: move to common header +static constexpr int64_t kGridSplitThreshold = 4096; +static constexpr int64_t kGridTileNumThreshold = 64; + constexpr bool isPowerOf2(int64_t n) { return (!(n & (n - 1))); } constexpr int64_t nextPowerOf2(int64_t n) { @@ -85,6 +90,28 @@ bool isMappedToGPUBlocks(Operation *op) { return false; } +bool isMappedToGPUWarps(scf::ForallOp forallOp) { + if (auto mapping = forallOp.getMappingAttr()) { + if (llvm::any_of(mapping.getValue(), [](Attribute attr) { + return isa(attr); + })) { + return true; + } + } + + return false; +} + +bool isMappedToGPUWarps(Operation *op) { + if (auto forOp = llvm::dyn_cast_or_null(op)) { + return isMappedToGPUWarps(forOp); + } + if (auto forallOp = llvm::dyn_cast_or_null(op)) { + return isMappedToGPUWarps(forallOp); + } + return false; +} + bool isMappedToGPUThreads(scf::ForOp forOp) { if (auto loopToSIMTAttr = forOp->getAttrOfType(getLoopToSIMTAttrName())) { @@ -133,6 +160,23 @@ std::optional getReductionDim(linalg::GenericOp genericOp) { return std::nullopt; } +int64_t getParallelism(linalg::GenericOp genericOp) { + SmallVector parallelDims; + genericOp.getParallelDims(parallelDims); + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + if (parallelDims.size() == 0) { + return 1; + } + int64_t parallelism = 1; + for (auto idx : parallelDims) { + if (ShapedType::isDynamic(staticLoopRanges[idx])) { + return ShapedType::kDynamic; + } + parallelism *= staticLoopRanges[idx]; + } + return parallelism; +} + std::optional getOperandReductionDim(OpOperand &operand) { auto genericOp = llvm::dyn_cast(operand.getOwner()); if (!genericOp) @@ -160,7 +204,7 @@ std::optional getOperandReductionDim(OpOperand &operand) { SmallVector getDynamicDims(linalg::GenericOp genericOp) { auto staticLoopRanges = genericOp.getStaticLoopRanges(); SmallVector ret; - for (int64_t i = 0; i < staticLoopRanges.size(); ++i) { + for (size_t i = 0; i < staticLoopRanges.size(); ++i) { if (ShapedType::isDynamic(staticLoopRanges[i])) { ret.push_back(i); } @@ -168,15 +212,46 @@ SmallVector getDynamicDims(linalg::GenericOp genericOp) { return ret; } +static void promoteAllTensorsWithinOp(ImplicitLocOpBuilder &b, Value parentOp, + gpu::AddressSpaceAttr memAddrSpace) { + // get corresponding empty tensor + auto emptyTensorType = transform::OperationType::get( + b.getContext(), tensor::EmptyOp::getOperationName()); + auto emptyTensor = b.create( + emptyTensorType, parentOp, tensor::EmptyOp::getOperationName()); + + // // empty tensor to alloc tensor + auto allocTensorType = transform::OperationType::get( + b.getContext(), bufferization::AllocTensorOp::getOperationName()); + auto allocTensor = b.create( + allocTensorType, emptyTensor); + auto memorySpaceAttrName = + bufferization::AllocTensorOp::getMemorySpaceAttrName(OperationName( + bufferization::AllocTensorOp::getOperationName(), b.getContext())); + + Value paramV = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ memAddrSpace); + b.create( + /* target */ allocTensor, + /* name */ memorySpaceAttrName, + /* param */ paramV); +} + //----------------------------------------------------------------------------// // configuration structs //----------------------------------------------------------------------------// +// tag for linalg operation static constexpr StringLiteral kGridReduction = "__grid_reduction__"; static constexpr StringLiteral kBlockReduction = "__block_reduction__"; static constexpr StringLiteral kWarpReduction = "__warp_reduction__"; static constexpr StringLiteral kThreadReduction = "__thread_reduction__"; +// tag for forall operation +static constexpr StringLiteral kMapInnerLinalgReductionDimToThread = + "__map_inner_linalg_reduction_dim_to_thread__"; + struct ProducerSelector { uint64_t operandNumber; llvm::StringRef opName; @@ -220,7 +295,9 @@ struct ProducerSelector { struct GridSplitConfig { int64_t splitFactor; - int64_t dimension; + int64_t redDim; + int64_t numLoops; + gpu::MappingId mapping; void apply(ImplicitLocOpBuilder &b, Value pdlV); }; @@ -229,11 +306,11 @@ struct GridTileConfig { SmallVector tileSizes; SmallVector mapping; std::vector fuseCandidates; - int64_t padDim; - SmallVector padValues; - int64_t warpSize; + int64_t parallelismPerBlock; + bool asNumThreads; + bool mapReductionDimToThread; - void apply(ImplicitLocOpBuilder &b, Value pdlV, bool usingForall); + void apply(ImplicitLocOpBuilder &b, Value pdlV); }; struct BlockSplitConfig { @@ -246,11 +323,15 @@ struct BlockSplitConfig { }; struct BlockTileConfig { + bool usingTileReduction; + bool mappingToWarp; + int64_t numLoops; + int64_t redDim; SmallVector tileSizes; SmallVector mapping; std::vector fuseCandidates; - void apply(ImplicitLocOpBuilder &b, Value pdlV, bool usingForall); + void apply(ImplicitLocOpBuilder &b, Value pdlV); }; struct ThreadTileConfig { @@ -283,15 +364,25 @@ transform::TileUsingForallOp tileToForallAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, const SmallVector &tileSizes, const SmallVector &mapping, - const std::vector &fuseCandidates) { + const std::vector &fuseCandidates, + bool asNumThreads) { SmallVector toBeFused; processProducerSelectors(b, fuseCandidates, toTile, toBeFused); - auto tileOp = b.create( - /* target */ toTile, - /* staticTileSizes */ tileSizes, - /* ctor tag */ transform::TileSizesSpec(), - /* mapping */ b.getArrayAttr(mapping)); + transform::TileUsingForallOp tileOp; + if (asNumThreads) { + tileOp = b.create( + /* target */ toTile, + /* numThreads */ tileSizes, + /* ctor tag */ transform::NumThreadsSpec(), + /* mapping */ b.getArrayAttr(mapping)); + } else { + tileOp = b.create( + /* target */ toTile, + /* staticTileSizes */ tileSizes, + /* ctor tag */ transform::TileSizesSpec(), + /* mapping */ b.getArrayAttr(mapping)); + } for (auto &&producerOp : toBeFused) { b.create( /* producerOp */ producerOp, @@ -300,6 +391,31 @@ tileToForallAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, return tileOp; } +void tileReductionToForallAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, + const SmallVector &numThreads, + const Attribute mapping, + StringRef annotation) { + auto tileReductionOp = b.create( + /* target */ toTile, + /* num_threads */ numThreads, + /*staticTileSizes*/ SmallVector{}, + /*mapping*/ b.getArrayAttr(mapping)); + + b.create( + /* target */ tileReductionOp.getSplitLinalgOp(), + /* name */ annotation, + /* param */ Value()); + + b.create( + /* target */ tileReductionOp.getCombiningLinalgOp(), + /* name */ annotation, + /* param */ Value()); + + b.create( + /* producerOp */ tileReductionOp.getFillOp(), + /* containingOp */ tileReductionOp.getForallOp()); +} + void tileToSCFForAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, const SmallVector &tileSizes, const SmallVector &mapping) { @@ -324,21 +440,13 @@ void tileToSCFForAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, void GridSplitConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { if (splitFactor) { - auto splitted = b.create( - /* target */ pdlV, - /* splitFactor */ splitFactor, - /* insertSplitDimension */ dimension, - /* innerParallel */ false, - /* useScalingAlgorithm */ false, - /* useAlloc */ false); - b.create( - /* target */ splitted.getSplitLinalgOp(), - /* name */ kGridReduction, - /* param */ Value()); - b.create( - /* target */ splitted.getCombiningLinalgOp(), - /* name */ kGridReduction, - /* param */ Value()); + auto mappingAttr = gpu::GPUBlockMappingAttr::get(b.getContext(), mapping); + + SmallVector numThreads(numLoops, 0); + numThreads[redDim] = splitFactor; + + tileReductionToForallAndFuseImpl(b, pdlV, numThreads, mappingAttr, + kGridReduction); } else { b.create( /* target */ pdlV, @@ -347,36 +455,46 @@ void GridSplitConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { } } -void GridTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV, - bool usingForall) { - if (usingForall) { - auto mappingAttrs = llvm::to_vector( - llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { - return gpu::GPUBlockMappingAttr::get(b.getContext(), dim); - })); - auto tiledOp = tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, - fuseCandidates); - if (padDim >= 0) { - b.create( - TypeRange{pdlV.getType(), pdlV.getType(), pdlV.getType()}, - tiledOp.getTiledOp(), - /*padding_values=*/b.getArrayAttr(padValues), - /*padding_dimensions=*/ - b.getI64ArrayAttr({padDim}), - /*padToMultipleOf=*/b.getArrayAttr({b.getI64IntegerAttr(warpSize)}), - /*pack_paddings=*/ArrayAttr{}, - /*transpose_paddings=*/ArrayAttr{}, - /*copyBack=*/transform::PadOp::kCopyOpNone); +void GridTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { + return gpu::GPUBlockMappingAttr::get(b.getContext(), dim); + })); + auto tiledOp = + tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, fuseCandidates, + /* asNumThreads = */ asNumThreads); + + if (mapReductionDimToThread) { + b.create( + /* target */ tiledOp.getForallOp(), + /* name */ kMapInnerLinalgReductionDimToThread, + /* param */ Value()); + } else if (!asNumThreads && parallelismPerBlock > 1) { + SmallVector forTileSizes = tileSizes; + for (size_t i = 0; i < forTileSizes.size(); ++i) { + if (forTileSizes[i]) + forTileSizes[i] = 1; } - } else { - static constexpr std::array mappings{ - getBlockIdXName(), getBlockIdYName(), getBlockIdZName()}; - auto mappingAttrs = llvm::to_vector( - llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { - return b.getStringAttr(mappings[static_cast(dim)]); - })); - tileToSCFForAndFuseImpl(b, pdlV, tileSizes, mappingAttrs); + auto pdlType = pdl::OperationType::get(b.getContext()); + auto fuseOp = b.create( + /* transformed */ pdlType, + /* loops */ + SmallVector(getNumTiledLoops(forTileSizes), pdlType), + /* target */ tiledOp.getTiledOp(), + /* tile_sizes */ b.getI64ArrayAttr(forTileSizes), + /* tile_interchange */ ArrayAttr()); + + b.create( + fuseOp.getLoops()[0], [](OpBuilder &b, Location loc) { + b.create(loc); + b.create( + loc); + }); + b.create( + /* target */ fuseOp.getTransformed(), + /* name */ kBlockReduction, + /* param */ Value()); } } @@ -468,22 +586,80 @@ void BlockSplitConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { /* param */ paramV); } -void BlockTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV, - bool usingForall) { - if (usingForall) { - auto mappingAttrs = llvm::to_vector( - llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { - return gpu::GPUThreadMappingAttr::get(b.getContext(), dim); - })); - tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, fuseCandidates); +void BlockTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { + + if (mappingToWarp) { + b.create( + /* target */ pdlV, + /* name */ kWarpReduction, + /* param */ Value()); } else { - static constexpr std::array mappings{ - getThreadIdXName(), getThreadIdYName(), getThreadIdZName()}; auto mappingAttrs = llvm::to_vector( llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { - return b.getStringAttr(mappings[static_cast(dim)]); + return gpu::GPUThreadMappingAttr::get(b.getContext(), dim); })); - tileToSCFForAndFuseImpl(b, pdlV, tileSizes, mappingAttrs); + if (usingTileReduction) { + SmallVector numThreads = tileSizes; + SmallVector staticTileSizes = llvm::to_vector(llvm::map_range( + tileSizes, [](int64_t val) -> int64_t { return val != 0; })); + + auto tiledRedutionOp = b.create( + /* target */ pdlV, + /* num_threads */ numThreads, + /*staticTileSizes*/ staticTileSizes, + /*mapping*/ b.getArrayAttr(mappingAttrs)); + + b.create( + /* producerOp */ tiledRedutionOp.getFillOp(), + /* containingOp */ tiledRedutionOp.getForallOp()); + + // attch block_redution to combineOp + b.create( + /* target */ tiledRedutionOp.getCombiningLinalgOp(), + /* name */ kBlockReduction, + /* param */ Value()); + + if (numLoops > 1) { + SmallVector combineTileSizes(numLoops, 1); + // excluding reduction dim. + combineTileSizes[redDim] = 0; + auto tileCombineOp = b.create( + /* target */ tiledRedutionOp.getCombiningLinalgOp(), + /* staticTileSizes */ combineTileSizes); + + b.create( + tileCombineOp.getLoops()[0], [](OpBuilder &b, Location loc) { + b.create(loc); + b.create( + loc); + }); + + b.create( + /* target */ tileCombineOp.getTiledLinalgOp(), + /* name */ kBlockReduction, + /* param */ Value()); + } + + { + // get corresponding empty tensor + auto forall = b.create( + tiledRedutionOp.getForallOp().getType(), + tiledRedutionOp.getForallOp(), + /* isolated_from_above */ false, + /* allow_empty_results */ false, + /* op_name */ b.getStringAttr(scf::ForallOp::getOperationName()), + /* deduplicate */ false, + /* nth_parent */ 1); + auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( + b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + + promoteAllTensorsWithinOp(b, forall, workgroupMemoryAddressSpace); + } + } else { + auto tiledOp = + tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, + fuseCandidates, /* asNumThreads = */ false); + } } } @@ -521,13 +697,35 @@ bool isReductionOp(linalg::GenericOp genericOp) { return false; if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap affineMap) { - return affineMap.isProjectedPermutation(/* allowZeroInResults */ false); + return affineMap.isProjectedPermutation( + /* allowZeroInResults */ false); })) return false; return true; } +bool isRedDimInInnermostLoop(linalg::GenericOp genericOp) { + if (!isReductionOp(genericOp)) + return false; + int64_t numLoops = genericOp.getNumLoops(); + auto maybeRedDim = getReductionDim(genericOp); + if (!maybeRedDim.has_value()) { + return false; + } + int64_t redDim = maybeRedDim.value(); + for (auto &&affineMap : genericOp.getIndexingMapsArray()) { + if (affineMap.isPermutation()) { + auto dim = affineMap.getDimPosition(numLoops - 1); + if (dim == redDim) { + return true; + } + break; + } + } + return false; +} + bool isGridReductionOp(linalg::GenericOp genericOp) { if (!isReductionOp(genericOp)) return false; @@ -558,6 +756,32 @@ bool isBlockReductionOp(linalg::GenericOp genericOp) { return false; } +bool isMappedReductionToThread(linalg::GenericOp genericOp) { + if (!isReductionOp(genericOp)) + return false; + + if (auto forallOp = + llvm::dyn_cast_or_null(genericOp->getParentOp())) { + return forallOp->hasAttr(kMapInnerLinalgReductionDimToThread); + } + return false; +} + +bool isWarpReductionOp(linalg::GenericOp genericOp) { + if (!isReductionOp(genericOp)) + return false; + + // early return for manual tag + if (genericOp->hasAttr(kWarpReduction)) + return true; + + // nested in op which is mapped to GPU warp + if (isMappedToGPUWarps(genericOp->getParentOp())) + return true; + + return false; +} + bool isThreadReductionOp(linalg::GenericOp genericOp) { if (!isReductionOp(genericOp)) return false; @@ -578,14 +802,33 @@ std::optional getGridSplitConfig(linalg::GenericOp genericOp, if (!isGridReductionOp(genericOp)) return std::nullopt; + int64_t numLoops = genericOp.getNumLoops(); auto redDim = *getReductionDim(genericOp); auto staticLoopRanges = genericOp.getStaticLoopRanges(); - if (ShapedType::isDynamic(staticLoopRanges[redDim]) || - staticLoopRanges[redDim] % splitFactor != 0 || - staticLoopRanges[redDim] <= 1024) + int64_t parallelism = getParallelism(genericOp); + + if (parallelism > 1 || parallelism == ShapedType::kDynamic) { return std::nullopt; + } + + if (!isRedDimInInnermostLoop(genericOp)) { + return std::nullopt; + } - return GridSplitConfig{splitFactor, redDim ? redDim - 1 : redDim}; + int64_t redDimSize = staticLoopRanges[redDim]; + if (isRedDimInInnermostLoop(genericOp)) { + if (ShapedType::isDynamic(redDimSize) || + staticLoopRanges[redDim] <= kGridSplitThreshold) { + return std::nullopt; + } + } + + // at least 2: split reduction & grid tile + int64_t blockMappingNum = std::max(numLoops, static_cast(2)); + return GridSplitConfig{splitFactor, redDim, numLoops, + static_cast( + static_cast(gpu::MappingId::LinearDim0) + + blockMappingNum - 1)}; } std::optional getGridTileConfig(linalg::GenericOp genericOp, @@ -596,37 +839,100 @@ std::optional getGridTileConfig(linalg::GenericOp genericOp, int64_t numLoops = genericOp.getNumLoops(); SmallVector tileSizes(numLoops, 1); + auto redDim = getReductionDim(genericOp).value(); + int64_t totalParallelism = getParallelism(genericOp); + tileSizes[redDim] = 0; + + bool asNumThreads = false; + bool mapReductionDimToThread = true; auto loopSizes = - cast(genericOp.getOperation()).computeStaticLoopSizes(); + cast(genericOp.getOperation()).getStaticLoopRanges(); + + int64_t parallelismPerBlock = blockSize; + int64_t redDimSize = loopSizes[redDim]; + + if (isRedDimInInnermostLoop(genericOp)) { + if (!ShapedType::isDynamic(redDimSize) && redDimSize < warpSize && + totalParallelism / blockSize >= kGridTileNumThreshold) { + parallelismPerBlock = blockSize; + mapReductionDimToThread = true; + } else { + parallelismPerBlock = 1; + mapReductionDimToThread = false; + } + } else { + if (!ShapedType::isDynamic(totalParallelism)) { + while (totalParallelism / parallelismPerBlock < kGridTileNumThreshold && + parallelismPerBlock > 1) { + parallelismPerBlock /= 2; + } + } + } - int64_t padDim = -1; + if (parallelismPerBlock == 1) { + mapReductionDimToThread = false; + } + + SmallVector parallelDims; + genericOp.getParallelDims(parallelDims); + int64_t remainParallelism = parallelismPerBlock; + int64_t lastTilingDim = -1; for (auto &&affineMap : genericOp.getIndexingMapsArray()) { if (affineMap.isPermutation()) { - auto dim = affineMap.getDimPosition(numLoops - 1); - padDim = dim; - if (loopSizes[dim] > warpSize) { - tileSizes[dim] *= warpSize; - break; + for (int64_t i = numLoops - 1; i >= 0; --i) { + if (remainParallelism == 1) { + break; + } + auto dim = affineMap.getDimPosition(i); + if (llvm::find(parallelDims, dim) == parallelDims.end()) + continue; + if (ShapedType::isDynamic(loopSizes[dim])) { + tileSizes[dim] = remainParallelism; + remainParallelism = 1; + } else { + int64_t dimSize = nextPowerOf2(loopSizes[dim]); + if (dimSize <= remainParallelism) { + tileSizes[dim] = 0; + remainParallelism /= dimSize; + } else { + tileSizes[dim] = remainParallelism; + remainParallelism = 1; + } + } + lastTilingDim = dim; } + break; } } - auto redDim = getReductionDim(genericOp).value(); - tileSizes[redDim] = 0; - std::vector fuseCandidates; for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) { ProducerSelector::detectFillOperand(&opOperand, fuseCandidates); } auto numTiledLoops = getNumTiledLoops(tileSizes); - if (!numTiledLoops) { - tileSizes[redDim] = loopSizes[redDim]; + if (numTiledLoops == 0) { numTiledLoops = 1; + if (lastTilingDim != -1) { + // parallelism is too small. + // using last tiling dimension to generate forallOp with unit mapping + // size. + tileSizes[lastTilingDim] = loopSizes[lastTilingDim]; + } else if (genericOp.hasSingleReductionLoop()) { + if (ShapedType::isDynamic(loopSizes[redDim])) { + asNumThreads = true; + tileSizes[redDim] = 1; + } else { + tileSizes[redDim] = loopSizes[redDim]; + } + } else { + return std::nullopt; + } } + if (numTiledLoops >= 1 && numTiledLoops <= 3) { SmallVector mapping(numLoops, -1); - int64_t dimMapping = static_cast(gpu::MappingId::DimX); + int64_t dimMapping = static_cast(gpu::MappingId::LinearDim0); for (auto &&affineMap : genericOp.getIndexingMapsArray()) { if (affineMap.isPermutation()) { for (int64_t i = numLoops - 1; i >= 0; i--) { @@ -643,24 +949,14 @@ std::optional getGridTileConfig(linalg::GenericOp genericOp, if (mapping.size() != numTiledLoops) return std::nullopt; - SmallVector padValues; - mlir::Builder b(genericOp.getContext()); - for (auto &&operand : genericOp->getOperands()) { - if (auto shapedType = llvm::dyn_cast(operand.getType())) { - padValues.push_back(b.getZeroAttr(shapedType.getElementType())); - } else { - return std::nullopt; - } - } - return GridTileConfig{ tileSizes, llvm::to_vector(llvm::map_range( mapping, [](int64_t i) { return static_cast(i); })), fuseCandidates, - padDim, - padValues, - warpSize}; + parallelismPerBlock, + asNumThreads, + mapReductionDimToThread}; } return std::nullopt; } @@ -728,27 +1024,49 @@ std::optional getBlockTileConfig(linalg::GenericOp genericOp, int64_t numLoops = genericOp.getNumLoops(); SmallVector tileSizes(numLoops, 0); auto loopSizes = - cast(genericOp.getOperation()).computeStaticLoopSizes(); + cast(genericOp.getOperation()).getStaticLoopRanges(); int64_t remainBlockSize = blockSize; auto redDim = getReductionDim(genericOp).value(); - for (int64_t idx = 0; idx < numLoops && remainBlockSize > 1; ++idx) { - if (idx == redDim) - continue; - int64_t curLoopSize2 = nextPowerOf2(loopSizes[idx]); - int64_t curBlockSize = std::min(curLoopSize2, remainBlockSize); - tileSizes[idx] = curLoopSize2 / curBlockSize; - remainBlockSize /= curBlockSize; - } - if (remainBlockSize == blockSize) { - tileSizes[redDim] = loopSizes[redDim]; + bool usingTileReduction = false; + bool mappingToWarp = false; + // mapping to warp redution directly + int64_t redDimSize = loopSizes[redDim]; + if (numLoops == 1 && redDimSize != ShapedType::kDynamic && + redDimSize <= warpSize && isPowerOf2(redDimSize)) { + mappingToWarp = true; + return BlockTileConfig{usingTileReduction, + mappingToWarp, + numLoops, + redDim, + tileSizes, + SmallVector{}, + std::vector{}}; + } else if (isMappedReductionToThread(genericOp)) { + for (int64_t i = 0; i < numLoops; ++i) + tileSizes[i] = 1; + tileSizes[redDim] = 0; + } else { + usingTileReduction = true; + tileSizes[redDim] = blockSize; + if (!ShapedType::isDynamic(redDimSize)) { + if (redDimSize <= blockSize) { + tileSizes[redDim] = std::max(nextPowerOf2(redDimSize), warpSize); + } + + while (tileSizes[redDim] * 2 > redDimSize && + tileSizes[redDim] / 2 >= warpSize) { + tileSizes[redDim] /= 2; + } + } } std::vector fuseCandidates; for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { ProducerSelector::detectPadOperand(opOperand, fuseCandidates); } + for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) { ProducerSelector::detectFillOperand(&opOperand, fuseCandidates); } @@ -761,7 +1079,7 @@ std::optional getBlockTileConfig(linalg::GenericOp genericOp, if (affineMap.isPermutation()) { for (int64_t i = numLoops - 1; i >= 0; i--) { auto dim = affineMap.getDimPosition(i); - if (tileSizes[dim] > 0) { + if (tileSizes[dim] > 0 || usingTileReduction) { mapping[dim] = dimMapping++; } } @@ -770,10 +1088,17 @@ std::optional getBlockTileConfig(linalg::GenericOp genericOp, } mapping.erase(std::remove(mapping.begin(), mapping.end(), -1), mapping.end()); - if (mapping.size() != numTiledLoops) + if (usingTileReduction && mapping.size() != numLoops) + return std::nullopt; + + if (!usingTileReduction && mapping.size() != numTiledLoops) return std::nullopt; return BlockTileConfig{ + usingTileReduction, + mappingToWarp, + numLoops, + redDim, tileSizes, llvm::to_vector(llvm::map_range( mapping, [](int64_t i) { return static_cast(i); })), @@ -787,6 +1112,10 @@ getThreadTileConfig(linalg::GenericOp genericOp) { if (!isThreadReductionOp(genericOp)) return std::nullopt; + if (genericOp.hasDynamicShape()) { + return std::nullopt; + } + int64_t numLoops = genericOp.getNumLoops(); SmallVector parallelTileSizes(numLoops, 1); SmallVector reductionTileSizes(numLoops, 0); @@ -836,9 +1165,11 @@ void createGPUSplitGridReductionTransformImpl(OpPassManager &pm, pm.addPass(createGenericTransformInsertionPass(config)); } -void createGPUTileGridReductionTransformImpl( - OpPassManager &pm, const std::string &anchor, const std::string &prefix, - int64_t warpSize, int64_t blockSize, bool usingForall) { +void createGPUTileGridReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix, + int64_t warpSize, + int64_t blockSize) { TransformInsertionConfig config; config.funcAnchor = anchor; config.matchPrefix = prefix; @@ -854,7 +1185,7 @@ void createGPUTileGridReductionTransformImpl( auto tileConfig = getGridTileConfig(llvm::cast(op), warpSize, blockSize) .value(); - tileConfig.apply(b, pdlV, usingForall); + tileConfig.apply(b, pdlV); }; pm.addPass(createGenericTransformInsertionPass(config)); @@ -886,9 +1217,11 @@ void createGPUSplitBlockReductionTransformImpl(OpPassManager &pm, pm.addPass(createGenericTransformInsertionPass(config)); } -void createGPUTileBlockReductionTransformImpl( - OpPassManager &pm, const std::string &anchor, const std::string &prefix, - int64_t warpSize, int64_t blockSize, bool usingForall) { +void createGPUTileBlockReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix, + int64_t warpSize, + int64_t blockSize) { TransformInsertionConfig config; config.funcAnchor = anchor; config.matchPrefix = prefix; @@ -907,7 +1240,7 @@ void createGPUTileBlockReductionTransformImpl( auto tileConfig = getBlockTileConfig(llvm::cast(op), warpSize, blockSize) .value(); - tileConfig.apply(b, pdlV, usingForall); + tileConfig.apply(b, pdlV); } else if (auto copyOp = llvm::dyn_cast_or_null(op)) { auto tileOp = b.create( /* target */ pdlV, @@ -922,6 +1255,166 @@ void createGPUTileBlockReductionTransformImpl( pm.addPass(createGenericTransformInsertionPass(config)); } +void createGPUTileSplitWarpReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix, + int64_t blockSize, + int64_t warpSize) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + if (isBlockReductionOp(genericOp)) { + int64_t numLoops = genericOp.getNumLoops(); + int64_t redDim = -1; + SmallVector reductionDims; + genericOp.getReductionDims(reductionDims); + if (reductionDims.size() == 1) { + redDim = reductionDims[0]; + } + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + int64_t redDimSize = staticLoopRanges[redDim]; + + if (numLoops == 1 && redDim != -1 && redDimSize % warpSize == 0 && + redDimSize > warpSize) + return true; + } + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto genericOp = llvm::dyn_cast_or_null(op); + ::mlir::OperandRange initRange = genericOp.getDpsInits(); + int64_t redDim = *getReductionDim(genericOp); + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + int64_t redDimSize = staticLoopRanges[redDim]; + int64_t splitFactor = redDimSize / warpSize; + // tile redution dim & mapping parallel dim to warp + auto mappingWarpAttr = gpu::GPUWarpMappingAttr::get( + b.getContext(), gpu::MappingId::LinearDim0); + + SmallVector numThreads{splitFactor}; + + auto tileReductionOp = b.create( + /* target */ pdlV, + /* num_threads */ numThreads, + /*staticTileSizes*/ SmallVector{}, + /*mapping*/ b.getArrayAttr(mappingWarpAttr)); + + b.create( + /* target */ tileReductionOp.getSplitLinalgOp(), + /* name */ kWarpReduction, + /* param */ Value()); + + b.create( + /* target */ tileReductionOp.getCombiningLinalgOp(), + /* name */ kWarpReduction, + /* param */ Value()); + + int64_t initStart = initRange.getBeginOperandIndex(); + int64_t initEnd = initStart + initRange.size(); + for (int64_t i = initStart; i < initEnd; ++i) { + // get the neutral tensor.empty() + // FIXME(zxg): if fillOp has multi output, this may be buggy. + auto producer = b.create( + /* producer type */ transform::OperationType::get( + b.getContext(), tensor::EmptyOp::getOperationName()), + /* target */ tileReductionOp.getFillOp(), + /* operand number */ 1); + + // promote to WorkGroup + auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( + b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + + promoteAllTensorsWithinOp(b, producer, workgroupMemoryAddressSpace); + } + + // fuse fill + b.create( + /* producerOp */ tileReductionOp.getFillOp(), + /* containingOp */ tileReductionOp.getForallOp()); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} + +void createGPUTileWarpReductionTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix, + int64_t warpSize) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (auto genericOp = llvm::dyn_cast_or_null(op)) { + if (isWarpReductionOp(genericOp) || isBlockReductionOp(genericOp)) { + int64_t numLoops = genericOp.getNumLoops(); + int64_t redDim = -1; + SmallVector reductionDims; + genericOp.getReductionDims(reductionDims); + if (reductionDims.size() == 1) { + redDim = reductionDims[0]; + } else { + return false; + } + auto staticLoopRanges = genericOp.getStaticLoopRanges(); + if (staticLoopRanges[redDim] != ShapedType::kDynamic && numLoops == 1 && + redDim != -1 && staticLoopRanges[redDim] <= warpSize) + return true; + } + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + auto genericOp = llvm::cast(op); + scf::ForallOp parentOp = genericOp->getParentOfType(); + Value toVectorize = pdlV; + Value forall = b.create( + pdlV.getType(), pdlV, + /* isolated_from_above */ false, + /* allow_empty_results */ false, + /* op_name */ b.getStringAttr(scf::ForallOp::getOperationName()), + /* deduplicate */ false, + /* nth_parent */ 1); + if (!parentOp || !isMappedToGPUWarps(parentOp)) { + std::vector fuseCandidates; + for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) { + ProducerSelector::detectFillOperand(&opOperand, fuseCandidates); + } + + SmallVector mapping{gpu::GPUWarpMappingAttr::get( + b.getContext(), gpu::MappingId::LinearDim0)}; + SmallVector numThreads(1, 1); + auto tileOp = + tileToForallAndFuseImpl(b, pdlV, numThreads, mapping, fuseCandidates, + /* asNumThreads = */ true); + forall = tileOp.getForallOp(); + toVectorize = tileOp.getTiledOp(); + } + + // convert inner redution to vector multi_reduction + b.create( + /* target */ toVectorize, + /* vector_sizes */ ValueRange{}, + /* vectorize_nd_extract */ b.getUnitAttr(), + /* scalable_sizes */ SmallVector{}, + /* static_vector_sizes */ SmallVector{}); + + // lower vector.multi_reduction to vector.reduction + b.create( + forall, [](OpBuilder &b, Location loc) { + b.create( + loc, vector::VectorMultiReductionLowering::InnerReduction); + }); + }; + pm.addPass(createGenericTransformInsertionPass(config)); +} + void createGPUTileThreadReductionTransformImpl(OpPassManager &pm, const std::string &anchor, const std::string &prefix) { @@ -957,8 +1450,7 @@ void mlir::createGPUTileGridReductionTransform( OpPassManager &pm, const GPUTileGridReductionOptions &options) { invokeOpPassPipelineBuilder(createGPUTileGridReductionTransformImpl, pm, options.funcAnchor, options.annotatePrefix, - options.warpSize, options.blockSize, - options.usingForall); + options.warpSize, options.blockSize); } void mlir::createGPUSplitBlockReductionTransform( @@ -972,8 +1464,21 @@ void mlir::createGPUTileBlockReductionTransform( OpPassManager &pm, const GPUTileBlockReductionOptions &options) { invokeOpPassPipelineBuilder(createGPUTileBlockReductionTransformImpl, pm, options.funcAnchor, options.annotatePrefix, - options.warpSize, options.blockSize, - options.usingForall); + options.warpSize, options.blockSize); +} + +void mlir::createGPUTileSplitWarpReductionTransform( + OpPassManager &pm, const GPUTileSplitWarpReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileSplitWarpReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, + options.blockSize, options.warpSize); +} + +void mlir::createGPUTileWarpReductionTransform( + OpPassManager &pm, const GPUTileWarpReductionOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileWarpReductionTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, + options.warpSize); } void mlir::createGPUTileThreadReductionTransform( diff --git a/compiler/lib/Pipelines/HloOpt.cpp b/compiler/lib/Pipelines/HloOpt.cpp index 2f99f749d..67304133b 100644 --- a/compiler/lib/Pipelines/HloOpt.cpp +++ b/compiler/lib/Pipelines/HloOpt.cpp @@ -44,7 +44,9 @@ void addGenericHloFusionPatterns(OpPassManager &pm, pm.addNestedPass(createDotTransposeFusionPass()); pm.addNestedPass(createReductionFusionPass()); + pm.addNestedPass(createConcatSliceFusionPass()); + pm.addNestedPass(createInsertSliceWithElemwiseFusionPass()); // Element fusion (always last?) // Note: if outlineSingleElemwiseOp is set, element fusion must be the last // pass, since it will cluster every elemenwise op which is not fused yet into diff --git a/compiler/lib/Pipelines/LinalgTensorOpt.cpp b/compiler/lib/Pipelines/LinalgTensorOpt.cpp index 96a78e9c1..7c291948e 100644 --- a/compiler/lib/Pipelines/LinalgTensorOpt.cpp +++ b/compiler/lib/Pipelines/LinalgTensorOpt.cpp @@ -33,11 +33,21 @@ #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Transforms/Passes.h" +#include "transforms/passes.h" using namespace mlir; namespace { void addGenericLinalgPasses(OpPassManager &pm) { + { // for insert slice fusion + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createTensorToLinalgPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + pm.addNestedPass(createAnchoredPipelinePass( + getByteIRElementwiseFusionAttrName(), anchoredPM)); + } + pm.addNestedPass( createHloFusionToLinalgPass(getByteIRElementwiseFusionAttrName())); pm.addNestedPass( @@ -66,6 +76,7 @@ void addGenericLinalgPasses(OpPassManager &pm) { /*enableSharedInput*/ true, /*enableDiffShapes*/ false)); anchoredPM.addPass(createCSEPass()); anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(bufferization::createEmptyTensorEliminationPass()); pm.addNestedPass( createAnchoredPipelinePass(elementwiseAnchor, anchoredPM)); } @@ -87,32 +98,36 @@ void addGenericLinalgPasses(OpPassManager &pm) { createGPUSplitGridReductionTransform(pm, splitGridRedOptions); pm.addPass(createTransformDialectInterpreter(true)); pm.addPass(createCanonicalizerPass()); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } GPUTileGridReductionOptions tileGridRedOptions; tileGridRedOptions.funcAnchor = reductionAnchor; - tileGridRedOptions.blockSize = 512; + tileGridRedOptions.blockSize = 256; pm.addPass(createLinalgFoldUnitExtentDimsPass()); createGPUTileGridReductionTransform(pm, tileGridRedOptions); pm.addPass(createTransformDialectInterpreter(true)); { OpPassManager anchoredPM(func::FuncOp::getOperationName()); - anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); anchoredPM.addPass(createCanonicalizerPass()); anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); pm.addNestedPass( createAnchoredPipelinePass(reductionAnchor, anchoredPM)); } - GPUSplitBlockReductionOptions splitBlockRedOptions; - splitBlockRedOptions.funcAnchor = reductionAnchor; - splitBlockRedOptions.splitFactor = 16; - createGPUSplitBlockReductionTransform(pm, splitBlockRedOptions); - pm.addPass(createTransformDialectInterpreter(true)); - pm.addPass(createCanonicalizerPass()); - GPUTileBlockReductionOptions tileBlockRedOptions; tileBlockRedOptions.funcAnchor = reductionAnchor; - tileBlockRedOptions.blockSize = 512; + tileBlockRedOptions.blockSize = 256; createGPUTileBlockReductionTransform(pm, tileBlockRedOptions); pm.addPass(createTransformDialectInterpreter(true)); { @@ -133,11 +148,50 @@ void addGenericLinalgPasses(OpPassManager &pm) { anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); anchoredPM.addPass(createCanonicalizerPass()); anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); pm.addNestedPass( createAnchoredPipelinePass(reductionAnchor, anchoredPM)); } - pm.addPass(createDetensorizeTransformInsertionPass(reductionAnchor)); + // Combine block redution + // step 1: per warp redution + GPUTileSplitWarpReductionOptions splitWarpRedOptions; + splitWarpRedOptions.funcAnchor = reductionAnchor; + splitWarpRedOptions.blockSize = 256; + createGPUTileSplitWarpReductionTransform(pm, splitWarpRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } + + // step 2: reduce in first warp + GPUTileWarpReductionOptions warpRedOptions; + warpRedOptions.funcAnchor = reductionAnchor; + warpRedOptions.warpSize = 32; + createGPUTileWarpReductionTransform(pm, warpRedOptions); + pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCanonicalizeExtPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createLinalgFoldUnitExtentDimsPass()); + pm.addNestedPass( + createAnchoredPipelinePass(reductionAnchor, anchoredPM)); + } + + pm.addNestedPass(mlir::createDetensorizeScfOpsPass()); + pm.addPass(createCanonicalizeExtPass()); + pm.addPass(createDetensorizeTransformInsertionPass(true, reductionAnchor)); pm.addPass(createTransformDialectInterpreter(true)); pm.addPass(createCanonicalizeExtPass()); pm.addPass(createRewriteInDPSTransformInsertionPass(reductionAnchor)); diff --git a/compiler/lib/Pipelines/SCFOpt.cpp b/compiler/lib/Pipelines/SCFOpt.cpp index e9d06b184..2a86b4140 100644 --- a/compiler/lib/Pipelines/SCFOpt.cpp +++ b/compiler/lib/Pipelines/SCFOpt.cpp @@ -19,10 +19,12 @@ #include "byteir/Dialect/Linalg/Passes.h" #include "byteir/Dialect/Linalg/Transforms/LinalgExtToLoops.h" +#include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/Common/Utils.h" #include "byteir/Transforms/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -42,6 +44,9 @@ void addGenericSCFOptPasses(OpPassManager &pm) { pm.addPass(createLowerAffinePass()); pm.addNestedPass(createLoopCoalescingPass()); pm.addPass(arith::createIntRangeOptimizationsPass()); + // for reduction + pm.addNestedPass( + createFuseNestedForallPass(getByteIRReductionFusionAttrName())); addCleanUpExtPassPipeline(pm); } diff --git a/compiler/lib/Utils/CMakeLists.txt b/compiler/lib/Utils/CMakeLists.txt index 32ef0a184..a5aeacaa8 100644 --- a/compiler/lib/Utils/CMakeLists.txt +++ b/compiler/lib/Utils/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_library(ByteIRUtils MemUtils.cpp ModuleUtils.cpp OpInterfaceUtils.cpp + OptionUtils.cpp PipelineUtils.cpp TileUtils.cpp TypeUtils.cpp diff --git a/compiler/lib/Utils/OptionUtils.cpp b/compiler/lib/Utils/OptionUtils.cpp new file mode 100644 index 000000000..0dfa0b1a6 --- /dev/null +++ b/compiler/lib/Utils/OptionUtils.cpp @@ -0,0 +1,109 @@ +//===- OptionUtils.cpp ------------------------------ -*- C++ ------*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Utils/OptionUtils.h" + +using KernelDims = llvm::cl::KernelDims; +bool llvm::cl::parser::parse(Option &O, StringRef ArgName, + StringRef Arg, KernelDims &Val) { + SmallVector integerVals; + if (Arg.size() <= 0) + return true; + int64_t idx = 0; + int64_t len = Arg.size(); + auto parseInteger = [&]() -> std::optional { + int64_t sgn = 1; + if (idx < len && Arg[idx] == '-') { + sgn = -1; + idx += 1; + } + int64_t val = 0; + int64_t start = idx; + while (idx < len && Arg[idx] <= '9' && Arg[idx] >= '0') { + val = val * 10 + Arg[idx] - '0'; + idx += 1; + } + + if (idx == start) + return std::nullopt; + val *= sgn; + return val; + }; + + auto consumeIf = [&](char ch) -> bool { + if (idx < len && Arg[idx] == ch) { + idx += 1; + return true; + } + return false; + }; + + auto curInt = parseInteger(); + if (curInt.has_value()) { + integerVals.emplace_back(curInt.value()); + } + while (consumeIf(' ')) { + } + while (consumeIf(',')) { + while (consumeIf(' ')) { + } + auto curInt = parseInteger(); + if (!curInt.has_value()) + return true; + integerVals.emplace_back(curInt.value()); + while (consumeIf(' ')) { + } + } + if (static_cast(integerVals.size()) != 3 || idx != len) + return true; + for (auto v : integerVals) { + if (v < 0) + return true; + } + Val.x = integerVals[0]; + Val.y = integerVals[1]; + Val.z = integerVals[2]; + return false; +} + +void llvm::cl::parser::printOptionDiff(const Option &O, + KernelDims V, + const OptVal &Default, + size_t GlobalWidth) const { + printOptionName(O, GlobalWidth); + std::string Str; + { + llvm::raw_string_ostream SS(Str); + SS << "{" << V.x << ", " << V.y << ", " << V.z << "}"; + } + outs() << "= " << Str; + outs().indent(2) << " (default: "; + if (Default.hasValue()) { + outs() << "{" << Default.getValue().x << ", " << Default.getValue().y + << ", " << Default.getValue().z << "}"; + } else { + outs() << "*no default*"; + } + outs() << ")\n"; +} + +void llvm::cl::parser::print(raw_ostream &os, + const KernelDims &value) { + os << "{" << value.x << ", " << value.y << ", " << value.z << "}"; +} + +void llvm::cl::parser::anchor() {} diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index d45100d96..a0c948f90 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -375,7 +375,7 @@ def compile( ### legalize stablehlo to mhlo with context: - PassManager.parse("builtin.module(canonicalize,stablehlo-legalize-to-hlo,canonicalize)").run(module.operation) + PassManager.parse("builtin.module(canonicalize,stablehlo-legalize-to-hlo,canonicalize-ext,canonicalize)").run(module.operation) _print_verbose(module, "// IR Dump After Legalize to HLO:") if verbose else ... ### parse output options from output_file_path diff --git a/compiler/test/Dialect/MemRef/removeCopy.mlir b/compiler/test/Dialect/MemRef/removeCopy.mlir index 3302a4026..15befc9fe 100644 --- a/compiler/test/Dialect/MemRef/removeCopy.mlir +++ b/compiler/test/Dialect/MemRef/removeCopy.mlir @@ -638,4 +638,3 @@ module attributes {byre.container_module} { // CHECK: memref.copy // CHECK: memref.copy - diff --git a/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt b/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt index 9fa0c18a8..a7baa3fba 100644 --- a/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt +++ b/frontends/torch-frontend/torch-frontend/python/CMakeLists.txt @@ -26,6 +26,7 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel byteir_backend/compiled_function.py byteir_backend/compiler.py byteir_backend/config.py + byteir_backend/debug.py byteir_backend/inner_compile.py byteir_backend/utils.py byteir_backend/byteir_fusible_pattern.py diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/README.md b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/README.md index 5997dc9e2..4c59e938f 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/README.md +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/README.md @@ -17,7 +17,8 @@ class NaiveModel(torch.nn.Module): r1 = torch.ops.aten.div(r0, x2) x0 = torch.ops.aten.mul(r1, r1) - x0 r2 = torch.ops.aten.slice(x0, 1, 1, 3, 1) - return r1, r2 + r3 = torch.ops.aten.slice(x0, 1, 1, 3, 1) + return r1, r2, r3 model = NaiveModel() opt_mod = torch.compile(model, backend="byteir") diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/__init__.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/__init__.py index ec2281f43..5923504df 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/__init__.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/__init__.py @@ -7,6 +7,12 @@ def byteir(*args, **kwargs): return byteir_compiler(*args, **kwargs) +@register_backend +def byteir_debug(*args, **kwargs): + from .debug import debug_backend + + return debug_backend(*args, **kwargs) + def set_cache_dir(path: str): from .compilation_cache import ByteIRFxGraphCache diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compilation_cache.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compilation_cache.py index fb86d9b59..628b79d4c 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compilation_cache.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compilation_cache.py @@ -21,24 +21,18 @@ import logging import shutil from copy import copy -import dataclasses from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union from filelock import FileLock import torch from torch._inductor.codecache import ( - BypassFxGraphCache, LOCK_TIMEOUT, - sha256_hash, - OrderedSetHolder, write_atomic, - _reduce_fake_tensor, - _reduce_symint, ) from torch._dynamo.utils import counters -from torch.fx.experimental.symbolic_shapes import ShapeEnv, has_hint, hint_int -from torch._subclasses.fake_tensor import extract_tensor_metadata, FakeTensor +from torch.fx.experimental.symbolic_shapes import ShapeEnv, has_hint, hint_int, SYMPY_INTERP +from torch._subclasses.fake_tensor import FakeTensor from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete try: @@ -47,10 +41,22 @@ ... from .compiled_function import (CompiledArtifact, ByteIRFunction) +from .utils import ( + dump_tensors_meta_info, + BypassFxGraphCache, + OrderedSetHolder, + TensorMetadata, + extract_tensor_metadata, + maybe_get_fake_mode, + _reduce_fake_tensor, + _reduce_symint, + sha256_hash, +) log = logging.getLogger(__name__) + def get_system_info() -> Dict[str, Any]: try: system: Dict[str, Any] = { @@ -88,7 +94,7 @@ def __init__( example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], ): - self.gm = gm + self.gm = gm.__str__() self.example_inputs = example_inputs # Order kwargs so hashing is stable to changes in kwarg order. @@ -102,13 +108,12 @@ def __init__( else: self.fx_kwargs[k] = fx_kwargs[k] - # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. - self.deterministic_algorithms_settings = ( - torch.are_deterministic_algorithms_enabled(), - torch.is_deterministic_algorithms_warn_only_enabled(), - torch.utils.deterministic. - fill_uninitialized_memory, # type: ignore[attr-defined] - ) + # # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. + # self.deterministic_algorithms_settings = ( + # torch.are_deterministic_algorithms_enabled(), + # torch.is_deterministic_algorithms_warn_only_enabled(), + # byteir_backend.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + # ) # Global settings affecting matmul codegen. self.cuda_matmul_settings = ( @@ -146,6 +151,9 @@ def get_str(obj) -> str: for k, v in obj.items(): h = ByteIRFxGraphCachePickler.get_hash(v) lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + elif isinstance(obj, torch.fx.GraphModule): + h = ByteIRFxGraphCachePickler.get_hash(obj.__str__()) + lines.append(f"[{h}] {attr}: {get_str(obj)}") else: h = ByteIRFxGraphCachePickler.get_hash(obj) lines.append(f"[{h}] {attr}: {get_str(obj)}") @@ -288,11 +296,34 @@ def _get_shape_env() -> Optional[ShapeEnv]: """ Helper to get the shape env from the tracing context. """ - ctx = torch._guards.TracingContext.try_get() + ctx = torch._guards.TracingContext.get() if not ctx: return None return ctx.fake_mode.shape_env + @staticmethod + def _produce_guards_expression(shape_env, placeholders, ignore_static=True): + """ + Expected to be used with evaluate_guards_expression(). Produces the guards + for the given placeholders and returns a string expression to be evaluated + by evaluate_guards_expression given concrete values for the placeholders. + """ + from torch._dynamo.source import LocalSource + arg_names = [f"t{i}" for i in range(len(placeholders))] + guards = shape_env.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static) + if guards: + return " and ".join(guards) + return None + + @staticmethod + def _evaluate_guards_expression(code, args): + """ + Expected to be used with produce_guards_expression(). Evaluates an expression + generated by produce_guards_expression for the given concrete args. + """ + arg_names = [f"t{i}" for i in range(len(args))] + return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) + @staticmethod def _lookup_compiled_artifact( key: str, @@ -335,8 +366,8 @@ def _lookup_compiled_artifact( # affect the current env, e.g., cause the creation of new guards, # so we evaluate with the hints instead of the symbols. hit = bool( - shape_env.evaluate_guards_expression(candidate.guards_expr, - hints)) + ByteIRFxGraphCache._evaluate_guards_expression( + candidate.guards_expr, hints)) log.debug( "fx graph cache key %s evaluating guards [%s] with values %s => hit=%s", key, @@ -354,8 +385,8 @@ def _lookup_compiled_artifact( # Now re-evaluate with the symints to add any guards to the current env. if artifact.guards_expr: check = bool( - shape_env.evaluate_guards_expression(artifact.guards_expr, - symints)) + ByteIRFxGraphCache._evaluate_guards_expression( + artifact.guards_expr, symints)) assert check is True log.debug("fx graph cache key %s post-load guards: %s", key, shape_env.guards) @@ -377,8 +408,8 @@ def _save_compiled_artifact(key: str, compiled_artifact: CompiledArtifact, shape_env = ByteIRFxGraphCache._get_shape_env() assert shape_env is not None symints = ByteIRFxGraphCache._filter_symints(example_inputs) - compiled_artifact.guards_expr = shape_env.produce_guards_expression( - symints) + compiled_artifact.guards_expr = ByteIRFxGraphCache._produce_guards_expression( + shape_env, symints) try: # FIXME compiled_artifact is not serializable. @@ -452,7 +483,9 @@ def Load(compile_fn: Callable, gm: torch.fx.GraphModule, free_func=caching_allocator_delete) byre_session.load(compiled_artifact.byre_file) compiled_func = ByteIRFunction(byre_session, - compiled_artifact.none_indices) + compiled_artifact.none_indices, + compiled_artifact.aliased_out_indices, + compiled_artifact.output_meta_info,) # save `ByteIRFunction` obj. ByteIRFxGraphCache._save_func(key, compiled_func) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py index a81885699..99bcf4aba 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py @@ -1,6 +1,8 @@ import dataclasses +import functools import logging from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union +from brt.utils import brt_dtype_to_torch_dtype import torch @@ -18,6 +20,8 @@ class CompiledArtifact: # TODO. serialize Session object. #byre_session: object none_indices: List[int] + aliased_out_indices: Optional[List[int]] = None # fw only. + output_meta_info: Optional[List[torch.Tensor]] = None hash_key: Optional[str] = None # This is a string representation of an expression we serialize # with the object so the guards can be evaluated in a different @@ -32,7 +36,7 @@ class ByteIRFunction: Wrap the byteir compiled function and runtime as a callable object for dynamo, as dynamo caches callable in guards. """ - def __init__(self, module_path_or_session, none_indices): + def __init__(self, module_path_or_session, none_indices, aliased_out_indices=None, output_meta_info=None): if isinstance(module_path_or_session, brt.Session): self._session = module_path_or_session else: @@ -40,58 +44,118 @@ def __init__(self, module_path_or_session, none_indices): free_func=caching_allocator_delete) self._session.load(module_path_or_session) self._none_indices = none_indices + self._aliased_out_indices = aliased_out_indices + self._output_meta_info = output_meta_info self._req = self._session.new_request_context( torch.cuda.current_stream()._as_parameter_.value) + self.input_arg_offsets = self._session.get_input_arg_offsets() + self.output_arg_offsets = self._session.get_output_arg_offsets() + + self.output_shape_and_dtype = [( + self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset)), + ) for offset in self._session.get_output_arg_offsets()] + + self._outs_len = len(self.output_arg_offsets) + self.static_shape_and_dtype = [ + (self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset))) + for offset in self.output_arg_offsets + ] + + self.real_outs_index_map = self._get_outputs_index_map( + self._outs_len, self._none_indices) + self.strited_inputs_index = None + + def _get_outputs_index_map(self, out_lens: int, none_indices: List[int]): + res = [] + none_lens = len(none_indices) + none_cnt = 0 + for idx in range(out_lens + none_lens): + if none_cnt < none_lens and idx == none_indices[none_cnt]: + none_cnt += 1 + continue + res.append(idx) + + return res + + @functools.lru_cache + def get_output_storage_size(self, fake_t): + _size = fake_t.size() + _stride = fake_t.stride() + _offset = fake_t.storage_offset() + sz = _offset + for d,s in zip(_size, _stride): + sz += (d-1) * s + sz += 1 + return sz + + @functools.lru_cache + def get_out_tensors(self, device): + outputs_ptr = [None] * self._outs_len + results = [None] * (self._outs_len + len(self._none_indices)) + + # for idx, shape_dty in enumerate(self.static_shape_and_dtype): + # _out = torch.empty(shape_dty[0], dtype=shape_dty[1], device=device) + # results[self.real_outs_index_map[idx]] = _out + # outputs_ptr[idx] = _out.data_ptr() + + _visited_aliased_out_cnt = 0 + for idx, shape_dty in enumerate(self.static_shape_and_dtype): + fake_t = self._output_meta_info[idx] + _aliased_out_len = 0 if self._aliased_out_indices is None else len(self._aliased_out_indices) + if _visited_aliased_out_cnt < _aliased_out_len and idx == self._aliased_out_indices[_visited_aliased_out_cnt]: + _visited_aliased_out_cnt += 1 + _out = torch.empty((1, self.get_output_storage_size(fake_t)), dtype=fake_t.dtype, device=device) + _out = _out.as_strided(size=fake_t.size(), stride=fake_t.stride(), storage_offset=fake_t.storage_offset()) + else: + _out = torch.empty(shape_dty[0], dtype=shape_dty[1], device=device) + results[self.real_outs_index_map[idx]] = _out + outputs_ptr[idx] = _out.data_ptr() + return results, outputs_ptr def __call__(self, *inputs): - from brt.utils import brt_dtype_to_torch_dtype log.debug(f"***** Run function compiled through byteir ******") + log.debug(f"_aliased_out_indices={self._aliased_out_indices}") + #log.debug(f"_output_meta_info={self._output_meta_info}") # FIXME. byteir requires all inputs on device side, move host side tensor to device. # Preprocess the strided tensor as byteir does not support yet. - new_inputs = [] - - for i in range(0, len(inputs)): - _t = inputs[i] - if not _t.is_cuda: - log.warning(f"device error: type={type(_t)}, {_t.device}") - _t = _t.to("cuda") - new_inputs.append(_t.contiguous()) - - device = new_inputs[0].device - - results = [ - torch.empty( - self._session.get_static_shape(offset), - dtype=brt_dtype_to_torch_dtype( - self._session.get_data_type(offset)), - device=device, - ) for offset in self._session.get_output_arg_offsets() - ] - - for offset, input in zip(self._session.get_input_arg_offsets(), - new_inputs): - self._req.bind_arg(offset, input.data_ptr()) - for offset, output in zip(self._session.get_output_arg_offsets(), - results): - self._req.bind_arg(offset, output.data_ptr()) + new_inputs = [None] * len(inputs) + + if self.strited_inputs_index is None: + self.strited_inputs_index = [] + for i in range(0, len(inputs)): + _t = inputs[i] + if not _t.is_contiguous(): + _t = _t.contiguous() + self.strited_inputs_index.append(i) + new_inputs[i] = _t + else: + for i in range(0, len(inputs)): + new_inputs[i] = inputs[i] + for i in self.strited_inputs_index: + new_inputs[i] = inputs[i].contiguous() + + device = inputs[0].device + + results, outputs_ptr = self.get_out_tensors(device) + + inputOffsetAndArg = [None] * len(new_inputs) + outputOffsetAndArg = [None] * len(outputs_ptr) + for idx, (offset, inp) in enumerate(zip(self.input_arg_offsets, new_inputs)): + inputOffsetAndArg[idx] = (offset, inp.data_ptr()) + for idx, (offset, output_ptr) in enumerate(zip(self.output_arg_offsets, outputs_ptr)): + outputOffsetAndArg[idx] = (offset, output_ptr) + self._req.bind_args(inputOffsetAndArg) + self._req.bind_args(outputOffsetAndArg) self._req.finish_io_binding() self._req.run() self._req.sync() - # add None results to return values - rets = [] - none_cnt = 0 - result_cnt = 0 - for i in range(len(results) + len(self._none_indices)): - if none_cnt < len( - self._none_indices) and i == self._none_indices[none_cnt]: - rets.append(None) - none_cnt += 1 - else: - rets.append(results[result_cnt]) - result_cnt += 1 + rets = results + if len(rets) == 1: return rets[0] return rets diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py index e8ba0283a..276aa3d93 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py @@ -3,6 +3,7 @@ import torch from functorch.compile import min_cut_rematerialization_partition, default_partition +from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch._decomp import register_decomposition, get_decompositions, core_aten_decompositions from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import ( @@ -12,6 +13,7 @@ from .inner_compile import (byteir_fx_compiler) from .partitioners import fuse_aware_min_cut_partition +from .utils import collect_outputs_aliased_inputs def byteir_decompositions(): @@ -41,13 +43,24 @@ def byteir_compiler( model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor], ): + """ + # dump module before aot-autograd + with maybe_disable_fake_tensor_mode(): + import tempfile + temp_dir = tempfile.mkdtemp(dir='/tmp/FxModule/') + model_.to_folder(folder=temp_dir, module_name="FxModule") + """ + # analysis output aliased to inputs in *fw* + aliased_out_info = collect_outputs_aliased_inputs(model_, example_inputs_) + _byteir_compiler = aot_autograd( - fw_compiler=functools.partial(byteir_fx_compiler, is_backward=False), + fw_compiler=functools.partial(byteir_fx_compiler, is_backward=False, aliased_out_info=aliased_out_info), bw_compiler=functools.partial(byteir_fx_compiler, is_backward=True), decompositions=byteir_decompositions, partition_fn=byteir_partition_fn, #partition_fn=min_cut_rematerialization_partition, #partition_fn=default_partition, + keep_inference_input_mutations=False, ) fake_mode = detect_fake_mode( diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/debug.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/debug.py new file mode 100644 index 000000000..07b21e5bc --- /dev/null +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/debug.py @@ -0,0 +1,41 @@ +import functools +import logging +from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union + +import torch + +from .compiler import byteir_compiler + +log = logging.getLogger(__name__) + + +def debug_backend(gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor]): + """ + compare results between byteir compiled function and eager mode graph. + """ + _opt_gm = byteir_compiler(gm, example_inputs) + + def f(*inputs): + opt_inputs = [ + inp.as_strided(size=inp.size(), + stride=inp.stride(), + storage_offset=inp.storage_offset()) + for inp in inputs + ] + eager_inputs = inputs + eager_res = gm(*eager_inputs) + opt_res = _opt_gm(*opt_inputs) + + # compare results + # TODO: check meta info as well as numercial. + try: + torch.testing.assert_close(eager_res, opt_res) + except Exception as e: + log.error(f"******* debug backend pass *******") + raise e + + print(f"******* debug backend pass *******") + return eager_res + + return f diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/inner_compile.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/inner_compile.py index aba69e419..8b07d4bb9 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/inner_compile.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/inner_compile.py @@ -9,16 +9,13 @@ from torch._dynamo import ( utils as dynamo_utils, ) from torch._dynamo.utils import counters +from torch._dynamo.utils import detect_fake_mode from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch._subclasses.fake_tensor import ( FakeTensorMode, FakeTensor, FakeTensorConverter, - TensorMetadata, - extract_tensor_metadata, - maybe_get_fake_mode, - unset_fake_temporarily, ) import torch_frontend @@ -40,7 +37,13 @@ ByteIRFunction, ) from .utils import ( - dump_tensors_meta_info, ) + dump_tensors_meta_info, + BypassFxGraphCache, + OrderedSetHolder, + TensorMetadata, + extract_tensor_metadata, + maybe_get_fake_mode, +) from . import config log = logging.getLogger(__name__) @@ -49,15 +52,16 @@ BACKEND_LEGAL_OPS = ["aten.max.dim"] -@dynamo_utils.dynamo_timed(phase_name="byteir_compile") +#@dynamo_utils.dynamo_timed(phase_name="byteir_compile") def inner_compile(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], workdir: str = None, compiler_type: str = "forward", + aliased_out_indices: List[int] = None, **kwargs) -> CompiledArtifact: graph_id = next(g_graph_counter) - log.debug(f"byteir compiling {compiler_type} graph {graph_id}") + log.info(f"byteir compiling {compiler_type} graph {graph_id}") if workdir is None: key = compiled_fx_graph_hash(gm, example_inputs, kwargs) @@ -75,10 +79,15 @@ def inner_compile(gm: torch.fx.GraphModule, fxg_dir_name = f"fx_graph_{compiler_type}_{graph_id}" fx_graph_folder = f"{workdir}/{fxg_dir_name}/" os.makedirs(fx_graph_folder, exist_ok=True) - with unset_fake_temporarily(): + with maybe_disable_fake_tensor_mode(): gm.to_folder(folder=fx_graph_folder, module_name="FxModule") - with FakeTensorMode(allow_non_fake_inputs=True): + with detect_fake_mode(example_inputs): + #with FakeTensorMode(allow_non_fake_inputs=True): fake_outs = gm(*example_inputs) + if isinstance(fake_outs, tuple): + fake_outs = list(fake_outs) + elif isinstance(fake_outs, torch.Tensor): + fake_outs = [fake_outs] dump_tensors_meta_info( example_inputs, os.path.join(fx_graph_folder, "inputs_meta_info.pkl")) @@ -108,34 +117,46 @@ def inner_compile(gm: torch.fx.GraphModule, log.debug("#### byteir compile success") none_indices = get_none_indices(gm) - compiled_artifact = CompiledArtifact(byre_file, none_indices) + compiled_artifact = CompiledArtifact( + byre_file, + none_indices, + aliased_out_indices, + fake_outs, + ) return compiled_artifact def byteir_fx_compiler(gm: torch.fx.GraphModule, example_inputs, - is_backward=False): + is_backward=False, + aliased_out_info: List[Any]=None, +): """ The main entry function of byteir torch compiler backend. """ compiler_type = "backward" if is_backward else "forward" + aliased_out_indices = None + if aliased_out_info is not None: + aliased_out_indices = [idx for idx, _ in aliased_out_info] log.info( f"########################### {'FORWARD' if not is_backward else 'BACKWARD'} ###########################" ) - log.info(torch._guards.TracingContext.try_get()) + log.info(torch._guards.TracingContext.get()) if config.byteir_not_use_cache: - compiled_artifact = inner_compile(gm, example_inputs) + compiled_artifact = inner_compile(gm, example_inputs, aliased_out_indices=aliased_out_indices) byre_session = brt.Session(alloc_func=caching_allocator_alloc, free_func=caching_allocator_delete) byre_session.load(compiled_artifact.byre_file) byre_func = ByteIRFunction(byre_session, - compiled_artifact.none_indices) + compiled_artifact.none_indices, + compiled_artifact.aliased_out_indices, + compiled_artifact.output_meta_info,) else: byre_func = ByteIRFxGraphCache.Load( - functools.partial(inner_compile, compiler_type=compiler_type), gm, + functools.partial(inner_compile, compiler_type=compiler_type, aliased_out_indices=aliased_out_indices), gm, example_inputs) log.debug(f"Counters:\n{counters}") diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/partitioners.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/partitioners.py index 580a8af26..ba5f733a5 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/partitioners.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/partitioners.py @@ -3,7 +3,7 @@ is_symbol_binding_fx_node, find_symbol_binding_fx_nodes ) -from torch.fx.experimental.sym_node import ( +from torch.fx.experimental.symbolic_shapes import ( magic_methods, method_to_operator, ) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/utils.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/utils.py index ed35b95b1..5183e14a9 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/utils.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/utils.py @@ -1,15 +1,180 @@ +import base64 +import contextlib +import dataclasses +import hashlib import os import functools import time +from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union import pickle +from dataclasses import dataclass + import torch from torch._subclasses.fake_tensor import ( FakeTensorMode, FakeTensor, - extract_tensor_metadata, ) +from torch._prims_common import suggest_memory_format + +# fx graph outputs analysis +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code, preserve_rng_state, detect_fake_mode +from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions +from torch._functorch.aot_autograd import run_functionalized_fw_and_collect_metadata, OutputType +import torch.utils._pytree as pytree +from unittest.mock import patch +from contextlib import contextmanager, nullcontext + + +## Helper classes portting from torch2.4 +class BypassFxGraphCache(Exception): + """ + Exception to indicate that the FxGraphCache should be bypassed. + """ + + pass + + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: List[Any] + + +@dataclass(frozen=True) +class TensorMetadata: + """ + The Tensor metadata relevant to hashing FakeTensors when caching. + """ + + dtype: torch.dtype + shape: torch.Size + stride: Tuple[Any, ...] + device: torch.device + layout: torch.layout + memory_format: Optional[torch.memory_format] + storage_offset: int + storage_bytes: Optional[int] + requires_grad: bool + is_quantized: bool + is_conj: bool + is_neg: bool + is_inference: bool + is_sparse: bool # read: is sparse COO + is_coalesced: Optional[bool] + dense_dim: Optional[int] + sparse_dim: Optional[int] + + +## Helper functions portting from torch2.4 +def is_sparse_coo(t): + return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo + + +def is_sparse_compressed_layout(layout): + return layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + + +def is_sparse_compressed(t): + return isinstance(t, torch.Tensor) and is_sparse_compressed_layout( + t.layout) + + +def is_sparse_any(t): + return is_sparse_coo(t) or is_sparse_compressed(t) + + +def extract_tensor_metadata(t: torch.Tensor) -> "TensorMetadata": + """ + Extract the TensorMetadata of a tensor. + """ + memory_format: Optional[torch.memory_format] = suggest_memory_format(t) + if is_sparse_any(t) or not t.is_contiguous(memory_format=memory_format): + memory_format = None + + return TensorMetadata( + dtype=t.dtype, + shape=t.shape, + stride=t.stride() if t.layout == torch.strided else (), + device=t.device, + layout=t.layout, + memory_format=memory_format, + storage_offset=t.storage_offset(), + # Only set storage_bytes for tensors that have storage (not sparse) + storage_bytes=t.untyped_storage().nbytes() + if not t.is_sparse else None, + requires_grad=t.requires_grad, + is_quantized=t.is_quantized, + is_conj=t.is_conj(), + is_neg=t.is_neg(), + is_inference=t.is_inference(), + is_sparse=t.is_sparse, + is_coalesced=t.is_coalesced() if t.is_sparse else None, + dense_dim=t.dense_dim() if t.is_sparse else None, + sparse_dim=t.sparse_dim() if t.is_sparse else None, + ) + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode( + hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + + +def _ident(x: Any) -> Any: + return x + + +def _reduce_fake_tensor(t): + """ + See FxGraphCachePickler. Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata(t) + return (_ident, (metadata, )) + + +def _reduce_symint(s): + """ + See FxGraphCachePickler. Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and + # not the backed value. We evaluate guards stored with a cached graph + # to ensure a cached entity with SymInt args is safe to reuse. + return (_ident, (str(s), )) + + +def maybe_get_fake_mode(t): + if isinstance(t, FakeTensor): + return t.fake_mode + if is_traceable_wrapper_subclass(t): + inner_tensor_names, _ = t.__tensor_flatten__() + modes = [ + maybe_get_fake_mode(getattr(t, t_name)) + for t_name in inner_tensor_names + ] + m = modes[0] + assert all(m is x for x in modes) + return m + elif isinstance(t, torch.Tensor) and torch._is_functional_tensor(t): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor( + t, reapply_views) + return maybe_get_fake_mode(unwrapped) + elif isinstance(t, torch.Tensor) and is_functorch_wrapped_tensor(t): + unwrapped = torch._C._functorch.get_unwrapped(t) + return maybe_get_fake_mode(unwrapped) + return None def record_execution_time(stage: str = "Unknown"): @@ -41,6 +206,7 @@ def dump_tensors_meta_info(tensors: list[torch.Tensor, FakeTensor], with open(save_path, "wb") as f: pickle.dump(_meta_infos, f) + def cal_storage_size(size, stride, storage_offset): _size = storage_offset assert len(size) == len(stride) @@ -49,6 +215,7 @@ def cal_storage_size(size, stride, storage_offset): _size = _size + 1 return _size + def create_real_tensor(size, dtype, layout, device, requires_grad, stride, storage_offset): storage_size = cal_storage_size(size, stride, storage_offset) @@ -118,3 +285,77 @@ def create_real_tensor_from_fake(ft: FakeTensor): stride=ft.stride(), storage_offset=ft.storage_offset()) return rt + + +def create_fw_runtime_metadata(fn, args): + + flat_args, _ = pytree.tree_flatten((args)) + #TODO(chh) we should pass fake tensor in this util function. + #fake_mode = detect_fake_mode(flat_args) + #shape_env = fake_mode.shape_env + shape_env = None + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + + python_dispatcher_mode = (enable_python_dispatcher() + if shape_env is not None else nullcontext()) + with torch.autograd.set_multithreading_enabled(False), preserve_rng_state( + ), fake_mode, python_dispatcher_mode, PhiloxStateTracker(): + + def process_inputs(flat_args): + + def convert(idx, x): + if shape_env is not None: + from torch._dynamo.source import ConstantSource + if isinstance(x, int): + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), + hint=x, + source=source) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=False) + + return [convert(idx, x) for idx, x in enumerate(flat_args)] + + fake_flat_args = process_inputs(flat_args) + + needs_autograd = (any(x.requires_grad for x in fake_flat_args + if isinstance(x, torch.Tensor)) + and torch.is_grad_enabled()) + + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *flat_args: None): + fw_metadata = run_functionalized_fw_and_collect_metadata( + fn, + keep_input_mutations=False and not needs_autograd, + )(*fake_flat_args) + return fw_metadata + + +def collect_outputs_aliased_inputs(fn, args): + fw_metadata = create_fw_runtime_metadata(fn, args) + _mutated_inputs_len = fw_metadata.num_mutated_inputs + _org_graph_outputs_len = len(fw_metadata.output_info) + _num_intermidate_bases = fw_metadata.num_intermediate_bases + _compiled_fn_outputs_len = _mutated_inputs_len + _org_graph_outputs_len + _num_intermidate_bases + + outputs_aliased = [] + + for idx, info in enumerate(fw_metadata.output_info): + # NB. idx is the index of the output in original graph(before functionalization) + if info.output_type in [ + OutputType.alias_of_input, OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output + ]: + real_idx = idx + _mutated_inputs_len + outputs_aliased.append((real_idx, info.base_idx)) + return outputs_aliased diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu index 9436bafc2..07a8f7668 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu @@ -16,11 +16,12 @@ //===----------------------------------------------------------------------===// #include +#include namespace brt { namespace cuda { namespace kernel { - +constexpr int32_t kMaxGridDim = 65535; template __global__ void transpose_naive_2d_kernel(const T *input, T *output, int m, int n) { @@ -40,12 +41,164 @@ void transpose_naive_2d(const T *input, T *output, int m, int n, dim3 grid, transpose_naive_2d_kernel<<>>(input, output, m, n); } +// inner dim small than 4 +template +__global__ void batch_transpose_with_small_inner_dim_kernel( + const int32_t total_row, const int32_t row, void *__restrict__ inp_ptr, + void *__restrict__ out_ptr) { + + for (int32_t i = threadIdx.x + blockIdx.x * BlockSize, + step_tile = gridDim.x * BlockSize; + i < total_row; i += step_tile) { + const int32_t batch_idx = i / row; + const int32_t row_idx = i - row * batch_idx; + + T *inp_tile_gmem = reinterpret_cast(inp_ptr); + T *out_tile_gmem = reinterpret_cast(out_ptr); + inp_tile_gmem += batch_idx * row * Col + row_idx * Col; + out_tile_gmem += batch_idx * row * Col + row_idx; + + union PackType { + typename std::aligned_storage::type t; + T a[PackSize]; + }; + + for (int32_t j = 0; j < Col; j += PackSize) { + PackType val = (reinterpret_cast(inp_tile_gmem))[j]; + for (int32_t k = 0; k < PackSize; ++k) { + out_tile_gmem[k * row] = val.a[k]; + } + } + } +} + +// (batch, dim0, dim1) => (batch, dim1, dim0) +template +__global__ void batch_transpose_kernel(const int32_t total_tile_num, + const int32_t tile_num_in_dim0, + const int32_t tile_num_in_dim1, + const int32_t tile_per_sample, + const int32_t row, const int32_t col, + void *__restrict__ inp_ptr, + void *__restrict__ out_ptr) { + __shared__ T tile_in_shmem[TileSizeX][TileSizeY]; + for (int32_t i = blockIdx.x, step_tile = gridDim.x; i < total_tile_num; + i += step_tile) { + const int32_t batch_idx = i / tile_per_sample; + const int32_t remainder = i - batch_idx * tile_per_sample; + const int32_t dim0_idx = remainder / tile_num_in_dim1; + const int32_t dim1_idx = remainder - dim0_idx * tile_num_in_dim1; + + T *inp_tile_gmem = reinterpret_cast(inp_ptr); + T *out_tile_gmem = reinterpret_cast(out_ptr); + inp_tile_gmem += batch_idx * row * col + dim0_idx * TileSizeX * col + + dim1_idx * TileSizeY; + out_tile_gmem += batch_idx * row * col + dim1_idx * TileSizeY * row + + dim0_idx * TileSizeX; + + int32_t range_0 = dim0_idx < tile_num_in_dim0 - 1 + ? TileSizeX + : row - dim0_idx * TileSizeX; + int32_t range_1 = dim1_idx < tile_num_in_dim1 - 1 + ? TileSizeY + : col - dim1_idx * TileSizeY; + constexpr int32_t row_num_per_iter = BlockSize / TileSizeY; + constexpr int32_t col_num_per_iter = BlockSize / TileSizeX; + + int32_t tile_row_idx = threadIdx.x / TileSizeY; + int32_t tile_col_idx = threadIdx.x - tile_row_idx * TileSizeY; + for (int32_t j = tile_row_idx; j < range_0; j += row_num_per_iter) { + if (tile_col_idx < range_1) { + tile_in_shmem[j][tile_col_idx ^ j] = + inp_tile_gmem[j * col + tile_col_idx]; + } + } + __syncthreads(); + tile_row_idx = threadIdx.x / TileSizeX; + tile_col_idx = threadIdx.x - tile_row_idx * TileSizeX; + for (int32_t j = tile_row_idx; j < range_1; j += col_num_per_iter) { + if (tile_col_idx < range_0) { + out_tile_gmem[j * row + tile_col_idx] = + tile_in_shmem[tile_col_idx][j ^ tile_col_idx]; + } + } + __syncthreads(); + } +} + +template +void LaunchBatchTransposeWithSmallInnerDim(int32_t total_row, const int32_t row, + void *inp_ptr, void *out_ptr, + cudaStream_t stream) { + dim3 block(BlockSize); + const int32_t blockNum = (total_row - 1) / BlockSize + 1; + dim3 grid(blockNum >= kMaxGridDim ? kMaxGridDim : blockNum); + batch_transpose_with_small_inner_dim_kernel + <<>>(total_row, row, inp_ptr, out_ptr); +} + +template +void batch_transpose(int32_t batch, int32_t row, int32_t col, const T *inp_ptr, + T *out_ptr, cudaStream_t stream) { + constexpr int32_t kTileSize = 32; + + const int32_t tile_num_in_dim0 = (row - 1) / kTileSize + 1; + const int32_t tile_num_in_dim1 = (col - 1) / kTileSize + 1; + const int32_t tile_per_sample = tile_num_in_dim0 * tile_num_in_dim1; + const int32_t total_tile_num = batch * tile_per_sample; + dim3 grid(total_tile_num >= kMaxGridDim ? kMaxGridDim : total_tile_num); + if (col <= 4) { + const int32_t total_row = batch * row; + constexpr int32_t kBlockSize = 256; + switch (col) { + case 2: + LaunchBatchTransposeWithSmallInnerDim( + total_row, row, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr), stream); + break; + case 3: + LaunchBatchTransposeWithSmallInnerDim( + total_row, row, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr), stream); + break; + case 4: + LaunchBatchTransposeWithSmallInnerDim( + total_row, row, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr), stream); + break; + default: + break; + } + } else if (row < 8 || col < 8) { + constexpr int32_t kBlockSize = 64; + dim3 block(kBlockSize); + batch_transpose_kernel + <<>>( + total_tile_num, tile_num_in_dim0, tile_num_in_dim1, tile_per_sample, + row, col, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr)); + } else { + constexpr int32_t kBlockSize = 256; + dim3 block(kBlockSize); + batch_transpose_kernel + <<>>( + total_tile_num, tile_num_in_dim0, tile_num_in_dim1, tile_per_sample, + row, col, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr)); + } +} + // instantiate template void transpose_naive_2d(const float *, float *, int, int, dim3, dim3, cudaStream_t); template void transpose_naive_2d<__half>(const __half *, __half *, int, int, dim3, dim3, cudaStream_t); +template void batch_transpose(int32_t, int32_t, int32_t, const float *, + float *, cudaStream_t); +template void batch_transpose<__half>(int32_t, int32_t, int32_t, const __half *, + __half *, cudaStream_t); } // namespace kernel } // namespace cuda } // namespace brt diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h index bc3fb9f54..3c9281688 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h @@ -26,6 +26,9 @@ template void transpose_naive_2d(const T *input, T *output, int m, int n, dim3 grid, dim3 block, cudaStream_t stream); +template +void batch_transpose(int32_t batch, int32_t row, int32_t col, const T *inp_ptr, + T *out_ptr, cudaStream_t stream); } // namespace kernel } // namespace cuda } // namespace brt diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc index d9929e569..2d34c21ca 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc @@ -40,11 +40,11 @@ using namespace brt::ir; namespace brt { namespace cuda { -template Transpose2D::Transpose2D(const OpAccessor &accessor) { +template +BatchTranspose::BatchTranspose(const OpAccessor &accessor) { auto shape_input = accessor.GetArgShape(0); auto shape_output = accessor.GetArgShape(1); - BRT_ENFORCE(shape_input.size() == 2); BRT_ENFORCE(shape_output == transpose::DeduceOutputShape( shape_input, accessor.GetAttrAsIntArray("permutation"))); @@ -52,18 +52,21 @@ template Transpose2D::Transpose2D(const OpAccessor &accessor) { } template -void Transpose2D::Execute(const T *input, T *output, - cudnnHandle_t /*handle*/, cudaStream_t stream) { +void BatchTranspose::Execute(const T *input, T *output, + cudnnHandle_t /*handle*/, cudaStream_t stream) { auto p = MakeCUDAGridAndBlock(input_shape[1], input_shape[0]); - kernel::transpose_naive_2d(input, output, static_cast(input_shape[0]), - static_cast(input_shape[1]), p.first, - p.second, stream); + int32_t rank = input_shape.size(); + int32_t batch = 1, m = input_shape[rank - 2], n = input_shape[rank - 1]; + for (int32_t i = 0; i < rank - 2; ++i) { + batch *= input_shape[i]; + } + kernel::batch_transpose(batch, m, n, input, output, stream); BRT_CUDA_CHECK(cudaGetLastError()); } // instantiate -template class Transpose2D; -template class Transpose2D<__half>; +template class BatchTranspose; +template class BatchTranspose<__half>; template Transpose4D::Transpose4D(const OpAccessor &accessor) { auto shape_input = accessor.GetArgShape(0); @@ -134,8 +137,17 @@ template class Transpose4D<__half>; template TransposeImpl::TransposeImpl(const OpAccessor &accessor) { auto shape_input = accessor.GetArgShape(0); - if (shape_input.size() == 2) { - this->impl = new Transpose2D(accessor); + bool identity = true; + auto permutation = accessor.GetAttrAsIntArray("permutation"); + for (size_t i = 0; i < permutation.size() - 2; i++) { + if (permutation[i] != i) { + identity = false; + } + } + if (permutation[permutation.size() - 2] == permutation.size() - 1 && + permutation[permutation.size() - 1] == permutation.size() - 2 && + identity) { + this->impl = new BatchTranspose(accessor); } else if (shape_input.size() == 4) { this->impl = new Transpose4D(accessor); } else { diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h index ab36caccb..8d9c11883 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h @@ -34,11 +34,11 @@ template class TransposeBase { }; /** - * Transpose2D + * BatchTranspose */ -template class Transpose2D : public TransposeBase { +template class BatchTranspose : public TransposeBase { public: - explicit Transpose2D(const OpAccessor &accessor); + explicit BatchTranspose(const OpAccessor &accessor); virtual void Execute(const T *input, T *output, cudnnHandle_t handle, cudaStream_t stream) override; diff --git a/runtime/python/src/module.cc b/runtime/python/src/module.cc index 799a812f3..55e3ae1d6 100644 --- a/runtime/python/src/module.cc +++ b/runtime/python/src/module.cc @@ -256,6 +256,30 @@ PYBIND11_MODULE(MODULE_NAME, m) { THROW_ON_FAIL( req.Context().BindArg(offset, reinterpret_cast(ptr))); }) + .def("bind_args", + [](ReqeustContextWithSession &req, py::list offset_and_args) { + for (auto handle : offset_and_args) { + PyObject *obj = handle.ptr(); + if (!PyTuple_Check(obj) || PyTuple_Size(obj) != 2) { + PyErr_SetString(PyExc_TypeError, + "expect pair of offset and arg"); + return; + } + + PyObject *offset = PyTuple_GetItem(obj, 0); + PyObject *arg = PyTuple_GetItem(obj, 1); + if (!PyLong_Check(offset)) { + PyErr_SetString(PyExc_TypeError, "offset should be integer"); + return; + } + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "arg should be integer"); + return; + } + THROW_ON_FAIL(req.Context().BindArg(PyLong_AsSize_t(offset), + PyLong_AsVoidPtr(arg))); + } + }) .def("get_arg", [](ReqeustContextWithSession &req, size_t offset) { void *ptr = req.Context().GetArg(offset); diff --git a/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc b/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc index 5ada9095e..f35bf78bf 100644 --- a/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc +++ b/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc @@ -63,6 +63,38 @@ static void CheckTranspose2D(T *input, T *output, free(h_output); } +template +static void CheckTranspose3D(T *input, T *output, + const std::vector &input_shape) { + T *h_input = + (T *)malloc(input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T)); + T *h_output = + (T *)malloc(input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T)); + cudaMemcpy(h_input, input, + input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_output, output, + input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T), + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + int B = input_shape[0]; + int m = input_shape[1]; + int n = input_shape[2]; + for (int64_t t = 0; t < B; ++t) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + int in_idx = t * m * n + i * n + j; + int out_idx = t * m * n + j * m + i; + EXPECT_EQ(h_output[out_idx], h_input[in_idx]); + } + } + } + + free(h_input); + free(h_output); +} + template static void CheckTranspose4D(T *input, T *output, const std::vector &input_shape, @@ -142,6 +174,8 @@ static void TestTranspose(std::vector shape_input, if (shape_input.size() == 2) { CheckTranspose2D(d_input, d_output, shape_input); + } else if (shape_input.size() == 3) { + CheckTranspose3D(d_input, d_output, shape_input); } else if (shape_input.size() == 4) { CheckTranspose4D(d_input, d_output, shape_input, perm); } else { @@ -150,8 +184,18 @@ static void TestTranspose(std::vector shape_input, } TEST(CUDAOpKerenlTest, TransposeOp) { + // 2D transpose TestTranspose({32, 64}, {64, 32}, {1, 0}); + TestTranspose({2, 1}, {1, 2}, {1, 0}); + TestTranspose({1007, 13}, {13, 1007}, {1, 0}); + TestTranspose({2007, 4339}, {4339, 2007}, {1, 0}); TestTranspose({1000, 512}, {512, 1000}, {1, 0}); + // 3D Batch transpose + TestTranspose({13, 789, 1234}, {13, 1234, 789}, {0, 2, 1}); + TestTranspose({65536, 32, 50}, {65536, 50, 32}, {0, 2, 1}); + TestTranspose({127, 2, 50}, {127, 50, 2}, {0, 2, 1}); + TestTranspose({127, 3, 50}, {127, 50, 3}, {0, 2, 1}); + TestTranspose({127, 4, 50}, {127, 50, 4}, {0, 2, 1}); // NCHW 2 NHWC TestTranspose({10, 20, 30, 40}, {10, 30, 40, 20}, {0, 2, 3, 1}); // NHWC 2 NCHW @@ -159,8 +203,15 @@ TEST(CUDAOpKerenlTest, TransposeOp) { } TEST(CUDAOpKerenlTest, TransposeOpFp16) { + // 2D transpose TestTranspose<__half>({32, 64}, {64, 32}, {1, 0}); + TestTranspose<__half>({2, 1}, {1, 2}, {1, 0}); + TestTranspose<__half>({1007, 13}, {13, 1007}, {1, 0}); + TestTranspose<__half>({2007, 4339}, {4339, 2007}, {1, 0}); TestTranspose<__half>({1000, 512}, {512, 1000}, {1, 0}); + // 3D Batch transpose + TestTranspose<__half>({13, 789, 1234}, {13, 1234, 789}, {0, 2, 1}); + TestTranspose<__half>({65536, 50, 2}, {65536, 2, 50}, {0, 2, 1}); // NCHW 2 NHWC TestTranspose<__half>({10, 20, 30, 40}, {10, 30, 40, 20}, {0, 2, 3, 1}); // NHWC 2 NCHW