diff --git a/compiler/include/byteir/Conversion/Passes.h b/compiler/include/byteir/Conversion/Passes.h index 4f7aa00cd..3758a88c3 100644 --- a/compiler/include/byteir/Conversion/Passes.h +++ b/compiler/include/byteir/Conversion/Passes.h @@ -34,6 +34,7 @@ #include "byteir/Conversion/ToLLVM/ToLLVM.h" #include "byteir/Conversion/ToLinalg/ToLinalg.h" #include "byteir/Conversion/ToPTX/ToPTX.h" +#include "byteir/Conversion/VectorToGPU/GPUVectorToGPU.h" namespace mlir { diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index 01dd52c7f..50b2454a5 100644 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -45,6 +45,16 @@ def GPUToNVVMExt : Pass<"gpu-to-nvvm-ext", "gpu::GPUModuleOp"> { ]; } + +//===----------------------------------------------------------------------===// +// GPUVectorToGPU +//===----------------------------------------------------------------------===// +def GPUVectorToGPU : Pass<"gpu-vector-to-gpu", "func::FuncOp"> { + let summary = "Transform vector.contract to gpu.mma.sync."; + let constructor = "mlir::createGPUVectorToGPUPass()"; +} + + //===----------------------------------------------------------------------===// // ToLinalg //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Conversion/VectorToGPU/GPUVectorToGPU.h b/compiler/include/byteir/Conversion/VectorToGPU/GPUVectorToGPU.h new file mode 100644 index 000000000..de8de26dd --- /dev/null +++ b/compiler/include/byteir/Conversion/VectorToGPU/GPUVectorToGPU.h @@ -0,0 +1,34 @@ +//===- GPUVectorToGPU.h --------------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_CONVERSION_VECTORTOGPU_GPUVECTORTOGPU_H +#define BYTEIR_CONVERSION_VECTORTOGPU_GPUVECTORTOGPU_H + +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +std::unique_ptr> createGPUVectorToGPUPass(); + +} // namespace mlir + +#endif // BYTEIR_CONVERSION_VECTORTOGPU_GPUVECTORTOGPU_H \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/GPU/Passes.h b/compiler/include/byteir/Dialect/GPU/Passes.h index f2107a9a1..0f140698a 100644 --- a/compiler/include/byteir/Dialect/GPU/Passes.h +++ b/compiler/include/byteir/Dialect/GPU/Passes.h @@ -21,8 +21,10 @@ #include "byteir/Dialect/GPU/Transforms/GPUBlockSwizzle.h" #include "byteir/Dialect/GPU/Transforms/GPUDistributeSharedMemoryCopy.h" #include "byteir/Dialect/GPU/Transforms/GPUDistributeToWarp.h" +#include "byteir/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.h" #include "byteir/Dialect/GPU/Transforms/GPUPackSharedMemoryAlloc.h" #include "byteir/Dialect/GPU/Transforms/GPUTensorCoreVectorization.h" +#include "byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h" #include "byteir/Dialect/GPU/Transforms/OptimizeVectorTransfer.h" #include "byteir/Dialect/GPU/Transforms/RemoveTrivialLoops.h" #include "mlir/Pass/Pass.h" diff --git a/compiler/include/byteir/Dialect/GPU/Passes.td b/compiler/include/byteir/Dialect/GPU/Passes.td index 0cc061190..4c38c9d72 100644 --- a/compiler/include/byteir/Dialect/GPU/Passes.td +++ b/compiler/include/byteir/Dialect/GPU/Passes.td @@ -101,6 +101,24 @@ def GPUTensorCoreVectorization : Pass<"gpu-tensorcore-vectorization", "func::Fun def GPUPackSharedMemoryAlloc : Pass<"gpu-pack-shared-memory-alloc", "func::FuncOp"> { let summary = "Analysis shared memory reuse and pack it into i8 alloc."; let constructor = "mlir::createGPUPackSharedMemoryAllocPass()"; + let dependentDialects = [ + "nvgpu::NVGPUDialect", + ]; } +//===----------------------------------------------------------------------===// +// LegalizeGPULaunch +//===----------------------------------------------------------------------===// +def LegalizeGPULaunch : Pass<"legalize-gpu-launch", "func::FuncOp"> { + let summary = "Legalize GPU launch ops."; + let constructor = "mlir::createLegalizeGPULaunchPass()"; +} + +//===----------------------------------------------------------------------===// +// GPUInputSharedMemorySwizzle +//===----------------------------------------------------------------------===// +def GPUInputSharedMemorySwizzle: Pass<"gpu-input-shared-memory-swizzle", "func::FuncOp"> { + let summary = "Swizzle shared memory for gemm's input to improve performance."; + let constructor = "mlir::createGPUInputSharedMemorySwizzlePass()"; +} #endif // BYTEIR_DIALECT_GPU_PASSES diff --git a/compiler/include/byteir/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.h b/compiler/include/byteir/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.h new file mode 100644 index 000000000..6dc6e1507 --- /dev/null +++ b/compiler/include/byteir/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.h @@ -0,0 +1,36 @@ +//===- GPUInputSharedMemorySwizzle.h ---------------------------------*--- +// C++-*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_GPU_TRANSFORMS_GPUINPUTSHAREDMEMORYSWIZZLE_H +#define BYTEIR_DIALECT_GPU_TRANSFORMS_GPUINPUTSHAREDMEMORYSWIZZLE_H + +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +std::unique_ptr> +createGPUInputSharedMemorySwizzlePass(); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_GPU_TRANSFORMS_GPUINPUTSHAREDMEMORYSWIZZLE_H \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h b/compiler/include/byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h new file mode 100644 index 000000000..8d07ec590 --- /dev/null +++ b/compiler/include/byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h @@ -0,0 +1,34 @@ +//===- LegalizeGPULaunch.h ---------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_GPU_TRANSFORMS_LEGALIZEGPULAUNCH_H +#define BYTEIR_DIALECT_GPU_TRANSFORMS_LEGALIZEGPULAUNCH_H + +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +std::unique_ptr> createLegalizeGPULaunchPass(); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_GPU_TRANSFORMS_LEGALIZEGPULAUNCH_H \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h b/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h index a49fab4fa..79d3ef645 100644 --- a/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h +++ b/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h @@ -49,7 +49,43 @@ static constexpr StringRef getCopyRelatedToWorkgroupMemoryMarker() { return "__byteir_copy_related_to_workgroup_memory__"; } -static constexpr StringRef getVectorizeMarker() { return "vectorizeMarker"; } +static constexpr StringRef getVectorizeMarker() { return "vectorize"; } + +static constexpr StringRef getAllocSharedMemoryAMarker() { + return "__byteir_alloca_matrix_a__"; +}; + +static constexpr StringRef getAllocSharedMemoryBMarker() { + return "__byteir_alloca_matrix_b__"; +}; + +static constexpr StringRef getAllocSharedMemoryAccMarker() { + return "__byteir_alloca_accumulator__"; +}; + +static constexpr StringRef getCopyToSharedMemoryAMarker() { + return "__byteir_load_matrix_a__"; +}; + +static constexpr StringRef getCopyToSharedMemoryBMarker() { + return "__byteir_load_matrix_b__"; +}; + +static constexpr StringRef getCopyFromSharedMemoryAccMarker() { + return "__byteir_store_matrix_c__"; +}; + +static constexpr StringRef getMatmulMainLoopMarker() { + return "__byteir_main_loop__"; +} + +constexpr StringRef getLinalgMMALevelAttrName() { + return "__byteir_mma_level__"; +} + +constexpr StringRef getMMAPatternAttrName() { return "__byteir_mma__"; } + +static constexpr StringRef getEpilogueMarker() { return "__byteir_epilogue__"; } std::optional> getGemmTileSize(func::FuncOp funcOp); std::optional> getGemmBlockSize(func::FuncOp funcOp); @@ -72,7 +108,7 @@ bool isMappedToGPUThreads(Operation *op); // Get the ForallOp which mapped to threadblock level in a function. // There should be only one valid ForallOp, otherwise the function will return // std::nullopt; -std::optional getForallOpMappedTo2DBlock(func::FuncOp funcOp); +std::optional getForallOpMappedToBlock(func::FuncOp funcOp); // Set a marker attribute on the operation. // The marker is represented as a UnitAttr. @@ -104,6 +140,8 @@ LogicalResult distributeLinalgOpsWithFilter(IRRewriter &rewriter, Operation *root, linalg::LinalgTilingOptions tilingOptions, linalg_ext::LinalgTransformationFilter filter); + +bool isLinalgOpMatmul(Operation *op); } // namespace mlir #endif // BYTEIR_UTILS_GPU_CODEGEN_UTILS_H \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Linalg/Passes.h b/compiler/include/byteir/Dialect/Linalg/Passes.h index 57455ac72..6f989b4b7 100644 --- a/compiler/include/byteir/Dialect/Linalg/Passes.h +++ b/compiler/include/byteir/Dialect/Linalg/Passes.h @@ -19,6 +19,7 @@ #define BYTEIR_DIALECT_LINALG_PASSES_H #include "byteir/Dialect/Linalg/Transforms/Bufferize.h" +#include "byteir/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.h" #include "byteir/Dialect/Linalg/Transforms/FuseElementwise.h" #include "byteir/Dialect/Linalg/Transforms/LinalgCollapseLoops.h" #include "byteir/Dialect/Linalg/Transforms/LinalgDataPlace.h" diff --git a/compiler/include/byteir/Dialect/Linalg/Passes.td b/compiler/include/byteir/Dialect/Linalg/Passes.td index dd8d2e098..33f73b1bd 100644 --- a/compiler/include/byteir/Dialect/Linalg/Passes.td +++ b/compiler/include/byteir/Dialect/Linalg/Passes.td @@ -198,4 +198,13 @@ def LinalgGeneralizationExt : Pass<"linalg-generalization-ext", "func::FuncOp"> ]; } +//===----------------------------------------------------------------------===// +// CanonicalizeMatmulEpilogue +//===----------------------------------------------------------------------===// + +def CanonicalizeMatmulEpilogue : Pass<"canonicalize-matmul-epilogue", "func::FuncOp"> { + let summary = "Canonicalize matmul epilogue"; + let constructor = "mlir::createCanonicalizeMatmulEpiloguePass()"; +} + #endif // BYTEIR_DIALECT_LINALG_PASSES \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.h b/compiler/include/byteir/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.h new file mode 100644 index 000000000..48e336f18 --- /dev/null +++ b/compiler/include/byteir/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.h @@ -0,0 +1,35 @@ +//===- LinalgPromote.h --------------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_LINALG_TRANSFORMS_CANONICALIZEMATMULEPILOGUE_H +#define BYTEIR_DIALECT_LINALG_TRANSFORMS_CANONICALIZEMATMULEPILOGUE_H + +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +std::unique_ptr> +createCanonicalizeMatmulEpiloguePass(); + +} // namespace mlir + +#endif // BYTEIR_DIALECT_LINALG_TRANSFORMS_CANONICALIZEMATMULEPILOGUE_H \ No newline at end of file diff --git a/compiler/include/byteir/Pipelines/GPU/GemmCodegen.h b/compiler/include/byteir/Pipelines/GPU/GemmCodegen.h new file mode 100644 index 000000000..520f0208a --- /dev/null +++ b/compiler/include/byteir/Pipelines/GPU/GemmCodegen.h @@ -0,0 +1,88 @@ +//===- GemmCodegen.h -----------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_PIPELINES_GPU_GEMM_CODEGEN_H +#define BYTEIR_PIPELINES_GPU_GEMM_CODEGEN_H + +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { + +struct GPUGemmCodegenConfigOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_tile_gemm")}; + ListOption tileSizeConfig{ + *this, "tile-size-config", + llvm::cl::desc("An optional tile size config for tile matmul op.")}; + ListOption workgroupSize{ + *this, "workgroup-size", + llvm::cl::desc("An optional workgroup size config for tile matmul op.")}; + Option stages{ + *this, "stages", llvm::cl::desc("An optional stages for tile matmul op."), + llvm::cl::init(3)}; +}; + +struct GPUGemmGeneralOptions + : public PassPipelineOptions { + Option funcAnchor{ + *this, "func-anchor", + llvm::cl::desc( + "An optional Unit attribute anchoring on target functions."), + llvm::cl::init("")}; + Option annotatePrefix{ + *this, "annotate-prefix", + llvm::cl::desc("An optional annotate prefix attribute on target ops."), + llvm::cl::init("__byteir_gpu_tile_gemm")}; +}; + +void createGPUTileGemmTransform(OpPassManager &pm, + const GPUGemmGeneralOptions &options); + +void createGPUAddGemmCodegenLoweringConfigTransform( + OpPassManager &pm, const GPUGemmCodegenConfigOptions &options); + +void createGPUPipeliningTransform(OpPassManager &pm, + const GPUGemmGeneralOptions &options); + +inline void registerGPUGemmCodegenPipelines() { + PassPipelineRegistration( + "insert-gpu-tile-gemm-transform", + "Insert transformation IR to tile linalg matmul op", + createGPUTileGemmTransform); + PassPipelineRegistration( + "insert-gpu-gemm-codegen-transform", + "Insert transformation IR to tile linalg matmul op", + createGPUAddGemmCodegenLoweringConfigTransform); + PassPipelineRegistration( + "insert-gpu-pipelining-transform", + "Insert transformation IR to tile linalg matmul op", + createGPUPipeliningTransform); +} + +} // namespace mlir + +#endif // BYTEIR_PIPELINES_GPU_GEMM_CODEGEN_H diff --git a/compiler/include/byteir/Pipelines/HloFusionOpt.h b/compiler/include/byteir/Pipelines/HloFusionOpt.h index c918bcd3b..31b5437b3 100644 --- a/compiler/include/byteir/Pipelines/HloFusionOpt.h +++ b/compiler/include/byteir/Pipelines/HloFusionOpt.h @@ -47,6 +47,10 @@ struct HloFusionOptPipelineOptions *this, "outline-cat-op", llvm::cl::desc("whether to outline cat ops and AIT as an backend"), llvm::cl::init(false)}; + Option outlineDotOp{ + *this, "outline-dot-op", + llvm::cl::desc("whether to outline dot ops and use gemm codegen"), + llvm::cl::init(false)}; }; void createHloFusionOptPipeline(OpPassManager &pm, diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index c2d351a87..27a4f32b7 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -13,3 +13,4 @@ add_subdirectory(ToLinalg) add_subdirectory(ToLLVM) add_subdirectory(ToPTX) add_subdirectory(LcclToByre) +add_subdirectory(VectorToGPU) diff --git a/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp b/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp index fda8a3af4..7af99d824 100644 --- a/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp +++ b/compiler/lib/Conversion/FuncToByre/FuncToByre.cpp @@ -129,6 +129,14 @@ class ConvertGPULaunchFuncToByrePattern computeOp->setAttr("BlockSize.y", rewriter.getI32IntegerAttr(by)); computeOp->setAttr("BlockSize.z", rewriter.getI32IntegerAttr(bz)); + auto sharedMemorySize = launchOp.getDynamicSharedMemorySize(); + if (sharedMemorySize) { + auto sharedMemorySizeValue = + cast(sharedMemorySize.getDefiningOp()); + IntegerAttr smem = cast(sharedMemorySizeValue.getValue()); + computeOp->setAttr("DynamicSharedMemorySize", smem); + } + if (useBarePtrCallConv) { computeOp->setAttr(byre::getKernelCallConventionAttrName(), rewriter.getStringAttr("bare_ptr")); diff --git a/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp b/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp index 61f0ac02e..d8d6674b9 100644 --- a/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp +++ b/compiler/lib/Conversion/GPUToNVVM/GPUToNVVM.cpp @@ -39,6 +39,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" @@ -49,6 +50,9 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -64,6 +68,75 @@ using namespace mlir::NVVM; namespace { +static void ConvertToDynamicSharedMemory(GPUModuleOp moduleOp) { + SymbolTableCollection symbolTableCollection; + // Collect all the addressOfOps to static shared memory globals. + SmallVector addressOfOps; + moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) { + // Check that the global associated with this addressOfOp has shared memory + // space. + if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3) + addressOfOps.push_back(addressOfOp); + }); + if (addressOfOps.size() == 0) + return; + + uint32_t numberOfBytes = 0; + llvm::SmallDenseMap globalMemoryOffsetMap; + for (auto addressOfOp : addressOfOps) { + uint32_t offset = 0; + auto globalOp = addressOfOp.getGlobal(symbolTableCollection); + if (globalMemoryOffsetMap.count(globalOp)) { + offset = globalMemoryOffsetMap[globalOp]; + } else { + offset = numberOfBytes; + if (std::optional alignment = globalOp.getAlignment()) { + offset = llvm::alignTo(offset, *alignment); + } + globalMemoryOffsetMap[globalOp] = offset; + auto thisarray = globalOp.getType(); + DataLayout dataLayout = DataLayout::closest(addressOfOp); + numberOfBytes = offset + dataLayout.getTypeSizeInBits(thisarray) / 8; + } + } + + // Check if numberOfBytes is less than 48 * 1024 + if (numberOfBytes < 48 * 1024) { + return; + } + + OpBuilder builder(moduleOp); + builder.setInsertionPoint(&moduleOp.front()); + auto type = + LLVM::LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), 0); + LLVM::GlobalOp global = builder.create( + moduleOp.getLoc(), type, /*isConstant=*/false, LLVM::Linkage::External, + "__dynamic_shared_memory__", Attribute(), + /*alignment=*/16, /*addr_space=*/3); + + // Replace the addressOfOps with correctly offseted pointers to dynamic + // shared memory. + for (auto addressOfOp : addressOfOps) { + uint32_t offset = + globalMemoryOffsetMap[addressOfOp.getGlobal(symbolTableCollection)]; + auto loc = addressOfOp.getLoc(); + builder.setInsertionPoint(addressOfOp); + LLVM::AddressOfOp globalPtr = + builder.create(loc, global); + Value zero = builder.create( + loc, IntegerType::get(builder.getContext(), 64), + builder.getI64IntegerAttr(0)); + Value offsetValue = builder.create( + loc, IntegerType::get(builder.getContext(), 64), + builder.getI64IntegerAttr(offset)); + Value shiftedPtr = builder.create( + loc, globalPtr.getType(), global.getGlobalType(), globalPtr, + ValueRange({zero, offsetValue})); + addressOfOp.replaceAllUsesWith(shiftedPtr); + addressOfOp.erase(); + } +} + template struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { public: @@ -253,6 +326,30 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase { // Apply in-dialect lowering. In-dialect lowering will replace // ops which need to be lowered further, which is not supported by a // single conversion pass. + // Run Vector -> Vector transformations ahead of conversion to LLVM. + { + RewritePatternSet patterns(&getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + vector::populateVectorBroadcastLoweringPatterns(patterns); + vector::populateVectorContractLoweringPatterns( + patterns, + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + vector::populateVectorMaskOpLoweringPatterns(patterns); + // We currently always use 64 bit indices, thus ensure the bit width of + // the mask compare is consistent. + vector::populateVectorMaskMaterializationPatterns( + patterns, /*force32BitVectorIndices=*/false); + vector::populateVectorShapeCastLoweringPatterns(patterns); + // TODO: doubtful that the "default" does what one want here, it is likely + // better to use something else. + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransformsOptions()); + vector::populateVectorTransferLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { + return signalPassFailure(); + } + } { RewritePatternSet patterns(m.getContext()); populateGpuRewritePatterns(patterns); @@ -289,13 +386,19 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase { converter.addConversion([&](gpu::MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); + // Convert dummy tokens. + converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { + return converter.convertType(IntegerType::get(type.getContext(), 32)); + }); RewritePatternSet llvmPatterns(m.getContext()); arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); + populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); #if 0 // FIXME: enable if gpu arch >= sm_75 @@ -325,6 +428,7 @@ struct GPUToNVVMExtPass : public GPUToNVVMExtBase { } } }); + ConvertToDynamicSharedMemory(m); } }; diff --git a/compiler/lib/Conversion/PassDetail.h b/compiler/lib/Conversion/PassDetail.h index 5306989d8..41c489fac 100644 --- a/compiler/lib/Conversion/PassDetail.h +++ b/compiler/lib/Conversion/PassDetail.h @@ -81,6 +81,10 @@ namespace mhlo { class MhloDialect; } // namespace mhlo +namespace nvgpu { +class NVGPUDialect; +} // namespace nvgpu + namespace NVVM { class NVVMDialect; } // namespace NVVM diff --git a/compiler/lib/Conversion/VectorToGPU/CMakeLists.txt b/compiler/lib/Conversion/VectorToGPU/CMakeLists.txt new file mode 100644 index 000000000..ea200301b --- /dev/null +++ b/compiler/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_byteir_conversion_library(ByteIRVectorToGPU + GPUVectorToGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Conversion/VectorToGPU + + + DEPENDS + ByteIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMemRefDialect + MLIRGPUDialect + MLIRTransforms + ) diff --git a/compiler/lib/Conversion/VectorToGPU/GPUVectorToGPU.cpp b/compiler/lib/Conversion/VectorToGPU/GPUVectorToGPU.cpp new file mode 100644 index 000000000..feae05078 --- /dev/null +++ b/compiler/lib/Conversion/VectorToGPU/GPUVectorToGPU.cpp @@ -0,0 +1,109 @@ +//===- GPUVectorToGPU.cpp ------------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#include "byteir/Conversion/VectorToGPU/GPUVectorToGPU.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "../PassDetail.h" + +using namespace mlir; + +#define DEBUG_TYPE "gpuvector-to-gpu" + +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { +struct GPUVectorToGPUPass : public GPUVectorToGPUBase { + + void getDependentDialects(DialectRegistry ®istry) { + registry.insert(); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + + RewritePatternSet patterns(funcOp.getContext()); + mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + populatePrepareVectorToMMAPatterns(patterns, /*targetMmaSync*/ true); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + IRRewriter rewriter(&getContext()); + if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) { + return signalPassFailure(); + } + RewritePatternSet f32ToTF32Patterns(funcOp.getContext()); + // enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 }; + // Collect patterns to convert mma.sync on f32 input and rewrite + // to use tensor cores with user provided level of accuracy: + // (a) tf32 (1 mma.sync per warp-level matrix-multiply-accumulate) + // (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate) + // Typically, tf32 tensor core acceleration comes at a cost + // of accuracy from missing precision bits. While f32 has 23 precision + // bits, tf32 has only 10 precision bits. tf32x3 aims to recover the + // precision bits by spliting each operand into two tf32 values + // Note: we only support tf32 for now, because tf32x3 is not supported in + // upstream + // The trick is very simple + // a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x + // b_small + a_small x b_big + // big = convert_to_tf32(fp32) + // small = convert_to_tf32(fp32 - big) + // a_small x b_small is discarded because they are too small. + nvgpu::populateMmaSyncF32ToTF32Patterns(f32ToTF32Patterns, + nvgpu::MmaSyncF32Lowering::TF32); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(f32ToTF32Patterns)))) { + return signalPassFailure(); + } + // As we do linalg prefetch first, so problem maybe occurs here. So we + // didn't need to createAsyncGroups to support gpu async copy lowering. In + // this step, we lowering transfer read into cp.async + nvgpu::createAsyncGroups(rewriter, funcOp, /* bypassL1 */ true); + + // Last step: + // Fold subview on memory copy to enable the application of shared memory + // swizzling optimization. + RewritePatternSet pattern(funcOp.getContext()); + memref::populateFoldMemRefAliasOpPatterns(pattern); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(pattern)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> mlir::createGPUVectorToGPUPass() { + return std::make_unique(); +} diff --git a/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt index 74d693559..09aea7167 100644 --- a/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/GPU/Transforms/CMakeLists.txt @@ -1,7 +1,9 @@ add_mlir_dialect_library(ByteIRGPUPasses + LegalizeGPULaunch.cpp GPUBlockSwizzle.cpp GPUDistributeSharedMemoryCopy.cpp GPUDistributeToWarp.cpp + GPUInputSharedMemorySwizzle.cpp GPUTensorCoreVectorization.cpp GPUPackSharedMemoryAlloc.cpp OptimizeVectorTransfer.cpp @@ -17,11 +19,15 @@ add_mlir_dialect_library(ByteIRGPUPasses ByteIRGPUPassIncGen ByteIRUtils MLIRGPUDialect + MLIRNVGPUDialect + MLIRNVGPUTransforms LINK_LIBS PUBLIC ByteIRUtils MLIRIR MLIRGPUDialect + MLIRNVGPUDialect + MLIRNVGPUTransforms MLIRMemRefDialect MLIRSupport ) diff --git a/compiler/lib/Dialect/GPU/Transforms/GPUBlockSwizzle.cpp b/compiler/lib/Dialect/GPU/Transforms/GPUBlockSwizzle.cpp index bbaa10562..92d676284 100644 --- a/compiler/lib/Dialect/GPU/Transforms/GPUBlockSwizzle.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/GPUBlockSwizzle.cpp @@ -94,9 +94,9 @@ makeSwizzledIdsInTritonWay(Location loc, OpBuilder &b, Value x, Value y, return {pidN, pidM}; } -// Only support 2d grid. -static LogicalResult reorderForallOpMappedTo2DBlock(scf::ForallOp forallOp, - unsigned swizzleLogTile) { +// Only support 2d or 3d grid. +static LogicalResult reorderForallOpMappedToBlock(scf::ForallOp forallOp, + unsigned swizzleLogTile) { unsigned swizzleTile = 1 << swizzleLogTile; OpBuilder b(forallOp); @@ -110,30 +110,29 @@ static LogicalResult reorderForallOpMappedTo2DBlock(scf::ForallOp forallOp, auto loops = newforallOp.getInductionVars(); auto mapping = newforallOp.getMappingAttr().getValue(); - Value workgroupIdX, workgroupIdY, workgroupCountX, workgroupCountY; - // if mapping[0] == gpu::MappingId::DimX, workgroupIdx = loop[0], otherwise - // workgroupIdx = loop[1] - int64_t dimXMapping = static_cast(gpu::MappingId::DimX); - if (mapping[0].cast().getMappingId() == - dimXMapping) { - workgroupIdX = loops[0]; - workgroupIdY = loops[1]; - workgroupCountX = gridSize[0]; - workgroupCountY = gridSize[1]; - } else { - workgroupIdX = loops[1]; - workgroupIdY = loops[0]; - workgroupCountX = gridSize[1]; - workgroupCountY = gridSize[0]; + SmallVector workgroupCounts(3); + SmallVector workgroupIds(3); + + SmallVector originWorkgroupIds(3); + + for (auto [mappingId, workgroupCount, workgroupId, originWorkgroupId] : + llvm::zip(mapping, gridSize, loops, originLoops)) { + auto mappingIdInt = + mappingId.cast().getMappingId(); + workgroupCounts[mappingIdInt] = workgroupCount; + workgroupIds[mappingIdInt] = workgroupId; + originWorkgroupIds[mappingIdInt] = originWorkgroupId; } auto [swizzledIdX, swizzledIdY] = makeSwizzledIdsInTritonWay( - newforallOp.getLoc(), b, workgroupIdX, workgroupIdY, workgroupCountX, - workgroupCountY, swizzleTile); + newforallOp.getLoc(), b, workgroupIds[0], workgroupIds[1], + workgroupCounts[0], workgroupCounts[1], swizzleTile); IRMapping bvm; - bvm.map(originLoops[0], swizzledIdX); - bvm.map(originLoops[1], swizzledIdY); + bvm.map(originWorkgroupIds[0], swizzledIdX); + bvm.map(originWorkgroupIds[1], swizzledIdY); + if (mapping.size() == 3) + bvm.map(originWorkgroupIds[2], workgroupIds[2]); for (auto &op : forallOp.getBody()->getOperations()) { b.clone(op, bvm); } @@ -154,13 +153,13 @@ struct GPUBlockSwizzlePass : public GPUBlockSwizzleBase { return signalPassFailure(); } - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional.has_value()) { return signalPassFailure(); } scf::ForallOp forallOp = *forallOpOptional; - if (failed(reorderForallOpMappedTo2DBlock(forallOp, swizzleLogTile))) { + if (failed(reorderForallOpMappedToBlock(forallOp, swizzleLogTile))) { return signalPassFailure(); } } diff --git a/compiler/lib/Dialect/GPU/Transforms/GPUDistributeSharedMemoryCopy.cpp b/compiler/lib/Dialect/GPU/Transforms/GPUDistributeSharedMemoryCopy.cpp index af39511bd..5fa3b9e9a 100644 --- a/compiler/lib/Dialect/GPU/Transforms/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/GPUDistributeSharedMemoryCopy.cpp @@ -525,7 +525,7 @@ class GPUDistributeSharedMemoryCopyPass } SmallVector workgroupSize = optionalWorkgroupSize.value(); - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional.has_value()) { return signalPassFailure(); } diff --git a/compiler/lib/Dialect/GPU/Transforms/GPUDistributeToWarp.cpp b/compiler/lib/Dialect/GPU/Transforms/GPUDistributeToWarp.cpp index 0f7f0acee..28c635d2d 100644 --- a/compiler/lib/Dialect/GPU/Transforms/GPUDistributeToWarp.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/GPUDistributeToWarp.cpp @@ -54,7 +54,6 @@ namespace { static constexpr int32_t kWarpSize = 32; static constexpr int32_t kNumGPUDims = 3; -static constexpr StringRef getVectorizeMarker() { return "vectorize"; } /// Filters out dimensions in `parallelLoops` that have unit range in /// `loopRanges`. @@ -93,13 +92,20 @@ static SmallVector calculateDistributedTileSize(ArrayRef numElements, OpBuilder &builder, Operation *operation) { func::FuncOp funcOp = operation->getParentOfType(); - auto blockTileSizeOptional = getGemmTileSize(funcOp); - if (!blockTileSizeOptional.has_value()) + auto gemmTileSizeOptional = getGemmTileSize(funcOp); + if (!gemmTileSizeOptional.has_value()) return {}; - SmallVector blockTileSize = getGemmTileSize(funcOp).value(); + + SmallVector gemmTileSize = gemmTileSizeOptional.value(); + SmallVector blockTileSize; SmallVector tileSizesVal; auto linalgOp = cast(operation); + if (linalgOp.getNumParallelLoops() == 3) { // bmm + blockTileSize = {0, gemmTileSize[0], gemmTileSize[1]}; + } else { // matmul + blockTileSize = {gemmTileSize[0], gemmTileSize[1]}; + } // Use partitionedLoop to know what loop needs to be distributed. auto partitionedLoops = getPartitionableLoops(linalgOp, std::nullopt); @@ -118,6 +124,7 @@ calculateDistributedTileSize(ArrayRef numElements, OpBuilder &builder, for (unsigned depth : partitionedLoops) { if (depth >= blockTileSize.size()) continue; + // tileSize means a warp should handle. tileSizesVal[depth] = builder.create( operation->getLoc(), llvm::divideCeil(blockTileSize[depth], distributedDim[idIdx++])); @@ -164,8 +171,7 @@ static LogicalResult tileToWarp(scf::ForallOp forallOp, .addFilter([](Operation *op) { // linalg.copy will be handled by GPUDistributeSharedMemoryCopy pass. // So we should not tile it here. - return success( - isa(op)); + return success(isa(op) || isLinalgOpMatmul(op)); }) .setMatchByDefault(); return distributeLinalgOpsWithFilter(forallOp, tilingOptions, filter); @@ -195,7 +201,7 @@ struct GPUDistributeToWarpPass SmallVector workgroupSize = optionalWorkgroupSize.value(); - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional.has_value()) { return signalPassFailure(); } diff --git a/compiler/lib/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.cpp b/compiler/lib/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.cpp new file mode 100644 index 000000000..80d419431 --- /dev/null +++ b/compiler/lib/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.cpp @@ -0,0 +1,73 @@ +//===- GPUInputSharedMemorySwizzle.cpp -------------------------*--- C++-*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/GPU/Transforms/GPUInputSharedMemorySwizzle.h" +#include "byteir/Dialect/GPU/Transforms/Utils.h" + +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; + +namespace { + +static void swizzleSharedMemory(scf::ForallOp forallOp) { + SmallVector shmAllocOps; + forallOp->walk([&](memref::AllocOp allocOp) { + // Only apply it to shared memory of input operands. + if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) { + return; + } + if (hasMarker(allocOp, {getAllocSharedMemoryAMarker(), + getAllocSharedMemoryBMarker()})) { + shmAllocOps.push_back(allocOp); + } + }); + for (auto allocOp : shmAllocOps) { + (void)nvgpu::optimizeSharedMemoryReadsAndWrites(forallOp, + allocOp.getMemref()); + } +} + +struct GPUInputSharedMemorySwizzlePass + : public GPUInputSharedMemorySwizzleBase { + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + OpBuilder builder(funcOp.getBody()); + + if (!hasGemmTileConfig(funcOp)) { + return; + } + + auto forallOpOptional = getForallOpMappedToBlock(funcOp); + if (!forallOpOptional.has_value()) { + return signalPassFailure(); + } + scf::ForallOp forallOp = *forallOpOptional; + swizzleSharedMemory(forallOp); + } +}; +} // namespace + +std::unique_ptr> +mlir::createGPUInputSharedMemorySwizzlePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/GPU/Transforms/GPUPackSharedMemoryAlloc.cpp b/compiler/lib/Dialect/GPU/Transforms/GPUPackSharedMemoryAlloc.cpp index ada584d85..fc5fbaa52 100644 --- a/compiler/lib/Dialect/GPU/Transforms/GPUPackSharedMemoryAlloc.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/GPUPackSharedMemoryAlloc.cpp @@ -45,7 +45,7 @@ using namespace mlir; namespace { static int64_t getAllocSize(Operation *op, DataLayout &dataLayout) { - auto allocOp = cast(op); + auto allocOp = cast(op); int64_t numElements = allocOp.getType().getNumElements(); return (dataLayout.getTypeSizeInBits(allocOp.getType().getElementType()) * numElements) / @@ -59,7 +59,7 @@ using AliasGroup = SmallVector; void analyseAllocsForPacking(scf::ForallOp forallOp, ArrayRef allocs, SmallVector &aliasGroups) { - // Represent of a group of allocations with overlapping liverange and the + // Represent of a group of allocOptions with overlapping liverange and the // liveness of the overall group. struct AllocGroup { SmallVector allocs; @@ -152,7 +152,7 @@ void packAllocs(OpBuilder &builder, scf::ForallOp forallOp, MemRefType allocType = MemRefType::get({maxAlloc}, builder.getI8Type(), AffineMap(), memorySpace); Value packedAlloc = - builder.create(forallOp.getLoc(), allocType); + builder.create(forallOp.getLoc(), allocType); for (size_t i = 0; i < aliasGroups.size(); i++) { int64_t offset = 0; for (Operation *alloc : aliasGroups[i]) { @@ -202,12 +202,48 @@ void sinkOpsInCFG(const SmallVector &allocs, } } +static void addBarrier(scf::ForallOp forallOp, Operation *alloc, + ArrayRef aliasGroup) { + Block *entryBlock = forallOp.getBody(); + bool needBarrier = false; + if (alloc->getBlock() != entryBlock) { + needBarrier = true; + } else { + for (Operation &op : entryBlock->getOperations()) { + if (&op == alloc) + break; + if (op.getNumRegions() != 0) { + needBarrier = true; + break; + } + if (isa(&op) && !llvm::is_contained(aliasGroup, &op)) { + needBarrier = true; + break; + } + } + } + if (!needBarrier) + return; + OpBuilder builder(alloc); + // TODO: make it a option if needed. + bool hasAsyncCopies = true; + if (hasAsyncCopies) { + Value groupToken = builder.create( + forallOp.getLoc(), + nvgpu::DeviceAsyncTokenType::get(forallOp.getContext()), + SmallVector()); + builder.create(forallOp.getLoc(), groupToken, + builder.getI32IntegerAttr(0)); + } + builder.create(alloc->getLoc()); +} + void packSharedMemoryAlloc(scf::ForallOp forallOp) { DominanceInfo dominators(forallOp); SmallVector allocs; - forallOp.walk([&](memref::AllocaOp alloca) { - if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(alloca.getType())) { - allocs.push_back(alloca); + forallOp.walk([&](memref::AllocOp allocOp) { + if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) { + allocs.push_back(allocOp); } }); // First sink the alloc as low as possible in the CFG. @@ -215,8 +251,14 @@ void packSharedMemoryAlloc(scf::ForallOp forallOp) { SmallVector aliasGroups; analyseAllocsForPacking(forallOp, allocs, aliasGroups); // If there is 1 or less alias group there is nothing to do. - if (aliasGroups.size() <= 1) + if (aliasGroups.size() <= 1) { return; + } + for (size_t i = 0; i < aliasGroups.size(); i++) { + for (Operation *alloc : aliasGroups[i]) { + addBarrier(forallOp, alloc, aliasGroups[i]); + } + } OpBuilder builder(forallOp.getContext()); packAllocs(builder, forallOp, aliasGroups); @@ -228,9 +270,9 @@ struct GPUPackSharedMemoryAllocPass void runOnOperation() override { auto funcOp = getOperation(); if (!hasGemmTileConfig(funcOp)) { - return signalPassFailure(); + return; } - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional.has_value()) { return signalPassFailure(); } diff --git a/compiler/lib/Dialect/GPU/Transforms/GPUTensorCoreVectorization.cpp b/compiler/lib/Dialect/GPU/Transforms/GPUTensorCoreVectorization.cpp index aad361cf2..f77d16b66 100644 --- a/compiler/lib/Dialect/GPU/Transforms/GPUTensorCoreVectorization.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/GPUTensorCoreVectorization.cpp @@ -54,16 +54,15 @@ using namespace mlir::linalg; namespace { -// static void vectorizeLinalgOps(scf::ForallOp forallOp) { -static void vectorizeLinalgOps(func::FuncOp forallOp) { +static void vectorizeLinalgOps(scf::ForallOp forallOp) { MLIRContext *context = forallOp.getContext(); IRRewriter rewriter(context); forallOp.walk([&](Operation *op) { - if (!isa( + if (hasAnyLinalgTransformationMarker(op, ArrayRef{getVectorizeMarker()}) && + isa( op)) { - return WalkResult::advance(); + (void)linalg::vectorize(rewriter, op); } - (void)linalg::vectorize(rewriter, op); return WalkResult::advance(); }); } @@ -82,7 +81,7 @@ gpuMmaUnrollOrder(vector::ContractionOp contract) { llvm::SmallDenseSet dims; for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { - dims.insert(expr.cast().getPosition()); + dims.insert(cast(expr).getPosition()); } // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { @@ -327,7 +326,7 @@ struct GPUTensorCoreVectorizationPass if (!hasGemmTileConfig(funcOp)) { return signalPassFailure(); } - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional.has_value()) { return signalPassFailure(); } @@ -335,10 +334,10 @@ struct GPUTensorCoreVectorizationPass { // Step 1(a). Vectorize (linalg to vector). - vectorizeLinalgOps(funcOp); + vectorizeLinalgOps(forallOp); LLVM_DEBUG({ llvm::dbgs() << "\nAfter vectorizeLinalgOps:\n"; - funcOp->dump(); + forallOp->dump(); }); RewritePatternSet contractionPatterns(context); @@ -354,6 +353,17 @@ struct GPUTensorCoreVectorizationPass funcOp->dump(); }); + // Step 1(b). Fold arithmetic extensions into vector contraction ops. + // Linalg to vector conversion introduces arithmetic extensions on the + // operands of vector contraction ops for mixed precision computation. + // This pattern folds the arithmetic extensions into the vector.contract. + RewritePatternSet foldArithExtPatterns(context); + vector::populateFoldArithExtensionPatterns(foldArithExtPatterns); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(foldArithExtPatterns)))) { + return signalPassFailure(); + } + // Step 3. Prepare vector operations to be lowered to native tensor core // operations (nvgpu.mmasync, nvgpu.ldmatrix). RewritePatternSet vectorContractPatterns(funcOp.getContext()); diff --git a/compiler/lib/Dialect/GPU/Transforms/LegalizeGPULaunch.cpp b/compiler/lib/Dialect/GPU/Transforms/LegalizeGPULaunch.cpp new file mode 100644 index 000000000..cc9ccb99a --- /dev/null +++ b/compiler/lib/Dialect/GPU/Transforms/LegalizeGPULaunch.cpp @@ -0,0 +1,83 @@ +//===- LegalizeGPULaunch.cpp --------------------------------------------*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/GPU/Transforms/LegalizeGPULaunch.h" +#include "byteir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Visitors.h" +#include + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; + +namespace { + +static int64_t getSharedMemorySizeInGPULaunch(gpu::LaunchOp op) { + int64_t sharedMemSizeInBytes = 0; + op->walk([&](memref::AllocaOp allocaOp) { + if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(allocaOp.getType())) { + + sharedMemSizeInBytes += + allocaOp.getType().getNumElements() * + allocaOp.getType().getElementType().getIntOrFloatBitWidth() / 8; + } + }); + op->walk([&](memref::AllocOp allocOp) { + if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) { + sharedMemSizeInBytes += + allocOp.getType().getNumElements() * + allocOp.getType().getElementType().getIntOrFloatBitWidth() / 8; + } + }); + return sharedMemSizeInBytes; +} + +struct LegalizeGPULaunchPass + : public LegalizeGPULaunchBase { + LegalizeGPULaunchPass() : LegalizeGPULaunchBase() {} + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + OpBuilder builder(funcOp.getContext()); + auto launchOps = funcOp.getOps(); + for (auto launchOp : launchOps) { + int64_t sharedMemSize = getSharedMemorySizeInGPULaunch(launchOp); + if (sharedMemSize < 48 * 1024) // 48kB + continue; + builder.setInsertionPoint(launchOp); + Value sharedMemSizeValue = builder.create( + launchOp.getLoc(), builder.getI32IntegerAttr(sharedMemSize)); + if (!launchOp.getDynamicSharedMemorySizeMutable().empty()) { + continue; + } + launchOp.getDynamicSharedMemorySizeMutable().append( + ValueRange{sharedMemSizeValue}); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::createLegalizeGPULaunchPass() { + return std::make_unique(); +} diff --git a/compiler/lib/Dialect/GPU/Transforms/OptimizeVectorTransfer.cpp b/compiler/lib/Dialect/GPU/Transforms/OptimizeVectorTransfer.cpp index a507adef6..e2106f977 100644 --- a/compiler/lib/Dialect/GPU/Transforms/OptimizeVectorTransfer.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/OptimizeVectorTransfer.cpp @@ -63,7 +63,7 @@ struct OptimizeVectorTransferPass return; } - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional) return; auto forallOp = forallOpOptional.value(); diff --git a/compiler/lib/Dialect/GPU/Transforms/PassDetail.h b/compiler/lib/Dialect/GPU/Transforms/PassDetail.h index 139a86534..ee30d3e9d 100644 --- a/compiler/lib/Dialect/GPU/Transforms/PassDetail.h +++ b/compiler/lib/Dialect/GPU/Transforms/PassDetail.h @@ -64,6 +64,14 @@ namespace vector { class VectorDialect; } +namespace NVVM { +class NVVMDialect; +} // namespace NVVM + +namespace nvgpu { +class NVGPUDialect; +} // namespace nvgpu + namespace transform { class TransformDialect; } // namespace transform diff --git a/compiler/lib/Dialect/GPU/Transforms/RemoveTrivialLoops.cpp b/compiler/lib/Dialect/GPU/Transforms/RemoveTrivialLoops.cpp index 77e4e4d9b..e25f818c9 100644 --- a/compiler/lib/Dialect/GPU/Transforms/RemoveTrivialLoops.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/RemoveTrivialLoops.cpp @@ -127,7 +127,7 @@ class RemoveTrivialLoopsPass final } SmallVector workgroupSize = blockSizeOptional.value(); - auto forallOpOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOpOptional = getForallOpMappedToBlock(funcOp); if (!forallOpOptional) return; auto forallOp = forallOpOptional.value(); diff --git a/compiler/lib/Dialect/GPU/Transforms/Utils.cpp b/compiler/lib/Dialect/GPU/Transforms/Utils.cpp index 81816a08c..3603385ab 100644 --- a/compiler/lib/Dialect/GPU/Transforms/Utils.cpp +++ b/compiler/lib/Dialect/GPU/Transforms/Utils.cpp @@ -44,7 +44,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; - +using namespace llvm; namespace mlir { //===----------------------------------------------------------------------===// @@ -161,11 +161,10 @@ bool isMappedToGPUThreads(Operation *op) { // Get the scf.forall op mapped to threadblock. // Just for gemm codegen for now. //===----------------------------------------------------------------------===// -std::optional getForallOpMappedTo2DBlock(func::FuncOp funcOp) { +std::optional getForallOpMappedToBlock(func::FuncOp funcOp) { std::vector forallOps; funcOp.walk([&](scf::ForallOp forallOp) { - if (isMappedToGPUBlocks(forallOp) && - forallOp.getMappingAttr().getValue().size() == 2) + if (isMappedToGPUBlocks(forallOp)) forallOps.push_back(forallOp); }); if (forallOps.size() != 1) { @@ -294,4 +293,46 @@ bool hasAnyLinalgTransformationMarker(Operation *op, return attr.getValue() == markerValue; })); } + +// a helper function to judge if a linalg generic op do matmul +// Result should not be transposed +bool isLinalgOpMatmul(Operation *op) { + if (!llvm::isa(op)) + return false; + + linalg::LinalgOp linalgOp = cast(op); + if (!(isa(linalgOp) || + isa(linalgOp))) { + if (!(linalg::isaContractionOpInterface(linalgOp) && + linalgOp.getNumParallelLoops() >= 2 && + linalgOp.getNumParallelLoops() <= 3)) { + return false; + } + // If this is not a named op matmul check some properties to make sure that + // we can map it to tensorcore ops. We should have only mulAdd in the region + // and the output map should have no permutation and the last dimension + // should be a reduce. + Region &body = linalgOp->getRegion(0); + Region::OpIterator it = body.op_begin(); + // jump two arith ext ops(optional) + while (it != body.op_end() && isa(*it)) + it++; + if (it == body.op_end() || !isa(*(it++))) + return false; + if (it == body.op_end() || !isa(*(it++))) + return false; + if (it == body.op_end() || !isa(*(it++))) + return false; + AffineMap outputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); + if (outputMap.getNumResults() != outputMap.getNumDims() - 1) + return false; + OpBuilder b(linalgOp); + for (unsigned i = 0, e = outputMap.getNumResults(); i < e - 1; i++) { + if (outputMap.getResult(i) != b.getAffineDimExpr(i)) + return false; + } + } + return true; +} } // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/compiler/lib/Dialect/Linalg/Transforms/CMakeLists.txt index a58c7e0b6..f922246b1 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(ByteIRLinalgPasses BufferizableOpInterfaceImpl.cpp Bufferize.cpp CanonicalizeExt.cpp + CanonicalizeMatmulEpilogue.cpp FuseElementwise.cpp HoistingExt.cpp LinalgCollapseLoops.cpp diff --git a/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp b/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp new file mode 100644 index 000000000..8ecdab386 --- /dev/null +++ b/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp @@ -0,0 +1,157 @@ +//===- CanonicalizeMatmulEpilogue.cpp ---------------------------*---C++-*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.h" +#include "byteir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" + +#include "PassDetail.h" + +using namespace llvm; +using namespace mlir; + +#define DEBUG_TYPE "canonicalize-matmul-epilogue" + +namespace { + +static LogicalResult +modifyUseToGetValueIntoStoreSet(RewriterBase &rewriter, + linalg::GenericOp genericOp) { + SmallVector newInputs; + SmallVector newOutputs; + SmallVector newResultTypes; + SmallVector maps; + OpOperand *inOperand = nullptr; + OpOperand *initOperand = nullptr; + for (OpOperand *in : genericOp.getDpsInputOperands()) { + // if operand is generated by a op which has MainLoop Marker or it's a + // linalg.matmul + if (hasMarker( + in->get().getDefiningOp(), + ArrayRef{getMatmulMainLoopMarker(), getMMAPatternAttrName()})) { + inOperand = in; + } else { + newInputs.push_back(in->get()); + maps.push_back(genericOp.getMatchingIndexingMap(in)); + } + } + // assert has only one dps init + if (genericOp.getNumDpsInits() != 1) + return failure(); + initOperand = genericOp.getDpsInitOperand(0); + + if (inOperand == nullptr || initOperand == nullptr) + return failure(); + maps.push_back(genericOp.getMatchingIndexingMap(inOperand)); + newOutputs.push_back(inOperand->get()); + newResultTypes.push_back(inOperand->get().getType()); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(genericOp); + + Location loc = genericOp.getLoc(); + SmallVector iterTypes(genericOp.getNumLoops(), + utils::IteratorType::parallel); + auto newOp = rewriter.create( + loc, newResultTypes, newInputs, newOutputs, maps, iterTypes, + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.inlineRegionBefore(genericOp.getRegion(), newOp.getRegion(), + newOp.getRegion().begin()); + + // Repair the payload entry block. + Block &payload = newOp.getRegion().front(); + payload.getArgument(inOperand->getOperandNumber()) + .replaceAllUsesWith(payload.getArgument(initOperand->getOperandNumber())); + payload.eraseArgument(inOperand->getOperandNumber()); + + rewriter.replaceOp(genericOp, newOp.getResults()); + return success(); +} + +// This pass modify IR on linalg tensor level. +// 1. Modify epilogue linalg.generic to avoid write result to a new buffer. +// Actually we can reuse input buffer. +// 2. Use shared_outs argument to replace tensor.empty buffer in scf.forall. As +// the thread block will not modify different slice of tensor. +class CanonicalizeMatmulEpiloguePass + : public CanonicalizeMatmulEpilogueBase { +public: + CanonicalizeMatmulEpiloguePass() = default; + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + if (!hasGemmTileConfig(funcOp)) + return; + auto forallOptional = getForallOpMappedToBlock(funcOp); + if (!forallOptional) + return; + scf::ForallOp forallOp = *forallOptional; + + SmallVector epilogueOps; + + // find epilogue, linalg.generic with getEpilogueMarker + forallOp.walk([&](linalg::GenericOp genericOp) { + if (!hasMarker(genericOp, getEpilogueMarker())) + return; + epilogueOps.push_back(genericOp); + }); + + if (epilogueOps.empty()) { + return; + } + assert(epilogueOps.size() == 1); + linalg::GenericOp epilogueOp = epilogueOps[0]; + IRRewriter rewriter(epilogueOp); + + // modify the epilogue to get the value into the store set + if (failed(modifyUseToGetValueIntoStoreSet(rewriter, epilogueOp))) { + llvm::errs() << "failed in modifyUseToGetValueIntoStoreSet\n"; + return signalPassFailure(); + } + + // get scf.forall shared_outs + auto forallSharedOuts = forallOp.getRegionOutArgs(); + auto forallDpsInits = forallOp.getDpsInitsMutable(); + for (const auto &[sharedOut, dpsInit] : + llvm::zip(forallSharedOuts, forallDpsInits)) { + // Get sharedOut's defining op and replace defining op in scf.forall + Value emptyValueOptional = dpsInit.get(); + tensor::EmptyOp emptyOp = + emptyValueOptional.getDefiningOp(); + if (emptyOp == nullptr) + continue; + + emptyValueOptional.replaceUsesWithIf( + sharedOut, [&](OpOperand &opOperand) { + // Only replace uses in the forall block + return opOperand.getOwner()->getBlock() == forallOp.getBody(); + }); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::createCanonicalizeMatmulEpiloguePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp b/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp index 34973f04e..3600ac782 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp +++ b/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp @@ -58,14 +58,12 @@ using namespace mlir; namespace { -constexpr StringRef allocMarker[3] = {"__byteir_alloca_matrix_a__", - "__byteir_alloca_matrix_b__", - "__byteir_alloca_accumulator__"}; -constexpr StringRef copyMarker[3] = { - "__byteir_load_matrix_a__", - "__byteir_load_matrix_b__", - "__byteir_store_matrix_c__", -}; +constexpr StringRef allocMarker[3] = {getAllocSharedMemoryAMarker(), + getAllocSharedMemoryBMarker(), + getAllocSharedMemoryAccMarker()}; +constexpr StringRef copyMarker[3] = {getCopyToSharedMemoryAMarker(), + getCopyToSharedMemoryBMarker(), + getCopyFromSharedMemoryAccMarker()}; namespace MatmulOperands { constexpr static int64_t A = 0; @@ -99,8 +97,8 @@ allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{}, gpu::AddressSpaceAttr::get(builder.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace())); - memref::AllocaOp buffer = - builder.create(forallOp.getLoc(), type); + memref::AllocOp buffer = + builder.create(forallOp.getLoc(), type); setMarker(buffer, allocMarker[OPERAND]); // To fix fill op. The FillOp operand `subview` should be rewrited to // `alloca` @@ -138,11 +136,13 @@ LogicalResult copyWorkgroupMemoryToGlobalMemory(OpBuilder &b, Value src, // get the only scf.for op inside the scf.forall op. scf::ForallOp forallOp = op->getParentOfType(); auto forOps = llvm::to_vector(forallOp.getOps()); - if (forOps.size() != 1) - return forallOp.emitError("expected a single scf.for op"); // copyWorkgroupMemoryToGlobalMemory after gemm compute ends. - b.setInsertionPointAfter(forOps[0]); + if (forOps.size() == 1) + b.setInsertionPointAfter(forOps[0]); + if (forOps.size() > 1) + return failure(); + b.create(src.getLoc()); Operation *copyOp = b.create(src.getLoc(), src, dst); setLinalgTransformationMarker(copyOp, getCopyRelatedToWorkgroupMemoryMarker()); @@ -266,13 +266,13 @@ struct LinalgPromotionPass : public LinalgPromotionBase { if (!hasGemmTileConfig(funcOp)) return; - auto forallOptional = getForallOpMappedTo2DBlock(funcOp); + auto forallOptional = getForallOpMappedToBlock(funcOp); if (!forallOptional) return; scf::ForallOp forallOp = *forallOptional; forallOp.walk([&](linalg::LinalgOp linalgOp) { - if (isa(linalgOp)) + if (isLinalgOpMatmul(linalgOp)) toPromote.push_back(linalgOp); }); if (toPromote.empty()) @@ -292,7 +292,10 @@ struct LinalgPromotionPass : public LinalgPromotionBase { // As we know linalg.matmul is in a scf.for, and the subview promotionImpl // inserts should be in the scf.forall op. auto forOp = linalgContractOp->getParentOfType(); - builder.setInsertionPoint(forOp); // before forOp + if (forOp) + builder.setInsertionPoint(forOp); // before forOp + else + builder.setInsertionPoint(linalgContractOp); // before linalgContractOp (void)promotionImpl(builder, linalgContractOp); // The linalg.copy should be fused with its consumer linalg.generic. @@ -310,6 +313,12 @@ struct LinalgPromotionPass : public LinalgPromotionBase { for (Operation *op : toDelete) op->erase(); } + // as we should do synchronization after linalg.copy and before + // linalg.matmul + builder.setInsertionPoint(linalgContractOp); + builder.create(linalgContractOp.getLoc()); + builder.setInsertionPointAfter(linalgContractOp); + builder.create(linalgContractOp.getLoc()); } }; diff --git a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp index a977413b2..53733b202 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp @@ -223,26 +223,111 @@ static GenericFuserConfig config_concat_slice_fuse{ namespace matmul_epilogue { +// Only support m % 128 == 0 & n % 128 == 0 & k % 32 == 0 for now. +static bool isValidShape(Operation *op) { + if (auto dotOp = dyn_cast(op)) { + auto lhsType = dyn_cast(dotOp.getLhs().getType()); + auto rhsType = dyn_cast(dotOp.getRhs().getType()); + if (!lhsType || !rhsType) + return false; + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + if (lhsShape.size() != 2 || rhsShape.size() != 2) + return false; + if (lhsShape[1] != rhsShape[0]) + return false; + if (lhsShape[0] % 128 != 0 || rhsShape[1] % 128 != 0 || + lhsShape[1] % 32 != 0) + return false; + return true; + } else if (auto dotGeneralOp = dyn_cast(op)) { + auto lhsType = dyn_cast(dotGeneralOp.getLhs().getType()); + auto rhsType = dyn_cast(dotGeneralOp.getRhs().getType()); + if (!lhsType || !rhsType) + return false; + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + int64_t lhsRank = lhsShape.size(); + int64_t rhsRank = rhsShape.size(); + // Only support matmul or batchmatmul for now. + if (lhsRank < 2 || lhsRank > 3 || rhsRank < 2 || rhsRank > 3) + return false; + if (lhsRank != rhsRank) + return false; + mhlo::DotDimensionNumbersAttr dimensionAttr = + dotGeneralOp.getDotDimensionNumbersAttr(); + ArrayRef lhsBatchingDimensions = + dimensionAttr.getLhsBatchingDimensions(); + ArrayRef rhsBatchingDimensions = + dimensionAttr.getRhsBatchingDimensions(); + ArrayRef lhsContractingDimensions = + dimensionAttr.getLhsContractingDimensions(); + ArrayRef rhsContractingDimensions = + dimensionAttr.getRhsContractingDimensions(); + if (lhsContractingDimensions.size() != 1 || + rhsContractingDimensions.size() != 1) + return false; + int64_t lhsContractingDim = lhsContractingDimensions[0]; + int64_t rhsContractingDim = rhsContractingDimensions[0]; + if (lhsShape[lhsContractingDim] % 32 != 0 || + rhsShape[rhsContractingDim] % 32 != 0) { + return false; + } + // BatchMatmul + if (lhsBatchingDimensions.size() == 1 && + rhsBatchingDimensions.size() == 1) { + int64_t lhsSpatialDim = 3; + int64_t rhsSpatialDim = 3; + int64_t lhsBatchingDim = lhsBatchingDimensions[0]; + int64_t rhsBatchingDim = rhsBatchingDimensions[0]; + lhsSpatialDim -= (lhsBatchingDim + lhsContractingDim); + rhsSpatialDim -= (rhsBatchingDim + rhsContractingDim); + if (lhsShape[lhsSpatialDim] % 128 != 0 || + rhsShape[rhsSpatialDim] % 128 != 0) { + return false; + } + return true; + } else { + // Matmul + int64_t lhsSpatialDim = 1; + int64_t rhsSpatialDim = 1; + lhsSpatialDim -= lhsContractingDim; + rhsSpatialDim -= rhsContractingDim; + if (lhsShape[lhsSpatialDim] % 128 != 0 || + rhsShape[rhsSpatialDim] % 128 != 0) { + return false; + } + return true; + } + } + return false; +} + bool isFusibleCandidate(Operation *op) { return isMhlo(op) && (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || isMhloConstantLike(op) || - isa(op)); + isa(op)); } -bool isFusibleStart(Operation *op) { return isa(op); } +bool isFusibleStart(Operation *op) { + return isa(op) && isValidShape(op); +} bool isFusibleTrigger(Operation *op) { // trigger fuse for anything but dot - return !isa(op); + return !isa(op); } bool isFusibleWith(Operation * /*target*/, Operation * /*start*/) { return true; } -bool isValidSingleOp(Operation *op) { return false; } +bool isValidSingleOp(Operation *op) { + return isa(op) && isValidShape(op); +} bool isValidFusionPattern(const MhloFusionPattern &) { return true; } @@ -516,7 +601,7 @@ struct MatmulEpilogueFusionPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatmulEpilogueFusionPass) - MatmulEpilogueFusionPass() : GenericFusionPass(false) {} + MatmulEpilogueFusionPass() : GenericFusionPass(true) {} /// Returns the command-line argument attached to this pass. static constexpr ::llvm::StringLiteral getArgumentName() { diff --git a/compiler/lib/Pipelines/CMakeLists.txt b/compiler/lib/Pipelines/CMakeLists.txt index 9626a5e6c..5db9fcc8a 100644 --- a/compiler/lib/Pipelines/CMakeLists.txt +++ b/compiler/lib/Pipelines/CMakeLists.txt @@ -36,6 +36,7 @@ add_mlir_library(ByteIRPipelines LINK_LIBS PUBLIC ByteIRGPUPipelines ByteIRHloToCat + ByteIRVectorToGPU ByteIRHostPipelines ByteIRPipelineCommon ByteIRUtils diff --git a/compiler/lib/Pipelines/GPU/CMakeLists.txt b/compiler/lib/Pipelines/GPU/CMakeLists.txt index 8eea9ad17..a0ed5a769 100644 --- a/compiler/lib/Pipelines/GPU/CMakeLists.txt +++ b/compiler/lib/Pipelines/GPU/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(ByteIRGPUPipelines ElementwiseCodegen.cpp + GemmCodegen.cpp GPUOpt.cpp LinalgMemrefGPU.cpp MappingForall.cpp diff --git a/compiler/lib/Pipelines/GPU/GPUOpt.cpp b/compiler/lib/Pipelines/GPU/GPUOpt.cpp index 9b4ad345e..d038f4cf2 100644 --- a/compiler/lib/Pipelines/GPU/GPUOpt.cpp +++ b/compiler/lib/Pipelines/GPU/GPUOpt.cpp @@ -108,13 +108,13 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) { createGPUMappingForallTransform(pm, options); pm.addPass(createTransformDialectInterpreter(true)); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createGpuLauchSinkIndexComputationsPass()); { OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createGpuLauchSinkIndexComputationsPass()); anchoredPM.addPass(createPromoteBuffersToStackPass( /*isSmallAlloc =*/[](Value value) { return value.getParentRegion()->getParentOfType(); @@ -126,10 +126,42 @@ void createReductionGPUOptPipelineImpl(OpPassManager &pm) { pm.addPass(createGpuKernelOutliningPass()); } +void createGemmGPUOptPipelineImpl(OpPassManager &pm) { + GPUMappingForallOptions options; + options.funcAnchor = getByteIRMatmulEpilogueFusionAttrName().str(); + options.annotatePrefix = "__byteir_gpu_gemm_tile"; + createGPUMappingForallTransform(pm, options); + pm.addPass(createTransformDialectInterpreter(true)); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createGpuLauchSinkIndexComputationsPass()); + + anchoredPM.addPass(createPromoteBuffersToStackPass( + /*isSmallAlloc =*/[](Value value) { + return value.getParentRegion()->getParentOfType(); + })); + + pm.addNestedPass(createAnchoredPipelinePass( + getByteIRMatmulEpilogueFusionAttrName(), anchoredPM)); + } + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLegalizeGPULaunchPass()); + + pm.addNestedPass(createAnchoredPipelinePass( + getByteIRMatmulEpilogueFusionAttrName(), anchoredPM)); + } + pm.addPass(createGpuKernelOutliningPass()); +} + void createGPUOptPipelineImpl(OpPassManager &pm, const bool &useBarePtrCallConv, const std::string &target) { createElementwiseGPUOptPipelineImpl(pm, useBarePtrCallConv, target); createReductionGPUOptPipelineImpl(pm); + createGemmGPUOptPipelineImpl(pm); pm.addPass(createCollectGPUKernelPass("unified", false)); } diff --git a/compiler/lib/Pipelines/GPU/GemmCodegen.cpp b/compiler/lib/Pipelines/GPU/GemmCodegen.cpp new file mode 100644 index 000000000..fdf6e1b19 --- /dev/null +++ b/compiler/lib/Pipelines/GPU/GemmCodegen.cpp @@ -0,0 +1,378 @@ +//===- GemmCodegen.cpp ---------------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Pipelines/GPU/GemmCodegen.h" +#include "byteir/Conversion/ToGPU/ToGPU.h" +#include "byteir/Conversion/ToLLVM/ToLLVM.h" +#include "byteir/Dialect/GPU/Transforms/Utils.h" +#include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" +#include "byteir/Dialect/Linalg/Transforms/LinalgPrefetch.h" +#include "byteir/Dialect/Transform/IR/TransformExtOps.h" +#include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" +#include "byteir/Pipelines/Common/Utils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallSet.h" + +#include + +using namespace mlir; + +namespace { + +constexpr StringRef getLinalgToGPUAttrName() { return "__byteir_to_gpu__"; } + +constexpr StringRef getLinalgTargetAttrName() { return "__byteir_target__"; } + +void createGPUTileGemmTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix) { + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + config.opFilter = [=](Operation *op) { + if (!isLinalgOpMatmul(op)) + return false; + return true; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + func::FuncOp funcOp = op->getParentOfType(); + linalg::LinalgOp linalgOp = cast(op); + Operation *user = *linalgOp->getUsers().begin(); + bool hasEpilogue = isa(user); + + if (hasEpilogue) { + setMarker(user, getEpilogueMarker()); + } + + bool isBMM = linalgOp.getNumParallelLoops() == 3; + + SmallVector tileSizeConfig = getGemmTileSize(funcOp).value(); + + auto func = b.create( + pdlV.getType(), pdlV, + /* isolated_from_above */ false, + /* allow_empty_results */ false, + /* op_name */ b.getStringAttr(func::FuncOp::getOperationName()), + /* deduplicate */ false, + /* nth_parent */ 1); + + auto anyType = transform::AnyOpType::get(b.getContext()); + auto linalgFillType = transform::OperationType::get( + b.getContext(), linalg::FillOp::getOperationName()); + auto linalgFill = b.create( + linalgFillType, func, linalg::FillOp::getOperationName()); + + Value mmaLevel = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ b.getStringAttr("Threadblock")); + Value target = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ b.getStringAttr("nv_sm_80")); + + SmallVector mappingIdx; + if (isBMM) { + mappingIdx = {2, 1, 0}; + } else { + mappingIdx = {1, 0}; + } + auto mapping = llvm::to_vector(llvm::map_range( + mappingIdx, [](int64_t i) { return static_cast(i); })); + auto mappingAttrs = llvm::to_vector( + llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { + return gpu::GPUBlockMappingAttr::get(b.getContext(), dim); + })); + + SmallVector parrallelTileSizes; + if (isBMM) { + parrallelTileSizes = {1, tileSizeConfig[0], tileSizeConfig[1]}; + } else { + parrallelTileSizes = {tileSizeConfig[0], tileSizeConfig[1]}; + } + Value tiledMatmulOp; + if (hasEpilogue) { + auto linalgGenericType = transform::OperationType::get( + b.getContext(), linalg::GenericOp::getOperationName()); + auto epilogue = b.create( + linalgGenericType, func, + b.getStrArrayAttr({linalg::GenericOp::getOperationName()}), + /*matchInterfaceEnum=*/transform::MatchInterfaceEnumAttr(), + /*opAttrs=*/ + b.getDictionaryAttr({NamedAttribute( + b.getStringAttr(getEpilogueMarker()), b.getUnitAttr())}), + /*filterResultType=*/TypeAttr(), + /*filterOperandTYpes=*/ArrayAttr()); + + transform::TileUsingForallOp tileOp = + b.create( + /* target */ epilogue, + /* staticTileSizes */ parrallelTileSizes, + /* ctor tag */ transform::TileSizesSpec(), + /* mapping */ b.getArrayAttr(mappingAttrs)); + transform::FuseIntoContainingOp fuse = + b.create( + /* producerOp */ pdlV, + /* containingOp */ tileOp.getForallOp()); + b.create( + /* producerOp */ linalgFill, + /* containingOp */ fuse.getNewContainingOp()); + tiledMatmulOp = fuse.getFusedOp(); + } else { + transform::TileUsingForallOp tileOp = + b.create( + /* target */ pdlV, + /* staticTileSizes */ parrallelTileSizes, + /* ctor tag */ transform::TileSizesSpec(), + /* mapping */ b.getArrayAttr(mappingAttrs)); + + b.create( + /* producerOp */ linalgFill, + /* containingOp */ tileOp.getForallOp()); + tiledMatmulOp = tileOp.getTiledOp(); + } + + SmallVector reductionTileSizes; + if (isBMM) + reductionTileSizes = {0, 0, 0, tileSizeConfig[2]}; + else + reductionTileSizes = {0, 0, tileSizeConfig[2]}; + auto tileKMatmulOp = + b.create(tiledMatmulOp, reductionTileSizes); + auto matmulKOp = tileKMatmulOp.getTiledLinalgOp(); + auto forLoops = tileKMatmulOp.getLoops(); + if (!forLoops.empty()) { + b.create(forLoops[0], getMatmulMainLoopMarker(), + Value()); + } else { + b.create(matmulKOp, getMatmulMainLoopMarker(), + Value()); + } + + b.create(matmulKOp, getLinalgMMALevelAttrName(), + mmaLevel); + b.create(matmulKOp, getLinalgTargetAttrName(), + target); + b.create(matmulKOp, getMMAPatternAttrName(), + Value()); + }; + + pm.addPass(createGenericTransformInsertionPass(config)); +} + +} // namespace + +void mlir::createGPUTileGemmTransform(OpPassManager &pm, + const GPUGemmGeneralOptions &options) { + invokeOpPassPipelineBuilder(createGPUTileGemmTransformImpl, pm, + options.funcAnchor, options.annotatePrefix); +} + +namespace { + +void createGPUAddGemmCodegenLoweringConfigTransformImpl( + OpPassManager &pm, const std::string &anchor, const std::string &prefix, + ArrayRef tileSizeConfig, ArrayRef workgroupSize, + int64_t stages) { + + SmallVector tileSizeConfigVec{tileSizeConfig}; + SmallVector workgroupSizeVec{workgroupSize}; + + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + + config.opFilter = [=](Operation *op) { + if (isLinalgOpMatmul(op)) { + // TODO: check if the matmul op is already annotated + // TODO: Add different lowering config for different matmul op size + return true; + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + // auto linalgOp = llvm::cast(op); + auto tileSizeConfigAttrs = b.getAttr(llvm::to_vector( + llvm::map_range(tileSizeConfigVec, [&](int64_t i) -> Attribute { + return b.getI64IntegerAttr(i); + }))); + auto workgroupSizeAttrs = b.getAttr(llvm::to_vector( + llvm::map_range(workgroupSizeVec, [&](int64_t i) -> Attribute { + return b.getI64IntegerAttr(i); + }))); + auto stagesAttr = b.getI64IntegerAttr(stages); + + auto func = b.create( + pdlV.getType(), pdlV, + /* isolated_from_above */ true, + /* allow_empty_results */ false, + /* op_name */ b.getStringAttr(func::FuncOp::getOperationName()), + /* deduplicate */ false, + /* nth_parent */ 1); + + Value tileSizeConfigValue = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ tileSizeConfigAttrs); + Value workgroupSizeValue = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ workgroupSizeAttrs); + Value stagesValue = b.create( + /* type */ pdl::AttributeType::get(b.getContext()), + /* value */ stagesAttr); + + b.create(func, getGemmTileConfigAttrName(), + tileSizeConfigValue); + b.create(func, getGemmBlockSizeAttrName(), + workgroupSizeValue); + b.create(func, getGemmPipelineDepthAttrName(), + stagesValue); + }; + pm.addPass(createGenericTransformInsertionPass(config)); +} +} // namespace + +void mlir::createGPUAddGemmCodegenLoweringConfigTransform( + OpPassManager &pm, const GPUGemmCodegenConfigOptions &options) { + invokeOpPassPipelineBuilder( + createGPUAddGemmCodegenLoweringConfigTransformImpl, pm, + options.funcAnchor, options.annotatePrefix, options.tileSizeConfig, + options.workgroupSize, options.stages); +} + +namespace { + +int numIterations(scf::ForOp forOp) { + Value lowerBound = forOp.getLowerBound(); + Value upperBound = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // get def constant value + auto defLowerBound = lowerBound.getDefiningOp(); + auto defUpperBound = upperBound.getDefiningOp(); + auto defStep = step.getDefiningOp(); + + if (defLowerBound && defUpperBound && defStep) { + auto lowerBoundValue = defLowerBound.getValue(); + auto upperBoundValue = defUpperBound.getValue(); + auto stepValue = defStep.getValue(); + + auto lowerBoundInt = cast(lowerBoundValue).getInt(); + auto upperBoundInt = cast(upperBoundValue).getInt(); + auto stepInt = cast(stepValue).getInt(); + return (upperBoundInt - lowerBoundInt) / stepInt; + } + return -1; +} +void createGPUPipeliningTransformImpl(OpPassManager &pm, + const std::string &anchor, + const std::string &prefix) { + + TransformInsertionConfig config; + config.funcAnchor = anchor; + config.matchPrefix = prefix; + + config.opFilter = [=](Operation *op) { + if (auto forallOp = llvm::dyn_cast_or_null(op)) { + if (!isMappedToGPUBlocks(forallOp)) { + return false; + } + func::FuncOp funcOp = forallOp->getParentOfType(); + auto pipelineStageOptional = getGemmPipelineDepth(funcOp); + if (!pipelineStageOptional) { + return false; + } + SmallVector forOps; + forallOp.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); + if (forOps.size() != 1) + return false; + scf::ForOp forOp = forOps[0]; + if (numIterations(forOp) <= pipelineStageOptional.value()) + return false; + else + return true; + } + return false; + }; + + config.transformBuilder = [=](ImplicitLocOpBuilder &b, Operation *op, + Value pdlV) { + func::FuncOp funcOp = op->getParentOfType(); + auto pipelineStageOptional = getGemmPipelineDepth(funcOp); + if (!pipelineStageOptional) { + return; + } + int pipelineStage = *pipelineStageOptional; + auto anyType = transform::AnyOpType::get(b.getContext()); + + auto memrefAllocType = transform::OperationType::get( + b.getContext(), memref::AllocOp::getOperationName()); + auto memrefAllocMatrixLHS = b.create( + memrefAllocType, pdlV, + b.getStrArrayAttr({memref::AllocOp::getOperationName()}), + /*matchInterfaceEnum=*/transform::MatchInterfaceEnumAttr(), + /*opAttrs=*/ + b.getDictionaryAttr({NamedAttribute( + b.getStringAttr(getAllocSharedMemoryAMarker()), b.getUnitAttr())}), + /*filterResultType=*/TypeAttr(), + /*filterOperandTYpes=*/ArrayAttr()); + b.create( + anyType, memrefAllocMatrixLHS, pipelineStage, /* skip_analysis */ true); + + auto memrefAllocMatrixRHS = b.create( + memrefAllocType, pdlV, + b.getStrArrayAttr({memref::AllocOp::getOperationName()}), + /*matchInterfaceEnum=*/transform::MatchInterfaceEnumAttr(), + /*opAttrs=*/ + b.getDictionaryAttr({NamedAttribute( + b.getStringAttr(getAllocSharedMemoryBMarker()), b.getUnitAttr())}), + /*filterResultType=*/TypeAttr(), + /*filterOperandTYpes=*/ArrayAttr()); + b.create( + anyType, memrefAllocMatrixRHS, pipelineStage, /* skip_analysis */ true); + + // fold memref alias for subview of multi-buffers + b.create(pdlV, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + + // match scf::for op + auto scfForOpType = transform::OperationType::get( + b.getContext(), scf::ForOp::getOperationName()); + auto scfForOp = b.create( + scfForOpType, pdlV, scf::ForOp::getOperationName()); + b.create(anyType, scfForOp, + pipelineStage); + }; + pm.addPass(createGenericTransformInsertionPass(config)); +} +} // namespace + +void mlir::createGPUPipeliningTransform(OpPassManager &pm, + const GPUGemmGeneralOptions &options) { + invokeOpPassPipelineBuilder(createGPUPipeliningTransformImpl, pm, + options.funcAnchor, options.annotatePrefix); +} \ No newline at end of file diff --git a/compiler/lib/Pipelines/GPU/MappingForall.cpp b/compiler/lib/Pipelines/GPU/MappingForall.cpp index 0fd0a2343..dd5123529 100644 --- a/compiler/lib/Pipelines/GPU/MappingForall.cpp +++ b/compiler/lib/Pipelines/GPU/MappingForall.cpp @@ -107,6 +107,13 @@ getMappingForallConfig(scf::ForallOp forallOp, const int64_t warpSize, if (!isMappedToGPUBlocks(forallOp)) return std::nullopt; + if (func::FuncOp funcOp = forallOp->getParentOfType()) { + auto blockSizeOptional = getGemmBlockSize(funcOp); + if (blockSizeOptional.has_value()) { + return MappingForallConfig{SmallVector(blockSizeOptional.value())}; + } + } + SmallVector blockDims{1, 1, 1}; auto &&block = forallOp.getRegion().front(); auto hasDynamicDims = [&]() -> bool { diff --git a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp index b791546b8..1b1f904d3 100644 --- a/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/NVVMCodegen.cpp @@ -53,7 +53,7 @@ void createNVVMCodegenPipelineImpl(OpPassManager &pm, pm.addPass(createSimplifyLinearizedIndexPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addNestedPass(createConvertVectorToLLVMPass()); + // pm.addNestedPass(createConvertVectorToLLVMPass()); pm.addNestedPass(createGPUToNVVMExtPass( useBarePtrCallConv, mlir::kDeriveIndexBitwidthFromDataLayout, gpuArch)); pm.addPass(createCSEPass()); diff --git a/compiler/lib/Pipelines/HloFusionOpt.cpp b/compiler/lib/Pipelines/HloFusionOpt.cpp index c70b9ba4a..0f97c44a6 100644 --- a/compiler/lib/Pipelines/HloFusionOpt.cpp +++ b/compiler/lib/Pipelines/HloFusionOpt.cpp @@ -31,7 +31,8 @@ using namespace mlir::mhlo; namespace { void addGenericHloFusionPatterns(OpPassManager &pm, bool outlineSingleElemwiseOp, - bool disableFusion, bool outlineCatOp) { + bool disableFusion, bool outlineCatOp, + bool outlineDotOp) { // Fusion passes if (outlineCatOp) { pm.addNestedPass(createCatFusionPass()); @@ -42,6 +43,10 @@ void addGenericHloFusionPatterns(OpPassManager &pm, pm.addNestedPass(createIOConvertFusionPass()); pm.addNestedPass(createReductionFusionPass()); + // outline dot ops and use gemm codegen + if (outlineDotOp) { + pm.addNestedPass(createMatmulEpilogueFusionPass()); + } pm.addNestedPass(createConcatSliceFusionPass()); // Element fusion (always last?) // Note: if outlineSingleElemwiseOp is set, element fusion must be the last @@ -64,7 +69,8 @@ void createHloFusionOptPipelineImpl(OpPassManager &pm, const std::string &entryFunc, const std::string &target, bool outlineSingleElemwiseOp, - bool disableFusion, bool outlineCatOp) { + bool disableFusion, bool outlineCatOp, + bool outlineDotOp) { addCleanUpExtPassPipeline(pm); // add fusion patterns @@ -72,7 +78,7 @@ void createHloFusionOptPipelineImpl(OpPassManager &pm, addCPUHloFusionPatterns(pm, disableFusion); } else { addGenericHloFusionPatterns(pm, outlineSingleElemwiseOp, disableFusion, - outlineCatOp); + outlineCatOp, outlineDotOp); } // note don't apply sccp @@ -85,8 +91,8 @@ void createHloFusionOptPipelineImpl(OpPassManager &pm, void mlir::createHloFusionOptPipeline( OpPassManager &pm, const HloFusionOptPipelineOptions &options) { - invokeOpPassPipelineBuilder(createHloFusionOptPipelineImpl, pm, - options.entryFunc, options.target, - options.outlineSingleElemwiseOp, - options.disableFusion, options.outlineCatOp); + invokeOpPassPipelineBuilder( + createHloFusionOptPipelineImpl, pm, options.entryFunc, options.target, + options.outlineSingleElemwiseOp, options.disableFusion, + options.outlineCatOp, options.outlineDotOp); } diff --git a/compiler/lib/Pipelines/LinalgMemrefOpt.cpp b/compiler/lib/Pipelines/LinalgMemrefOpt.cpp index 6b11d2503..545d1bf9f 100644 --- a/compiler/lib/Pipelines/LinalgMemrefOpt.cpp +++ b/compiler/lib/Pipelines/LinalgMemrefOpt.cpp @@ -18,15 +18,82 @@ #include "byteir/Pipelines/LinalgMemrefOpt.h" #include "byteir/Conversion/ToLinalg/ToLinalg.h" +#include "byteir/Conversion/VectorToGPU/GPUVectorToGPU.h" #include "byteir/Dialect/Byre/ByreDialect.h" +#include "byteir/Dialect/GPU/Passes.h" +#include "byteir/Dialect/Linalg/Passes.h" +#include "byteir/Dialect/Transform/Transforms/TransformDialectInterpreter.h" +#include "byteir/Dialect/Transform/Transforms/TransformInsertion.h" #include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Pipelines/Common/Utils.h" +#include "byteir/Pipelines/GPU/GemmCodegen.h" +#include "byteir/Transforms/AnchoredPipeline.h" +#include "byteir/Transforms/CanonicalizeExt.h" #include "byteir/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; namespace { +void addGemmOptPasses(OpPassManager &pm) { + { + auto gemmAnchor = getByteIRMatmulEpilogueFusionAttrName().str(); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createLinalgPromotionPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createCanonicalizerPass()); + + anchoredPM.addPass(createGPUDistributeToWarpPass()); + anchoredPM.addPass(createRemoveTrivialLoopsPass()); + anchoredPM.addPass(createGPUTensorCoreVectorizationPass()); + anchoredPM.addPass(memref::createFoldMemRefAliasOpsPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createOptimizeVectorTransferPass()); + anchoredPM.addPass(createGPUDistributeSharedMemoryCopyPass()); + anchoredPM.addPass(memref::createFoldMemRefAliasOpsPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + // tranfer_read -> nvgpu.async_copy + anchoredPM.addPass(createGPUVectorToGPUPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(memref::createFoldMemRefAliasOpsPass()); + // shared memory swizzle + anchoredPM.addPass(createGPUInputSharedMemorySwizzlePass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + pm.addNestedPass( + createAnchoredPipelinePass(gemmAnchor, anchoredPM)); + } + + // do multi-buffer and pipelining + { + GPUGemmGeneralOptions options; + options.funcAnchor = gemmAnchor; + createGPUPipeliningTransform(pm, options); + pm.addPass(createTransformDialectInterpreter(true)); + pm.addPass(memref::createFoldMemRefAliasOpsPass()); + } + + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + // Pack shared memory alloc to reuse it + anchoredPM.addPass(createGPUPackSharedMemoryAllocPass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + anchoredPM.addPass(createGPUBlockSwizzlePass(3)); + pm.addNestedPass( + createAnchoredPipelinePass(gemmAnchor, anchoredPM)); + } + } +} + void addGenericLinalgMemrefOptPasses(OpPassManager &pm) { // TODO: change getByteIRElementwiseFusionAttrName to GPU specific codegen // anchor tag @@ -41,6 +108,7 @@ void addGenericLinalgMemrefOptPasses(OpPassManager &pm) { void createLinalgMemrefOptPipelineImpl(OpPassManager &pm, const std::string & /* target */) { addGenericLinalgMemrefOptPasses(pm); + addGemmOptPasses(pm); } } // namespace diff --git a/compiler/lib/Pipelines/LinalgTensorOpt.cpp b/compiler/lib/Pipelines/LinalgTensorOpt.cpp index b1d75341c..23b9ee35e 100644 --- a/compiler/lib/Pipelines/LinalgTensorOpt.cpp +++ b/compiler/lib/Pipelines/LinalgTensorOpt.cpp @@ -25,6 +25,7 @@ #include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Pipelines/Common/Utils.h" #include "byteir/Pipelines/GPU/ElementwiseCodegen.h" +#include "byteir/Pipelines/GPU/GemmCodegen.h" #include "byteir/Pipelines/GPU/ReductionCodegen.h" #include "byteir/Pipelines/Host/Codegen.h" #include "byteir/Transforms/AnchoredPipeline.h" @@ -43,6 +44,8 @@ void addGenericLinalgPasses(OpPassManager &pm) { createHloFusionToLinalgPass(getByteIRElementwiseFusionAttrName())); pm.addNestedPass( createHloFusionToLinalgPass(getByteIRReductionFusionAttrName())); + pm.addNestedPass( + createHloFusionToLinalgPass(getByteIRMatmulEpilogueFusionAttrName())); pm.addNestedPass(createUnrealizedCastToLinalgPass()); pm.addPass(createLinalgElementwiseFusionExtPass( /*enableSharedInput*/ true, /*enableDiffShapes*/ false)); @@ -225,6 +228,42 @@ void addGenericLinalgPasses(OpPassManager &pm) { pm.addNestedPass( createAnchoredPipelinePass(reductionAnchor, anchoredPM)); } + { // gemm codegen + auto gemmAnchor = getByteIRMatmulEpilogueFusionAttrName().str(); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + // Try to fuse possible epilogue linalg elementwise ops + anchoredPM.addPass(createLinalgElementwiseOpFusionPass()); + pm.addNestedPass( + createAnchoredPipelinePass(gemmAnchor, anchoredPM)); + } + SmallVector tileSizeConfig = {128, 128, 32}; + SmallVector workgroupSize = {64, 2, 1}; + int64_t stages = 3; + // Annotate fusion with gemm config + GPUGemmCodegenConfigOptions gemmConfigOptions; + gemmConfigOptions.funcAnchor = gemmAnchor; + gemmConfigOptions.tileSizeConfig = tileSizeConfig; + gemmConfigOptions.workgroupSize = workgroupSize; + gemmConfigOptions.stages = stages; + createGPUAddGemmCodegenLoweringConfigTransform(pm, gemmConfigOptions); + pm.addPass(createTransformDialectInterpreter(true)); + + GPUGemmGeneralOptions options; + options.funcAnchor = gemmAnchor; + createGPUTileGemmTransform(pm, options); + pm.addPass(createTransformDialectInterpreter(true)); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + { + OpPassManager anchoredPM(func::FuncOp::getOperationName()); + anchoredPM.addPass(createCanonicalizeMatmulEpiloguePass()); + anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createCSEPass()); + pm.addNestedPass( + createAnchoredPipelinePass(gemmAnchor, anchoredPM)); + } + } } } diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 9b16e682f..c872efb64 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -28,6 +28,7 @@ def __init__(self, verbose: bool = False, name: str = "model", enable_tf32: bool = False, + enable_gemm_codegen: bool = False, parallelism: int = 1, disable_byteir_ait_cache: bool = False, **kwargs): @@ -43,6 +44,7 @@ def __init__(self, self.verbose = verbose self.name = name self.enable_tf32 = enable_tf32 + self.enable_gemm_codegen = enable_gemm_codegen self.parallelism = parallelism self.disable_byteir_ait_cache = disable_byteir_ait_cache self.kwargs = kwargs @@ -89,6 +91,7 @@ def _compile_cuda( entry_func = compile_options.entry_func gpu_arch = compile_options.gpu_arch verbose = compile_options.verbose + enable_gemm_codegen = compile_options.enable_gemm_codegen enable_tf32 = compile_options.enable_tf32 output_file_dir = compile_options.output_dir @@ -104,7 +107,10 @@ def _compile_cuda( PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation) _print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ... with context: - PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op})").run(module.operation) + if enable_gemm_codegen: + PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op outline-dot-op})").run(module.operation) + else: + PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op})").run(module.operation) _print_verbose(module, "// IR Dump After Hlo Fusion Opt:") if verbose else ... with context: PassManager.parse("builtin.module(linalg-tensor-opt)").run(module.operation) @@ -377,6 +383,7 @@ def compile( byre_serial_version: str = "1.0.0", verbose: bool = False, enable_tf32: bool = False, + enable_gemm_codegen: bool = False, parallelism: int = 1, disable_byteir_ait_cache: bool = False, **kwargs, @@ -434,6 +441,7 @@ def compile( byre_serial_version=byre_serial_version, verbose=verbose, enable_tf32=enable_tf32, + enable_gemm_codegen=enable_gemm_codegen, parallelism=parallelism, disable_byteir_ait_cache=disable_byteir_ait_cache, kwargs=kwargs) diff --git a/compiler/python/byteir/tools/compiler.py b/compiler/python/byteir/tools/compiler.py index 1c01e06bd..88c1ff35f 100644 --- a/compiler/python/byteir/tools/compiler.py +++ b/compiler/python/byteir/tools/compiler.py @@ -28,6 +28,7 @@ # gpu options parser.add_argument("--enable_tf32", default=False, action="store_true") + parser.add_argument("--enable_gemm_codegen", default=False, action="store_true") parser.add_argument("--ait_parallelism", type=int, default=1, help="number of processes to compile ait op") parser.add_argument("--disable_byteir_cache", default=False, action="store_true") @@ -41,5 +42,6 @@ byre_serial_version=args.serial_version, verbose=args.verbose, enable_tf32=args.enable_tf32, + enable_gemm_codegen=args.enable_gemm_codegen, parallelism=args.ait_parallelism, disable_byteir_ait_cache=args.disable_byteir_cache) diff --git a/compiler/test/Dialect/GPU/gpu-block-swizzle.mlir b/compiler/test/Dialect/GPU/gpu-block-swizzle.mlir index 3432ba8d1..3f937bff3 100644 --- a/compiler/test/Dialect/GPU/gpu-block-swizzle.mlir +++ b/compiler/test/Dialect/GPU/gpu-block-swizzle.mlir @@ -43,6 +43,6 @@ module { // CHECK-NEXT: %[[ADDI1:.*]] = arith.addi %[[MULI1]], %[[REMUI0]] : index // CHECK-NEXT: %[[REMUI1:.*]] = arith.remui %[[ADDI0]], %[[C168]] : index // CHECK-NEXT: %[[DIVUI1:.*]] = arith.divui %[[REMUI1]], %[[MINSI0]] : index -// CHECK-NEXT: %[[APPLY_MAP0:.*]] = affine.apply #[[MAP]](%[[DIVUI1]]) -// CHECK-NEXT: %[[APPLY_MAP1:.*]] = affine.apply #[[MAP]](%[[ADDI1]]) +// CHECK-NEXT: %[[APPLY_MAP0:.*]] = affine.apply #[[MAP]](%[[ADDI1]]) +// CHECK-NEXT: %[[APPLY_MAP1:.*]] = affine.apply #[[MAP]](%[[DIVUI1]]) // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[APPLY_MAP0]], %[[APPLY_MAP1]]] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>> \ No newline at end of file diff --git a/compiler/test/Dialect/GPU/gpu-pack-shared-memory-alloc.mlir b/compiler/test/Dialect/GPU/gpu-pack-shared-memory-alloc.mlir index ae85345b7..550bcdbf0 100644 --- a/compiler/test/Dialect/GPU/gpu-pack-shared-memory-alloc.mlir +++ b/compiler/test/Dialect/GPU/gpu-pack-shared-memory-alloc.mlir @@ -27,9 +27,9 @@ module { %c32 = arith.constant 32 : index %alloc = memref.alloc() : memref<5376x5376xf16> scf.forall (%arg2, %arg3) in (42, 42) { - %alloca = memref.alloca() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space> - %alloca_1 = memref.alloca() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space> - %alloca_2 = memref.alloca() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space> + %alloc_0 = memref.alloc() {__byteir_alloc_accumulator__} : memref<128x128xf16, #gpu.address_space> + %alloc_1 = memref.alloc() {__byteir_alloc_matrix_b__} : memref<32x128xf16, #gpu.address_space> + %alloc_2 = memref.alloc() {__byteir_alloc_matrix_a__} : memref<128x32xf16, #gpu.address_space> %0 = affine.apply #map(%arg2) %1 = affine.apply #map(%arg3) %subview = memref.subview %alloc[%0, %1] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>> @@ -50,24 +50,24 @@ module { %16:32 = scf.for %arg4 = %c0 to %c2048 step %c32 iter_args(%arg5 = %cst, %arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst) -> (vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>) { %subview_3 = memref.subview %arg0[%0, %arg4] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>> %subview_4 = memref.subview %arg1[%arg4, %1] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>> - linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%subview_3 : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%alloca_2 : memref<128x32xf16, #gpu.address_space>) - linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%subview_4 : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%alloca_1 : memref<32x128xf16, #gpu.address_space>) - %17 = vector.transfer_read %alloca_2[%4, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %18 = vector.transfer_read %alloca_2[%4, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %19 = vector.transfer_read %alloca_2[%13, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %20 = vector.transfer_read %alloca_2[%13, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %21 = vector.transfer_read %alloca_2[%14, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %22 = vector.transfer_read %alloca_2[%14, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %23 = vector.transfer_read %alloca_2[%15, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %24 = vector.transfer_read %alloca_2[%15, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> - %25 = vector.transfer_read %alloca_1[%c0, %5], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %26 = vector.transfer_read %alloca_1[%c16, %5], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %27 = vector.transfer_read %alloca_1[%c0, %7], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %28 = vector.transfer_read %alloca_1[%c16, %7], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %29 = vector.transfer_read %alloca_1[%c0, %9], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %30 = vector.transfer_read %alloca_1[%c16, %9], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %31 = vector.transfer_read %alloca_1[%c0, %11], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> - %32 = vector.transfer_read %alloca_1[%c16, %11], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%subview_3 : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%alloc_2 : memref<128x32xf16, #gpu.address_space>) + linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%subview_4 : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%alloc_1 : memref<32x128xf16, #gpu.address_space>) + %17 = vector.transfer_read %alloc_2[%4, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %18 = vector.transfer_read %alloc_2[%4, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %19 = vector.transfer_read %alloc_2[%13, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %20 = vector.transfer_read %alloc_2[%13, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %21 = vector.transfer_read %alloc_2[%14, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %22 = vector.transfer_read %alloc_2[%14, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %23 = vector.transfer_read %alloc_2[%15, %c0], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %24 = vector.transfer_read %alloc_2[%15, %c16], %cst_0 {in_bounds = [true, true]} : memref<128x32xf16, #gpu.address_space>, vector<16x16xf16> + %25 = vector.transfer_read %alloc_1[%c0, %5], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %26 = vector.transfer_read %alloc_1[%c16, %5], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %27 = vector.transfer_read %alloc_1[%c0, %7], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %28 = vector.transfer_read %alloc_1[%c16, %7], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %29 = vector.transfer_read %alloc_1[%c0, %9], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %30 = vector.transfer_read %alloc_1[%c16, %9], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %31 = vector.transfer_read %alloc_1[%c0, %11], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> + %32 = vector.transfer_read %alloc_1[%c16, %11], %cst_0 {in_bounds = [true, true], permutation_map = #map13} : memref<32x128xf16, #gpu.address_space>, vector<16x16xf16> %33 = vector.extract_strided_slice %25 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %34 = vector.contract {indexing_maps = [#map14, #map15, #map16], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %17, %33, %arg5 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> %35 = vector.extract_strided_slice %25 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> @@ -150,45 +150,45 @@ module { %112 = vector.contract {indexing_maps = [#map14, #map15, #map16], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %24, %87, %72 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> scf.yield %74, %76, %78, %80, %82, %84, %86, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112 : vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16>, vector<16x8xf16> } - vector.transfer_write %16#31, %alloca[%15, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#30, %alloca[%15, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#29, %alloca[%15, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#28, %alloca[%15, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#27, %alloca[%15, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#26, %alloca[%15, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#25, %alloca[%15, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#24, %alloca[%15, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#23, %alloca[%14, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#22, %alloca[%14, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#21, %alloca[%14, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#20, %alloca[%14, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#19, %alloca[%14, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#18, %alloca[%14, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#17, %alloca[%14, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#16, %alloca[%14, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#15, %alloca[%13, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#14, %alloca[%13, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#13, %alloca[%13, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#12, %alloca[%13, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#11, %alloca[%13, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#10, %alloca[%13, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#9, %alloca[%13, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#8, %alloca[%13, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#7, %alloca[%4, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#6, %alloca[%4, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#5, %alloca[%4, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#4, %alloca[%4, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#3, %alloca[%4, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#2, %alloca[%4, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#1, %alloca[%4, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - vector.transfer_write %16#0, %alloca[%4, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> - linalg.copy {__byteir_store_matrix_c__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%alloca : memref<128x128xf16, #gpu.address_space>) outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>) + vector.transfer_write %16#31, %alloc_0[%15, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#30, %alloc_0[%15, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#29, %alloc_0[%15, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#28, %alloc_0[%15, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#27, %alloc_0[%15, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#26, %alloc_0[%15, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#25, %alloc_0[%15, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#24, %alloc_0[%15, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#23, %alloc_0[%14, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#22, %alloc_0[%14, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#21, %alloc_0[%14, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#20, %alloc_0[%14, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#19, %alloc_0[%14, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#18, %alloc_0[%14, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#17, %alloc_0[%14, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#16, %alloc_0[%14, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#15, %alloc_0[%13, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#14, %alloc_0[%13, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#13, %alloc_0[%13, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#12, %alloc_0[%13, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#11, %alloc_0[%13, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#10, %alloc_0[%13, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#9, %alloc_0[%13, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#8, %alloc_0[%13, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#7, %alloc_0[%4, %12] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#6, %alloc_0[%4, %11] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#5, %alloc_0[%4, %10] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#4, %alloc_0[%4, %9] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#3, %alloc_0[%4, %8] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#2, %alloc_0[%4, %7] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#1, %alloc_0[%4, %6] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + vector.transfer_write %16#0, %alloc_0[%4, %5] {in_bounds = [true, true]} : vector<16x8xf16>, memref<128x128xf16, #gpu.address_space> + linalg.copy {__byteir_store_matrix_c__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%alloc_0 : memref<128x128xf16, #gpu.address_space>) outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>) } {mapping = [#gpu.block, #gpu.block]} return %alloc : memref<5376x5376xf16> } } -// CHECK: %alloca = memref.alloca() : memref<32768xi8, #gpu.address_space> -// CHECK: %{{.*}} = memref.view %alloca[%c0{{.*}}][] : memref<32768xi8, #gpu.address_space> to memref<32x128xf16, #gpu.address_space> -// CHECK: %{{.*}} = memref.view %alloca[%c8192{{.*}}][] : memref<32768xi8, #gpu.address_space> to memref<128x32xf16, #gpu.address_space> -// CHECK: %{{.*}} = memref.view %alloca[%c0{{.*}}][] : memref<32768xi8, #gpu.address_space> to memref<128x128xf16, #gpu.address_space> \ No newline at end of file +// CHECK: %[[ALLOC_PACK:.*]] = memref.alloc() : memref<32768xi8, #gpu.address_space> +// CHECK: %{{.*}} = memref.view %[[ALLOC_PACK]][%c0{{.*}}][] : memref<32768xi8, #gpu.address_space> to memref<32x128xf16, #gpu.address_space> +// CHECK: %{{.*}} = memref.view %[[ALLOC_PACK]][%c8192{{.*}}][] : memref<32768xi8, #gpu.address_space> to memref<128x32xf16, #gpu.address_space> +// CHECK: %{{.*}} = memref.view %[[ALLOC_PACK]][%c0{{.*}}][] : memref<32768xi8, #gpu.address_space> to memref<128x128xf16, #gpu.address_space> \ No newline at end of file diff --git a/compiler/test/Dialect/Linalg/canonicalize-matmul-epilogue.mlir b/compiler/test/Dialect/Linalg/canonicalize-matmul-epilogue.mlir new file mode 100644 index 000000000..4e4ea95f8 --- /dev/null +++ b/compiler/test/Dialect/Linalg/canonicalize-matmul-epilogue.mlir @@ -0,0 +1,59 @@ +// RUN: byteir-opt %s -canonicalize-matmul-epilogue --canonicalize -cse | FileCheck %s +#map = affine_map<(d0) -> (d0 * 128)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func private @Unknown0(%arg0: tensor<5376x2048xf16>, %arg1: tensor<2048x5376xf16>, %arg2: tensor<5376x5376xf16>) -> tensor<5376x5376xf16> attributes {__byteir_gemm_block_size__ = [64, 2, 1], __byteir_gemm_pipeline_depth__ = 3 : i64, __byteir_gemm_tile_config__ = [128, 128, 32], __byteir_matmul_epilogue_fusion__} { + %c32 = arith.constant 32 : index + %c2048 = arith.constant 2048 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<5376x5376xf16> + %1 = scf.forall (%arg3, %arg4) in (42, 42) shared_outs(%arg5 = %0) -> (tensor<5376x5376xf16>) { + %2 = affine.apply #map(%arg3) + %3 = affine.apply #map(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%2, 0] [128, 2048] [1, 1] : tensor<5376x2048xf16> to tensor<128x2048xf16> + %extracted_slice_0 = tensor.extract_slice %arg1[0, %3] [2048, 128] [1, 1] : tensor<2048x5376xf16> to tensor<2048x128xf16> + %extracted_slice_1 = tensor.extract_slice %0[%2, %3] [128, 128] [1, 1] : tensor<5376x5376xf16> to tensor<128x128xf16> + %4 = linalg.fill ins(%cst : f16) outs(%extracted_slice_1 : tensor<128x128xf16>) -> tensor<128x128xf16> + %5 = scf.for %arg6 = %c0 to %c2048 step %c32 iter_args(%arg7 = %4) -> (tensor<128x128xf16>) { + %extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %arg6] [128, 32] [1, 1] : tensor<128x2048xf16> to tensor<128x32xf16> + %extracted_slice_5 = tensor.extract_slice %extracted_slice_0[%arg6, 0] [32, 128] [1, 1] : tensor<2048x128xf16> to tensor<32x128xf16> + %7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_4, %extracted_slice_5 : tensor<128x32xf16>, tensor<32x128xf16>) outs(%arg7 : tensor<128x128xf16>) attrs = {__byteir_gpu_tile_gemm_0, __byteir_gpu_tile_gemm_1, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} { + ^bb0(%in: f16, %in_6: f16, %out: f16): + %8 = arith.mulf %in, %in_6 : f16 + %9 = arith.addf %out, %8 : f16 + linalg.yield %9 : f16 + } -> tensor<128x128xf16> + scf.yield %7 : tensor<128x128xf16> + } + %extracted_slice_2 = tensor.extract_slice %arg2[%2, %3] [128, 128] [1, 1] : tensor<5376x5376xf16> to tensor<128x128xf16> + %extracted_slice_3 = tensor.extract_slice %arg5[%2, %3] [128, 128] [1, 1] : tensor<5376x5376xf16> to tensor<128x128xf16> + %6 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_2 : tensor<128x128xf16>, tensor<128x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf16>) attrs = {__byteir_epilogue__} { + ^bb0(%in: f16, %in_4: f16, %out: f16): + %7 = arith.addf %in, %in_4 : f16 + linalg.yield %7 : f16 + } -> tensor<128x128xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg5[%2, %3] [128, 128] [1, 1] : tensor<128x128xf16> into tensor<5376x5376xf16> + } + } {mapping = [#gpu.block, #gpu.block]} + return %1 : tensor<5376x5376xf16> + } + func.func @main(%arg0: tensor<5376x2048xf16>, %arg1: tensor<2048x5376xf16>, %arg2: tensor<5376x5376xf16>) -> tensor<5376x5376xf16> { + %0 = call @Unknown0(%arg0, %arg1, %arg2) : (tensor<5376x2048xf16>, tensor<2048x5376xf16>, tensor<5376x5376xf16>) -> tensor<5376x5376xf16> + return %0 : tensor<5376x5376xf16> + } +} + +// CHECK: scf.forall (%{{.*}}, %{{.*}}) in (42, 42) shared_outs(%[[V0:.*]] = %{{.*}}) +// CHECK: %[[V1:.*]] = tensor.extract_slice %[[V0]] +// CHECK: linalg.fill ins(%{{.*}}) outs(%[[V1]] : {{.*}}) +// CHECK: %[[MATMUL_RESULT:.*]] = scf.for +// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<128x128xf16>) outs(%[[MATMUL_RESULT]] : tensor<128x128xf16>) +// CHECK-NEXT: ^bb0(%in: f16, %out: f16): +// CHECK-NEXT: %[[T1:.*]] = arith.addf %out, %in : f16 +// CHECK-NEXT: linalg.yield %[[T1]] : f16 +// CHECK-NEXT: } -> tensor<128x128xf16> diff --git a/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir b/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir index 7d2443d24..f233a6881 100644 --- a/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir +++ b/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir @@ -36,21 +36,24 @@ module { // CHECK-NEXT: %[[C32:.*]] = arith.constant 32 : index // CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<5376x5376xf16> // CHECK-NEXT: scf.forall (%[[ARG2:.*]], %[[ARG3:.*]]) in (42, 42) { -// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space> -// CHECK-NEXT: %[[ALLOCA_0:.*]] = memref.alloca() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space> -// CHECK-NEXT: %[[ALLOCA_1:.*]] = memref.alloca() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOC_2:.*]] = memref.alloc() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOC_0:.*]] = memref.alloc() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOC_1:.*]] = memref.alloc() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space> // CHECK-NEXT: %[[APPLY_MAP0:.*]] = affine.apply #[[MAP]](%[[ARG2]]) // CHECK-NEXT: %[[APPLY_MAP1:.*]] = affine.apply #[[MAP]](%[[ARG3]]) // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[APPLY_MAP0]], %[[APPLY_MAP1]]] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>> -// CHECK-NEXT: linalg.fill ins(%[[CST]] : f16) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.fill ins(%[[CST]] : f16) outs(%[[ALLOC_2]] : memref<128x128xf16, #gpu.address_space>) // CHECK-NEXT: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2048]] step %[[C32]] { // CHECK-NEXT: %[[SUBVIEW_2:.*]] = memref.subview %[[ARG0]][%[[APPLY_MAP0]], %[[ARG4]]] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>> // CHECK-NEXT: %[[SUBVIEW_3:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[APPLY_MAP1]]] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>> -// CHECK-NEXT: linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_2]] : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%[[ALLOCA_1]] : memref<128x32xf16, #gpu.address_space>) -// CHECK-NEXT: linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_3]] : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%[[ALLOCA_0]] : memref<32x128xf16, #gpu.address_space>) -// CHECK-NEXT: linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%[[ALLOCA_1]], %[[ALLOCA_0]] : memref<128x32xf16, #gpu.address_space>, memref<32x128xf16, #gpu.address_space>) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_2]] : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%[[ALLOC_1]] : memref<128x32xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_3]] : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%[[ALLOC_0]] : memref<32x128xf16, #gpu.address_space>) +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%[[ALLOC_1]], %[[ALLOC_0]] : memref<128x32xf16, #gpu.address_space>, memref<32x128xf16, #gpu.address_space>) outs(%[[ALLOC_2]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: gpu.barrier // CHECK-NEXT: } -// CHECK-NEXT: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) outs(%[[SUBVIEW]] : memref<128x128xf16, strided<[5376, 1], offset: ?>>) attrs = {__internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} { +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ALLOC_2]] : memref<128x128xf16, #gpu.address_space>) outs(%[[SUBVIEW]] : memref<128x128xf16, strided<[5376, 1], offset: ?>>) attrs = {__internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} { // CHECK-NEXT: ^bb0(%in: f16, %out: f16): // CHECK-NEXT: %2 = arith.maximumf %in, %cst : f16 // CHECK-NEXT: linalg.yield %2 : f16 diff --git a/compiler/test/Dialect/Linalg/linalg-promotion.mlir b/compiler/test/Dialect/Linalg/linalg-promotion.mlir index 8524b5e22..74c684922 100644 --- a/compiler/test/Dialect/Linalg/linalg-promotion.mlir +++ b/compiler/test/Dialect/Linalg/linalg-promotion.mlir @@ -30,21 +30,24 @@ module { // CHECK-NEXT: %[[C32:.*]] = arith.constant 32 : index // CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<5376x5376xf16> // CHECK-NEXT: scf.forall (%[[ARG2:.*]], %[[ARG3:.*]]) in (42, 42) { -// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space> -// CHECK-NEXT: %[[ALLOCA_0:.*]] = memref.alloca() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space> -// CHECK-NEXT: %[[ALLOCA_1:.*]] = memref.alloca() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOC_2:.*]] = memref.alloc() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOC_0:.*]] = memref.alloc() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOC_1:.*]] = memref.alloc() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space> // CHECK-NEXT: %[[APPLY_MAP0:.*]] = affine.apply #[[MAP]](%[[ARG2]]) // CHECK-NEXT: %[[APPLY_MAP1:.*]] = affine.apply #[[MAP]](%[[ARG3]]) // CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[APPLY_MAP0]], %[[APPLY_MAP1]]] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>> -// CHECK-NEXT: linalg.fill ins(%[[CST]] : f16) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.fill ins(%[[CST]] : f16) outs(%[[ALLOC_2]] : memref<128x128xf16, #gpu.address_space>) // CHECK-NEXT: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2048]] step %[[C32]] { // CHECK-NEXT: %[[SUBVIEW_2:.*]] = memref.subview %[[ARG0]][%[[APPLY_MAP0]], %[[ARG4]]] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>> // CHECK-NEXT: %[[SUBVIEW_3:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[APPLY_MAP1]]] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>> -// CHECK-NEXT: linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_2]] : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%[[ALLOCA_1]] : memref<128x32xf16, #gpu.address_space>) -// CHECK-NEXT: linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_3]] : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%[[ALLOCA_0]] : memref<32x128xf16, #gpu.address_space>) -// CHECK-NEXT: linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%[[ALLOCA_1]], %[[ALLOCA_0]] : memref<128x32xf16, #gpu.address_space>, memref<32x128xf16, #gpu.address_space>) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_2]] : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%[[ALLOC_1]] : memref<128x32xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_3]] : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%[[ALLOC_0]] : memref<32x128xf16, #gpu.address_space>) +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%[[ALLOC_1]], %[[ALLOC_0]] : memref<128x32xf16, #gpu.address_space>, memref<32x128xf16, #gpu.address_space>) outs(%[[ALLOC_2]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: gpu.barrier // CHECK-NEXT: } -// CHECK-NEXT: linalg.copy {__byteir_store_matrix_c__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) outs(%[[SUBVIEW]] : memref<128x128xf16, strided<[5376, 1], offset: ?>>) +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: linalg.copy {__byteir_store_matrix_c__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[ALLOC_2]] : memref<128x128xf16, #gpu.address_space>) outs(%[[SUBVIEW]] : memref<128x128xf16, strided<[5376, 1], offset: ?>>) // CHECK-NEXT: } {mapping = [#gpu.block, #gpu.block]} // CHECK-NEXT: return %[[ALLOC]] : memref<5376x5376xf16> // CHECK-NEXT: } diff --git a/compiler/test/Dialect/Mhlo/transforms/matmulEpilogueFusion.mlir b/compiler/test/Dialect/Mhlo/transforms/matmulEpilogueFusion.mlir index 2b5c185f7..1f9ecadc6 100644 --- a/compiler/test/Dialect/Mhlo/transforms/matmulEpilogueFusion.mlir +++ b/compiler/test/Dialect/Mhlo/transforms/matmulEpilogueFusion.mlir @@ -1,13 +1,13 @@ // RUN: byteir-opt %s -fuse-matmul-epilogue | FileCheck %s -func.func @dot_element_epilog(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<4x4xf32>, %arg3 : tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>) { - %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = "mhlo.add"(%arg2, %0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = "mhlo.abs"(%1) : (tensor<4x4xf32>) -> tensor<4x4xf32> - %3 = "mhlo.add"(%arg3, %2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %4 = "mhlo.dot"(%3, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %5 = "mhlo.add"(%3, %4) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - return %3, %5 : tensor<4x4xf32>, tensor<4x4xf32> +func.func @dot_element_epilog(%arg0 : tensor<128x128xf32>, %arg1 : tensor<128x128xf32>, %arg2 : tensor<128x128xf32>, %arg3 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %1 = "mhlo.add"(%arg2, %0) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %2 = "mhlo.abs"(%1) : (tensor<128x128xf32>) -> tensor<128x128xf32> + %3 = "mhlo.add"(%arg3, %2) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %4 = "mhlo.dot"(%3, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %5 = "mhlo.add"(%3, %4) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + return %3, %5 : tensor<128x128xf32>, tensor<128x128xf32> } // CHECK-LABEL: func.func @dot_element_epilog // CHECK-NEXT: mhlo.fusion @@ -24,11 +24,11 @@ func.func @dot_element_epilog(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>, // CHECK: {__byteir_matmul_epilogue_fusion__} // CHECK: return -func.func @dot_element_epilog_with_previous(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<4x4xf32>, %arg3 : tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>) { - %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = "mhlo.dot"(%arg2, %arg3) {precision_config = [#mhlo, #mhlo]} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = "mhlo.add"(%0, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - return %0, %2 : tensor<4x4xf32>, tensor<4x4xf32> +func.func @dot_element_epilog_with_previous(%arg0 : tensor<128x128xf32>, %arg1 : tensor<128x128xf32>, %arg2 : tensor<128x128xf32>, %arg3 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %1 = "mhlo.dot"(%arg2, %arg3) {precision_config = [#mhlo, #mhlo]} : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %2 = "mhlo.add"(%0, %1) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + return %0, %2 : tensor<128x128xf32>, tensor<128x128xf32> } // CHECK-LABEL: func.func @dot_element_epilog_with_previous // CHECK-NEXT: mhlo.add @@ -39,11 +39,11 @@ func.func @dot_element_epilog_with_previous(%arg0 : tensor<4x4xf32>, %arg1 : ten // CHECK: {__byteir_matmul_epilogue_fusion__} // CHECK: return -func.func @dot_element_epilog_with_next(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<4x4xf32>, %arg3 : tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>) { - %0 = "mhlo.dot"(%arg2, %arg3) {precision_config = [#mhlo, #mhlo]} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = "mhlo.add"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = "mhlo.add"(%1, %0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - return %1, %2 : tensor<4x4xf32>, tensor<4x4xf32> +func.func @dot_element_epilog_with_next(%arg0 : tensor<128x128xf32>, %arg1 : tensor<128x128xf32>, %arg2 : tensor<128x128xf32>, %arg3 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %0 = "mhlo.dot"(%arg2, %arg3) {precision_config = [#mhlo, #mhlo]} : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %1 = "mhlo.add"(%arg0, %arg1) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %2 = "mhlo.add"(%1, %0) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + return %1, %2 : tensor<128x128xf32>, tensor<128x128xf32> } // CHECK-LABEL: func.func @dot_element_epilog_with_next // CHECK-NEXT: mhlo.add @@ -54,12 +54,15 @@ func.func @dot_element_epilog_with_next(%arg0 : tensor<4x4xf32>, %arg1 : tensor< // CHECK: {__byteir_matmul_epilogue_fusion__} // CHECK: return -func.func @dot_element_prolog(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<4x4xf32>, %arg3 : tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>) { - %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = "mhlo.dot"(%0, %arg2) {precision_config = [#mhlo, #mhlo]} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - return %0, %1 : tensor<4x4xf32>, tensor<4x4xf32> +func.func @dot_element_prolog(%arg0 : tensor<128x128xf32>, %arg1 : tensor<128x128xf32>, %arg2 : tensor<128x128xf32>, %arg3 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + %1 = "mhlo.dot"(%0, %arg2) {precision_config = [#mhlo, #mhlo]} : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> + return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> } // CHECK-LABEL: func.func @dot_element_prolog // CHECK-NEXT: mhlo.add -// CHECK-NEXT: mhlo.dot +// CHECK-NEXT: mhlo.fusion +// CHECK-NEXT: mhlo.dot +// CHECK-NEXT: mhlo.return +// CHECK: {__byteir_matmul_epilogue_fusion__} // CHECK-NEXT: return diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 1d9670c6a..8e111fec4 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -43,6 +43,7 @@ using namespace mlir; #define BLOCK_SIZE_X_ATTR "BlockSize.x" #define BLOCK_SIZE_Y_ATTR "BlockSize.y" #define BLOCK_SIZE_Z_ATTR "BlockSize.z" +#define SHARED_MEMORY_SIZE "DynamicSharedMemorySize" #define ARG_RANKS_ATTR "arg_ranks" #define CALL_CONVENTION_ATTR "call_convention" @@ -92,6 +93,11 @@ struct PTXImpl { CUfunction func; auto status_func = ptx_compiler->GetOrCreateFunction( func, kernel_info.kernel_name, kernel_info.file_name); + size_t max_shared_mem = 48 << 10; + if (shared_size > max_shared_mem) { + cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_size); + } BRT_ENFORCE(status_func.IsOK(), status_func.ErrorMessage()); device2func.emplace(device_id, func); return func; @@ -170,11 +176,17 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) ranks.push_back(GetRankFromOpArgIndex(info_, i)); } } + int64_t dynamic_shm_size = 0; + if (info.GetOperation()->hasAttrOfType(SHARED_MEMORY_SIZE)) { + dynamic_shm_size = info.GetOperation() + ->getAttrOfType(SHARED_MEMORY_SIZE) + .getInt(); + } auto num_arg = GetOpArgNum(info_); impl_->grid = dim3(gx, gy, gz); impl_->block = dim3(bx, by, bz); - impl_->shared_size = 0; + impl_->shared_size = dynamic_shm_size; impl_->arg_reserve_size = 3; // initial 3 for grid/block/shared_size // store tensor meta diff --git a/tests/numerical_test/execute.py b/tests/numerical_test/execute.py index 48be71f75..303c0a1ec 100644 --- a/tests/numerical_test/execute.py +++ b/tests/numerical_test/execute.py @@ -33,8 +33,10 @@ np.random.uniform(low=0.5, high=1.0, size=(256, 1)).astype(np.float16) ], "cpu@convert_f32_i32_special_val.mlir": [ - np.array([[np.inf, -np.inf, np.nan], [1., 999.999, -np.inf]], dtype=np.float32), - ] + np.array( + [[np.inf, -np.inf, np.nan], [1.0, 999.999, -np.inf]], dtype=np.float32 + ), + ], } @@ -53,7 +55,7 @@ def entry_func(self): @property def entry_func_name(self) -> str: - return self.entry_func.name.value + return self.entry_func.name.value def need_special_inputs(self) -> bool: key = self.target + "@" + self.file_base_name @@ -156,7 +158,16 @@ def profile(self, inputs, outputs, warmup_trials=10, run_trials=50): return ((end - start) * 1000) / run_trials -def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", unique_name=None, **kwargs): +def compile_and_run_mlir( + mhlo_file, + target, + workdir, + verbose, + mode="numerical", + enable_gemm_codegen=False, + unique_name=None, + **kwargs, +): if unique_name is None: unique_name = os.path.basename(mhlo_file).split(".")[0] + "." + target try: @@ -175,13 +186,23 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", os.makedirs(workdir, exist_ok=True) os.makedirs(workdir + f"/{unique_name}", exist_ok=True) output_mlir_file_name = f"{workdir}/{unique_name}/{unique_name}.rt.mlir" - byteir.compile( - mhlo_file, - output_mlir_file_name, - entry_func=entry_func_name, - target=target, - verbose=verbose, - ) + if enable_gemm_codegen: + byteir.compile( + mhlo_file, + output_mlir_file_name, + entry_func=entry_func_name, + target=target, + enable_gemm_codegen=True, + verbose=verbose, + ) + else: + byteir.compile( + mhlo_file, + output_mlir_file_name, + entry_func=entry_func_name, + target=target, + verbose=verbose, + ) except Exception as e: return TestResult( unique_name=unique_name, @@ -230,7 +251,10 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", # print("golden output: ", golden_output) # print("actual output: ", output.detach().cpu().numpy()) golden = torch.from_numpy(golden_output).contiguous().to(cur_device) - torch.testing.assert_close(golden, output) + if enable_gemm_codegen: + torch.testing.assert_close(golden, output, atol=1e-2, rtol=1e-2) + else: + torch.testing.assert_close(golden, output) except Exception as e: return TestResult( unique_name=unique_name, @@ -250,7 +274,9 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", ) -def compile_and_run_torch(test, target, workdir, verbose, mode="numerical"): +def compile_and_run_torch( + test, target, workdir, verbose, mode="numerical", enable_gemm_codegen=False +): from torch_e2e_testing.framework import generate_golden_trace import torch_frontend @@ -279,13 +305,23 @@ def compile_and_run_torch(test, target, workdir, verbose, mode="numerical"): output_mlir_file_name = f"{workdir}/{unique_name}/{unique_name}.rt.mlir" with open(mlir_file_name, "w+") as fout: compiled_graph.operation.print(file=fout, large_elements_limit=None) - byteir.compile( - mlir_file_name, - output_mlir_file_name, - entry_func="forward", - target=target, - verbose=verbose, - ) + if enable_gemm_codegen: + byteir.compile( + mlir_file_name, + output_mlir_file_name, + entry_func="forward", + target=target, + enable_gemm_codegen=True, + verbose=verbose, + ) + else: + byteir.compile( + mlir_file_name, + output_mlir_file_name, + entry_func="forward", + target=target, + verbose=verbose, + ) except Exception as e: return TestResult( unique_name=unique_name, @@ -325,7 +361,12 @@ def compile_and_run_torch(test, target, workdir, verbose, mode="numerical"): try: golden_output = trace_item.output.detach().cpu() actual_output = torch_outputs[0].detach().cpu() - torch.testing.assert_close(golden_output, actual_output) + if enable_gemm_codegen: + torch.testing.assert_close( + golden_output, actual_output, atol=1e-1, rtol=1e-2 + ) + else: + torch.testing.assert_close(golden_output, actual_output) except Exception as e: return TestResult( unique_name=unique_name, diff --git a/tests/numerical_test/main.py b/tests/numerical_test/main.py index a5103459e..ad86fa200 100644 --- a/tests/numerical_test/main.py +++ b/tests/numerical_test/main.py @@ -25,58 +25,95 @@ GLOBAL_TORCH_TEST_REGISTRY_NAMES, ) from testset import CPU_MLIR_TEST_DIR, CUDA_MLIR_TEST_DIR -from testset import CPU_ALL_SET, CUDA_ALL_SET, CUDA_AIT_ALL_SET, CUDA_AIT_SM80PLUS_SET +from testset import ( + CPU_ALL_SET, + CUDA_ALL_SET, + CUDA_AIT_ALL_SET, + CUDA_AIT_SM80PLUS_SET, + CUDA_WITH_GEMM_CODEGEN_SET, +) ##### TEST SET CONFIG ####### TEST_SET = { "cpu": CPU_ALL_SET, "cuda": CUDA_ALL_SET, "cuda_with_ait": CUDA_AIT_ALL_SET, + # as the some features are still under development + # we will merge it into cuda test later + "cuda_with_gemm_codegen": CUDA_WITH_GEMM_CODEGEN_SET, } + def get_local_gpu_arch(): from byteir.utils import detect_gpu_arch_with_nvidia_smi + gpu_arch = detect_gpu_arch_with_nvidia_smi() assert gpu_arch != None assert gpu_arch.startswith("sm_") gpu_arch = int(gpu_arch[3:]) return gpu_arch + def run(target, filter, workdir, mode="numerical", verbose=False): if target == "dynamo": from torch_dynamo_e2e_testing.execute import run_torch_dynamo_tests + gpu_arch = get_local_gpu_arch() # TODO(zzk): use test infra for dynamo tests run_torch_dynamo_tests(gpu_arch) return [] + enable_gemm_codegen = target == "cuda_with_gemm_codegen" test_set = TEST_SET[target] + if target != "cpu": gpu_arch = get_local_gpu_arch() if target == "cuda_with_ait" and gpu_arch < 80: test_set -= CUDA_AIT_SM80PLUS_SET + # As we only support gemm codegen on sm80+ + if target == "cuda_with_gemm_codegen" and gpu_arch < 80: + enable_gemm_codegen = False + + # As cuda_with_gemm_codegen is a special case of cuda. + if target == "cuda_with_gemm_codegen": + target = "cuda" results = [] for test in test_set: if not re.match(filter, test): continue if test in GLOBAL_TORCH_TEST_REGISTRY_NAMES: + print(test) results.append( compile_and_run_torch( - GLOBAL_TORCH_TEST_REGISTRY[test], target, workdir, verbose, mode + GLOBAL_TORCH_TEST_REGISTRY[test], + target, + workdir, + verbose, + mode, + enable_gemm_codegen, ) ) else: if target == "cpu": results.append( compile_and_run_mlir( - os.path.join(CPU_MLIR_TEST_DIR, test), target, workdir, verbose, mode + os.path.join(CPU_MLIR_TEST_DIR, test), + target, + workdir, + verbose, + mode, ) ) else: results.append( compile_and_run_mlir( - os.path.join(CUDA_MLIR_TEST_DIR, test), target, workdir, verbose, mode + os.path.join(CUDA_MLIR_TEST_DIR, test), + target, + workdir, + verbose, + mode, + enable_gemm_codegen, ) ) return results @@ -93,6 +130,7 @@ def parse_args(): "cpu", "cuda", "cuda_with_ait", + "cuda_with_gemm_codegen", "dynamo", "native_torch", ], @@ -141,10 +179,18 @@ def main(): results = [] if args.target == "all": - for target in ["cpu", "cuda", "cuda_with_ait", "dynamo"]: + for target in [ + "cpu", + "cuda", + "cuda_with_ait", + "dynamo", + "cuda_with_gemm_codegen", + ]: results += run(target, args.filter, args.workdir) else: - results += run(args.target, args.filter, args.workdir, mode=args.mode, verbose=args.verbose) + results += run( + args.target, args.filter, args.workdir, mode=args.mode, verbose=args.verbose + ) failed = report_results(results) sys.exit(1 if failed else 0) diff --git a/tests/numerical_test/mlir_tests/ops/bmm_rcr_f16f16f32.mlir b/tests/numerical_test/mlir_tests/ops/bmm_rcr_f16f16f32.mlir new file mode 100644 index 000000000..b77b511e7 --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/bmm_rcr_f16f16f32.mlir @@ -0,0 +1,5 @@ +func.func @bmm_rcr(%arg0 : tensor<32x256x128xf16>, %arg1 : tensor<32x256x128xf16>) -> tensor<32x256x256xf16> { + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<32x256x128xf16>) -> tensor<32x128x256xf16> + %1 = "mhlo.dot_general"(%arg1, %0) {dot_dimension_numbers = #mhlo.dot} : (tensor<32x256x128xf16>, tensor<32x128x256xf16>) -> tensor<32x256x256xf16> + return %1 : tensor<32x256x256xf16> +} diff --git a/tests/numerical_test/mlir_tests/ops/bmm_rrr_f16f16f32.mlir b/tests/numerical_test/mlir_tests/ops/bmm_rrr_f16f16f32.mlir new file mode 100644 index 000000000..bf4cfeea8 --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/bmm_rrr_f16f16f32.mlir @@ -0,0 +1,4 @@ +func.func @bmm_rrr(%arg0 : tensor<32x256x256xf16>, %arg1 : tensor<32x256x128xf16>) -> tensor<32x256x128xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<32x256x256xf16>, tensor<32x256x128xf16>) -> tensor<32x256x128xf32> + return %0 : tensor<32x256x128xf32> +} diff --git a/tests/numerical_test/mlir_tests/ops/gemm_crr_f16f16f32.mlir b/tests/numerical_test/mlir_tests/ops/gemm_crr_f16f16f32.mlir new file mode 100644 index 000000000..377f5d258 --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/gemm_crr_f16f16f32.mlir @@ -0,0 +1,7 @@ +func.func @bmm_crr(%arg0 : tensor<1x256x4096xf16>, %arg1 : tensor<256x11008xf16>) -> tensor<4096x11008xf32> { + %0 = mhlo.reshape %arg0 : (tensor<1x256x4096xf16>) -> tensor<256x4096xf16> + %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<256x4096xf16>) -> tensor<4096x256xf16> + %2 = "mhlo.dot"(%1, %arg1) : (tensor<4096x256xf16>, tensor<256x11008xf16>) -> tensor<4096x11008xf32> + return %2: tensor<4096x11008xf32> +} + diff --git a/tests/numerical_test/mlir_tests/ops/gemm_crr_f32.mlir b/tests/numerical_test/mlir_tests/ops/gemm_crr_f32.mlir new file mode 100644 index 000000000..7f20075a9 --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/gemm_crr_f32.mlir @@ -0,0 +1,7 @@ +func.func @bmm_crr(%arg0 : tensor<1x256x4096xf32>, %arg1 : tensor<256x11008xf32>) -> tensor<4096x11008xf32> { + %0 = mhlo.reshape %arg0 : (tensor<1x256x4096xf32>) -> tensor<256x4096xf32> + %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<256x4096xf32>) -> tensor<4096x256xf32> + %2 = "mhlo.dot"(%1, %arg1) : (tensor<4096x256xf32>, tensor<256x11008xf32>) -> tensor<4096x11008xf32> + return %2: tensor<4096x11008xf32> +} + diff --git a/tests/numerical_test/mlir_tests/ops/gemm_rrr_f16f16f32.mlir b/tests/numerical_test/mlir_tests/ops/gemm_rrr_f16f16f32.mlir new file mode 100644 index 000000000..7a652d8ba --- /dev/null +++ b/tests/numerical_test/mlir_tests/ops/gemm_rrr_f16f16f32.mlir @@ -0,0 +1,4 @@ +func.func @main(%arg0: tensor<256x128xf16>, %arg1: tensor<128x256xf16>) -> tensor<256x256xf32> { + %0 = "mhlo.dot"(%arg0, %arg1): (tensor<256x128xf16>, tensor<128x256xf16>)-> tensor<256x256xf32> + return %0 : tensor<256x256xf32> +} diff --git a/tests/numerical_test/testset.py b/tests/numerical_test/testset.py index 02585154e..66d80ee49 100644 --- a/tests/numerical_test/testset.py +++ b/tests/numerical_test/testset.py @@ -9,6 +9,7 @@ CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + def _get_test_files_from_dir(directory): test_files = [] for filename in os.listdir(directory): @@ -49,6 +50,10 @@ def _get_test_files_from_dir(directory): "transpose1203.mlir", "transpose2013.mlir", "transpose120.mlir", + "gemm_crr_f16f16f32.mlir", + "gemm_rrr_f16f16f32.mlir", + "bmm_rcr_f16f16f32.mlir", + "bmm_rrr_f16f16f32.mlir", } CUDA_ALL_SET = (CUDA_MLIR_TEST_SET | CUDA_TORCH_TEST_SET) - CUDA_XFAIL_SET @@ -91,3 +96,29 @@ def _get_test_files_from_dir(directory): } CUDA_AIT_ALL_SET = CUDA_AIT_MLIR_TEST_SET | CUDA_AIT_TORCH_TEST_SET + +##### CUDA WITH GEMMCODEGEN TEST SET ####### +CUDA_TORCH_MATMUL_TESTS = {test for test in CUDA_TORCH_TEST_SET if "Matmul" in test} + +CUDA_GEMMCODEGEN_TESTS = { + "gemm_crr_f16f16f32.mlir", + "gemm_crr_f32.mlir", + "gemm_rrr_f16f16f32.mlir", + "bmm_rcr_f16f16f32.mlir", + "bmm_rrr_f16f16f32.mlir", +} + +CUDA_WITH_GEMM_CODEGEN_XFAIL_SET = { + "MatmulTransposeAF16Module_basic", + # "MatmulTransposeBF16Module_basic", + # "MatmulTransposeModule_basic", + # TODO: Test passed on A10. But failed on CI machine. + # "BatchMatmulAddF32Module_basic", + # TODO: fix bug + "gemm_crr_f16f16f32.mlir", + "bmm_rcr_f16f16f32.mlir", +} + +CUDA_WITH_GEMM_CODEGEN_SET = ( + CUDA_TORCH_MATMUL_TESTS | CUDA_GEMMCODEGEN_TESTS +) - CUDA_WITH_GEMM_CODEGEN_XFAIL_SET diff --git a/tests/numerical_test/torch_e2e_testing/test_suite/basic.py b/tests/numerical_test/torch_e2e_testing/test_suite/basic.py index 6955b32b8..e224e27ca 100644 --- a/tests/numerical_test/torch_e2e_testing/test_suite/basic.py +++ b/tests/numerical_test/torch_e2e_testing/test_suite/basic.py @@ -44,9 +44,51 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: MatmulF16Module()) def MatmulF16Module_basic(module, tu: TestUtils): - module.forward(tu.rand(256, 512).to(torch.float16), - tu.rand(512, 1024).to(torch.float16)) + module.forward(tu.rand(128, 32).to(torch.float16), + tu.rand(32, 128).to(torch.float16)) + +class BatchMatmulF16Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.bmm(a, b) + +@register_test_case(module_factory=lambda: BatchMatmulF16Module()) +def BatchMatmulF16Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 1024, 128).to(torch.float16), + tu.rand(2, 128, 1024).to(torch.float16)) + + +class MatmulTransposeAF16Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.matmul(a.T, b) + +@register_test_case(module_factory=lambda: MatmulTransposeAF16Module()) +def MatmulTransposeAF16Module_basic(module, tu: TestUtils): + module.forward(tu.rand(64, 128).to(torch.float16), + tu.rand(64, 128).to(torch.float16)) + + +class MatmulTransposeBF16Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.matmul(a, b.T) + +@register_test_case(module_factory=lambda: MatmulTransposeBF16Module()) +def MatmulTransposeBF16Module_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 64).to(torch.float32), + tu.rand(128, 64).to(torch.float32)) + class MatmulTransposeModule(torch.nn.Module): def __init__(self): @@ -71,7 +113,7 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: MatmulF32Module()) def MatmulF32Module_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 6), tu.rand(6, 10)) + module.forward(tu.rand(1024, 128), tu.rand(128, 1024)) class BatchMatmulF32Module(torch.nn.Module): @@ -85,9 +127,36 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: BatchMatmulF32Module()) def BatchMatmulF32Module_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 5, 6), tu.rand(2, 6, 10)) + module.forward(tu.rand(2, 128, 128), tu.rand(2, 128, 128)) +class MatmulAddF16Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + return c + torch.matmul(a, b) + +@register_test_case(module_factory=lambda: MatmulAddF16Module()) +def MatmulAddF16Module_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 32).to(torch.float16), + tu.rand(32, 128).to(torch.float16), + tu.rand(128, 128).to(torch.float16)) + +class MatmulF16ReluModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.relu(torch.matmul(a, b)) + +@register_test_case(module_factory=lambda: MatmulF16ReluModule()) +def MatmulF16ReluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 256).to(torch.float16), + tu.rand(256, 128).to(torch.float16)) + class BatchMatmulAddF32Module(torch.nn.Module): def __init__(self): @@ -99,7 +168,7 @@ def forward(self, a, b, c): @register_test_case(module_factory=lambda: BatchMatmulAddF32Module()) def BatchMatmulAddF32Module_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 5, 6), tu.rand(2, 6, 10), tu.rand(2, 5, 10)) + module.forward(tu.rand(1, 128, 128), tu.rand(1, 128, 128), tu.rand(1, 128, 128)) # ==============================================================================