diff --git a/compiler/include/byteir/Dialect/Vector/CMakeLists.txt b/compiler/include/byteir/Dialect/Vector/CMakeLists.txt index 5c919f7df..93279fab2 100644 --- a/compiler/include/byteir/Dialect/Vector/CMakeLists.txt +++ b/compiler/include/byteir/Dialect/Vector/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(TransformOps) add_subdirectory(Transforms) \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Vector/TransformOps/CMakeLists.txt b/compiler/include/byteir/Dialect/Vector/TransformOps/CMakeLists.txt new file mode 100644 index 000000000..ac409a542 --- /dev/null +++ b/compiler/include/byteir/Dialect/Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS VectorExtTransformOps.td) +mlir_tablegen(VectorExtTransformOps.h.inc -gen-op-decls) +mlir_tablegen(VectorExtTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRVectorExtTransformOpsIncGen) + +add_mlir_doc(VectorExtTransformOps VectorExtTransformOps Dialects/ -gen-op-doc) \ No newline at end of file diff --git a/compiler/include/byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.h b/compiler/include/byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.h new file mode 100644 index 000000000..8614f008d --- /dev/null +++ b/compiler/include/byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.h @@ -0,0 +1,52 @@ +//===- VectorExtTransformOps.h - Vector transform ops ----------*- C++ --*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_Vector_TRANSFORMOPS_VectorEXTTRANSFORMOPS_H +#define BYTEIR_DIALECT_Vector_TRANSFORMOPS_VectorEXTTRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +class TilingInterface; +class RewriterBase; +namespace Vector { +class GenericOp; +class VectorOp; +} // namespace Vector +} // namespace mlir + +//===----------------------------------------------------------------------===// +// VectorExt Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace vector_ext { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace vector_ext +} // namespace mlir + +#endif // BYTEIR_DIALECT_Vector_TRANSFORMOPS_VectorEXTTRANSFORMOPS_H diff --git a/compiler/include/byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.td b/compiler/include/byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.td new file mode 100644 index 000000000..bec4cd4a6 --- /dev/null +++ b/compiler/include/byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.td @@ -0,0 +1,49 @@ +//===-- VectorExtTransformOps.td ------------------------------------------===// +// +// Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from LinalgTransformOps.td in LLVM project +// Original license: +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_DIALECT_VECTOR_TRANSFORMOPS_VECTOR_EXT_TRANSFORMOPS +#define BYTEIR_DIALECT_VECTOR_TRANSFORMOPS_VECTOR_EXT_TRANSFORMOPS + +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def ConvertReductionToGPUShuffleOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Convert vector reduction op to a sequence of GPU shuffle ops. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs ); + + let assemblyFormat = "$target attr-dict `:` type($target)"; +} + + +#endif // BYTEIR_DIALECT_VECTOR_TRANSFORMOPS_VECTOR_EXT_TRANSFORMOPS \ No newline at end of file diff --git a/compiler/lib/CAPI/Dialects.cpp b/compiler/lib/CAPI/Dialects.cpp index 9325f00d6..623f033ce 100644 --- a/compiler/lib/CAPI/Dialects.cpp +++ b/compiler/lib/CAPI/Dialects.cpp @@ -45,5 +45,6 @@ void byteirRegisterDialectExtensions(MlirContext context) { linalg_ext::registerTransformDialectExtension(registry); transform_ext::registerTransformDialectExtension(registry); tensor_ext::registerTilingInterfaceExternalModels(registry); + vector_ext::registerTransformDialectExtension(registry); unwrap(context)->appendDialectRegistry(registry); } diff --git a/compiler/lib/Dialect/Vector/CMakeLists.txt b/compiler/lib/Dialect/Vector/CMakeLists.txt index e31af3266..bb2c33aaf 100644 --- a/compiler/lib/Dialect/Vector/CMakeLists.txt +++ b/compiler/lib/Dialect/Vector/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/compiler/lib/Dialect/Vector/TransformOps/CMakeLists.txt new file mode 100644 index 000000000..8841ec047 --- /dev/null +++ b/compiler/lib/Dialect/Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,27 @@ +add_mlir_dialect_library(MLIRVectorExtTransformOps + VectorExtTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Dialect/Vector/TransformOps + + DEPENDS + MLIRVectorExtTransformOpsIncGen + ByteIRVectorPasses + + LINK_LIBS PUBLIC + ByteIRLinalgPasses + MLIRAffineDialect + MLIRArithDialect + MLIRCclDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRParser + MLIRPDLDialect + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRTensorTilingInterfaceImplExt + MLIRTransformDialect + MLIRVectorDialect + MLIRGPUDialect + ) diff --git a/compiler/lib/Dialect/Vector/TransformOps/VectorExtTransformOps.cpp b/compiler/lib/Dialect/Vector/TransformOps/VectorExtTransformOps.cpp new file mode 100644 index 000000000..0d1181243 --- /dev/null +++ b/compiler/lib/Dialect/Vector/TransformOps/VectorExtTransformOps.cpp @@ -0,0 +1,228 @@ +//===- VectorExtTransformOps.cpp - Implementation of Vector transform ops -===// +// +// Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from VectorExtTransformOps.cpp in IREE project +// Original license: +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Some code comes from DropUnitDims.cpp in LLVM project +// Original license: +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.h" +#include "byteir/Dialect/Ccl/IR/CclOps.h" +#include "byteir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "byteir/Utils/Hoist.h" +#include "byteir/Utils/TileUtils.h" +#include "byteir/Utils/Utils.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "mlir/Transforms/TopologicalSortUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" + +#include + +using namespace mlir; +using namespace mlir::Vector; +using namespace mlir::scf; +using namespace mlir::tensor; +using namespace mlir::transform; + +#define DEBUG_TYPE "Vector-ext-transforms" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace { +/// A simple pattern rewriter that implements no special logic. +class SimpleRewriter : public PatternRewriter { +public: + SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// ConvertReductionToGPUShuffleOp +//===----------------------------------------------------------------------===// +static Value warpReduction(Location loc, OpBuilder &builder, Value input, + Value acc, vector::CombiningKind kind, + uint32_t size) { + // Parallel reduction using butterfly shuffles. + for (uint64_t i = 1; i < size; i <<= 1) { + Value shuffled = builder + .create(loc, input, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); + input = vector::makeArithReduction(builder, loc, kind, input, shuffled); + } + return input; +} + +DiagnosedSilenceableFailure transform::ConvertReductionToGPUShuffleOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto targets = SmallVector(state.getPayloadOps(getTarget())); + for (const auto &payloadOp : targets) { + payloadOp->walk([&](vector::ReductionOp reduceOp) { + Location loc = reduceOp.getLoc(); + auto vectorOp = reduceOp.getVector(); + // if (vectorOp.getType().getRank() > 1) { + // reduceOp->emitError() << "the rank of vector should equal to 1"; + // return WalkResult::interrupt(); + // } + rewriter.setInsertionPoint(reduceOp); + + Region *parentRegion = reduceOp->getParentRegion(); + auto argNum = parentRegion->getNumArguments(); + // if (argNum != 2) { + // reduceOp->emitError() << "the args of region should equal to 2"; + // return WalkResult::interrupt(); + // } + BlockArgument blockArg = parentRegion->getArgument(0); + llvm::ArrayRef argShape = + blockArg.getType().dyn_cast().getShape(); + + // auto tensorSize = rewriter.create( + // loc, rewriter.getI32IntegerAttr(argShape.back())); + auto laneId = rewriter.create(loc); + auto laneVal = rewriter.create( + loc, rewriter.getI32Type(), laneId); + + // Value cond = rewriter.create( + // loc, arith::CmpIPredicate::slt, laneVal, tensorSize); + // scf::IfOp scfIf = rewriter.create(loc, cond, false); + + // rewriter.setInsertionPointToStart(scfIf.getBody(0)); + // Block *parentBlock = reduceOp->getBlock(); + + SmallVector extractIndex; + for (int64_t i = 0; i < argShape.size() - 1; i++) { + extractIndex.push_back(rewriter.create(loc, 1)); + } + extractIndex.push_back(laneId); + + auto input = + rewriter.create(loc, blockArg, extractIndex); + // IRRewriter rewriter(reduceOp.getContext()); + Value reduce = + warpReduction(loc, rewriter, input, reduceOp.getAcc(), + reduceOp.getKind(), vectorOp.getType().getShape()[0]); + BlockArgument blockOutput = parentRegion->getArgument(1); + ShapedType outputShape = blockOutput.getType().dyn_cast(); + + // if (outputShape.getNumElements() > 1) { + // reduceOp->emitError() << "the shape of reduction output of should + // equal to 1"; return WalkResult::interrupt(); + // } + llvm::ArrayRef OutputShape = outputShape.getShape(); + + SmallVector insertIndex; + for (int64_t i = 0; i < OutputShape.size(); i++) { + insertIndex.push_back(rewriter.create(loc, 0)); + } + auto insertOp = rewriter.create( + loc, reduce, blockOutput, insertIndex); + + Operation *terminator = parentRegion->back().getTerminator(); + terminator->setOperands(ValueRange{insertOp}); + }); + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::ConvertReductionToGPUShuffleOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + // onlyReadsHandle(getTarget(), effects); + producesHandle(getODSResults(0), effects); + // consumesHandle(getODSResults(0), effects); + consumesHandle(getTarget(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the +/// additional ops are using PDL types for operands and results. +class VectorExtTransformDialectExtension + : public transform::TransformDialectExtension< + VectorExtTransformDialectExtension> { +public: + using Base::Base; + + void init() { + // TODO remove unused ones + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.cpp.inc" + +void mlir::vector_ext::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/compiler/lib/Pipelines/GPU/MappingForall.cpp b/compiler/lib/Pipelines/GPU/MappingForall.cpp index 6e0ef686a..58a6f7507 100644 --- a/compiler/lib/Pipelines/GPU/MappingForall.cpp +++ b/compiler/lib/Pipelines/GPU/MappingForall.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/SmallSet.h" +#include #include using namespace mlir; @@ -58,10 +59,11 @@ bool isMappedToGPUBlocks(scf::ForallOp forallOp) { return false; } -bool isMappedToGPUThreads(scf::ForallOp forallOp) { +bool isMappedToGPUThreadsOrWarps(scf::ForallOp forallOp) { if (auto mapping = forallOp.getMappingAttr()) { - if (llvm::any_of(mapping.getValue(), [](Attribute attr) { - return isa(attr); + if (llvm::all_of(mapping.getValue(), [](Attribute attr) { + return isa(attr) || + isa(attr); })) { return true; } @@ -70,15 +72,25 @@ bool isMappedToGPUThreads(scf::ForallOp forallOp) { return false; } -void updateBlockDims(scf::ForallOp forallOp, SmallVector &blockDims) { +void updateBlockDims(scf::ForallOp forallOp, SmallVector &blockDims, + int32_t warpSize) { for (auto &&[lb, ub, step, mappingAttr] : llvm::zip( forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), forallOp.getMappingAttr().getValue())) { - if (auto threadMapping = - llvm::dyn_cast_or_null(mappingAttr)) { - auto numIterations = constantTripCount(lb, ub, step); - auto threadIdx = threadMapping.getMappingId(); - if (numIterations.has_value()) { + auto numIterations = constantTripCount(lb, ub, step); + if (numIterations.has_value()) { + if (auto threadMapping = + llvm::dyn_cast_or_null(mappingAttr)) { + auto threadIdx = threadMapping.getMappingId(); + blockDims[threadIdx] = + std::max(blockDims[threadIdx], numIterations.value()); + } else if (auto threadMapping = + llvm::dyn_cast_or_null( + mappingAttr)) { + auto threadIdx = threadMapping.getMappingId(); + if (threadIdx == 0) { + *numIterations *= warpSize; + } blockDims[threadIdx] = std::max(blockDims[threadIdx], numIterations.value()); } @@ -86,19 +98,85 @@ void updateBlockDims(scf::ForallOp forallOp, SmallVector &blockDims) { } } +void updateMaxIterationSpace(scf::ForallOp forallOp, int64_t &maxIterationSpace, + int32_t warpSize) { + int64_t IterationSpace = 1; + for (auto &&[lb, ub, step, mappingAttr] : llvm::zip( + forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), forallOp.getMappingAttr().getValue())) { + auto numIterations = constantTripCount(lb, ub, step); + if (numIterations.has_value()) { + if (auto threadMapping = + llvm::dyn_cast_or_null(mappingAttr)) { + if (threadMapping.getWarp() == gpu::MappingId::LinearDim0) { + *numIterations *= warpSize; + } + IterationSpace *= numIterations.value(); + } else if (auto threadMapping = + llvm::dyn_cast_or_null( + mappingAttr)) { + IterationSpace *= numIterations.value(); + } + } + } + if (IterationSpace > maxIterationSpace) { + maxIterationSpace = IterationSpace; + } +} + +bool isLinearMappingMode(scf::ForallOp forallOp) { + return llvm::all_of(forallOp.getMapping()->getValue(), [](Attribute a) { + return cast(a).isLinearMapping(); + }); +} + +bool isNonLinearMappingMode(scf::ForallOp forallOp) { + return !llvm::any_of(forallOp.getMapping()->getValue(), [](Attribute a) { + return cast(a).isLinearMapping(); + }); +} + std::optional getMappingForallConfig(scf::ForallOp forallOp) { if (!isMappedToGPUBlocks(forallOp)) return std::nullopt; - + const int32_t warpSize = 32; SmallVector blockDims{1, 1, 1}; + int64_t maxIterationSpace = 0; + bool hasMappingToWarpAndNonLenearModeOp = false; + auto &&block = forallOp.getRegion().front(); for (auto &&nestedForall : block.getOps()) { - if (isMappedToGPUThreads(nestedForall)) { - updateBlockDims(nestedForall, blockDims); + if (isMappedToGPUThreadsOrWarps(nestedForall)) { + if (isLinearMappingMode(nestedForall)) { + updateMaxIterationSpace(nestedForall, maxIterationSpace, warpSize); + } else if (isNonLinearMappingMode(nestedForall)) { + if (llvm::all_of(nestedForall.getMapping()->getValue(), + [](Attribute attr) { + return isa(attr); + })) { + hasMappingToWarpAndNonLenearModeOp = true; + } + updateBlockDims(nestedForall, blockDims, warpSize); + } } } + int64_t blockSize = blockDims[0] * blockDims[1] * blockDims[2]; + // TODO: Nested Forall Op with both nonlinear and linear modes in a Forall Op + // is not supported yet + if (blockSize != 1 && maxIterationSpace != 1) { + return std::nullopt; + } + if (maxIterationSpace > 1) { + if (hasMappingToWarpAndNonLenearModeOp && + maxIterationSpace % warpSize != 0) { + blockDims[0] = warpSize; + blockDims[1] = ceil((double)maxIterationSpace / warpSize); + } else { + blockDims[0] = maxIterationSpace; + } + } if (blockDims[0] * blockDims[1] * blockDims[2] > kMaximumBlockDim) { return std::nullopt; } @@ -134,7 +212,7 @@ void createGPUMappingForallTransformImpl(OpPassManager &pm, /* target */ launchOp.getResult(), /* block_dims */ mappingConfig.blockDims, /* sync_after_distribute*/ true, - /* warp_dims */ 32); + /* warp_size */ 32); }; pm.addPass(createGenericTransformInsertionPass(config)); diff --git a/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp b/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp index fc3db3fb8..b94786ded 100644 --- a/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/ReductionCodegen.cpp @@ -34,8 +34,10 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/SmallSet.h" - -#include +#include +#include +#include +#include using namespace mlir; @@ -246,6 +248,10 @@ struct BlockTileConfig { SmallVector tileSizes; SmallVector mapping; std::vector fuseCandidates; + int64_t reductionSize; + int64_t remainBlockSize; + int64_t warpSize; + bool mapToWarp; void apply(ImplicitLocOpBuilder &b, Value pdlV, bool usingForall); }; @@ -276,11 +282,14 @@ void processProducerSelectors( } } -void tileToForallAndFuseImpl( - ImplicitLocOpBuilder &b, Value toTile, - const SmallVector &tileSizes, - const SmallVector &mapping, - const std::vector &fuseCandidates) { +transform::TileUsingForallOp +tileToForallAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, + const SmallVector &tileSizes, + const SmallVector &mapping, + const std::vector &fuseCandidates, + const int64_t reductionSize = 0, + const int64_t warpSize = 32, + const int64_t remainBlockSize = 512) { SmallVector toBeFused; processProducerSelectors(b, fuseCandidates, toTile, toBeFused); @@ -294,6 +303,18 @@ void tileToForallAndFuseImpl( /* producerOp */ producerOp, /* containingOp */ tileOp.getForallOp()); } + // if (reductionSize == warpSize && remainBlockSize >= warpSize) { + // auto warpMap = b.getArrayAttr( + // {gpu::GPUWarpMappingAttr::get(b.getContext(), gpu::Warps::DimX)}); + // auto tileSize = ArrayRef{0, 0}; + // tileOp = b.create( + // /* target */ tileOp.getTiledOp(), + // /* staticTileSizes */ tileSize, + // /* ctor tag */ transform::TileSizesSpec(), + // /* mapping */ + // warpMap); + // } + return tileOp; } void tileToSCFForAndFuseImpl(ImplicitLocOpBuilder &b, Value toTile, @@ -452,12 +473,36 @@ void BlockSplitConfig::apply(ImplicitLocOpBuilder &b, Value pdlV) { void BlockTileConfig::apply(ImplicitLocOpBuilder &b, Value pdlV, bool usingForall) { + static int32_t reduction_kernel_cnt = 0; if (usingForall) { auto mappingAttrs = llvm::to_vector( llvm::map_range(mapping, [&](gpu::MappingId dim) -> Attribute { + if (mapToWarp == true) { + return gpu::GPUWarpMappingAttr::get(b.getContext(), dim); + } return gpu::GPUThreadMappingAttr::get(b.getContext(), dim); })); - tileToForallAndFuseImpl(b, pdlV, tileSizes, mappingAttrs, fuseCandidates); + transform::TileUsingForallOp tileOp = tileToForallAndFuseImpl( + b, pdlV, tileSizes, mappingAttrs, fuseCandidates, reductionSize, + warpSize, remainBlockSize); + if (mapToWarp) { + auto pdlType = pdl::OperationType::get(b.getContext()); + std::string func_name = + "redunction_kernel_" + std::to_string(reduction_kernel_cnt++); + transform::LoopOutlineOp outlineOp = b.create( + pdlType, pdlType, tileOp.getTiledOp(), func_name); + auto vecChildOp = + b.create( + outlineOp.getFunction()); + b.create( + vecChildOp, [](OpBuilder &b, Location loc) { + b.create( + loc, vector::VectorMultiReductionLowering::InnerReduction); + }); + auto shuffleOp = + b.create(vecChildOp); + b.create(outlineOp.getCall()); + } } else { static constexpr std::array mappings{ getThreadIdXName(), getThreadIdYName(), getThreadIdZName()}; @@ -580,7 +625,6 @@ std::optional getGridTileConfig(linalg::GenericOp genericOp, SmallVector tileSizes(numLoops, 1); auto loopSizes = cast(genericOp.getOperation()).computeStaticLoopSizes(); - for (auto &&affineMap : genericOp.getIndexingMapsArray()) { if (affineMap.isPermutation()) { auto dim = affineMap.getDimPosition(numLoops - 1); @@ -600,6 +644,10 @@ std::optional getGridTileConfig(linalg::GenericOp genericOp, } auto numTiledLoops = getNumTiledLoops(tileSizes); + if (!numTiledLoops) { + tileSizes[redDim] = loopSizes[redDim]; + numTiledLoops = 1; + } if (numTiledLoops >= 1 && numTiledLoops <= 3) { SmallVector mapping(numLoops, -1); int64_t dimMapping = static_cast(gpu::MappingId::DimX); @@ -654,7 +702,6 @@ std::optional getBlockSplitConfig(linalg::GenericOp genericOp, splitFactor = newSplitFactor / 2; } } - if (staticLoopRanges[redDim] < splitFactor) { splitFactor = staticLoopRanges[redDim]; } else { @@ -674,8 +721,13 @@ std::optional getBlockSplitConfig(linalg::GenericOp genericOp, } } - for (; splitFactor > 2; splitFactor >>= 1) { - splitFactors.push_back(splitFactor / 2); + // for (; splitFactor > 2; splitFactor >>= 1) { + // splitFactors.push_back(splitFactor / 2); + // dimensions.push_back(redDim ? redDim - 1 : redDim); + // } + // haven't consider padding + if (splitFactor > warpSize) { + splitFactors.push_back(splitFactor / 32); dimensions.push_back(redDim ? redDim - 1 : redDim); } @@ -695,16 +747,27 @@ std::optional getBlockTileConfig(linalg::GenericOp genericOp, int64_t remainBlockSize = blockSize; auto redDim = getReductionDim(genericOp).value(); + bool mapToWarp = false; for (int64_t idx = 0; idx < numLoops && remainBlockSize > 1; ++idx) { - if (idx == redDim) - continue; - int64_t curLoopSize2 = nextPowerOf2(loopSizes[idx]); - int64_t curBlockSize = std::min(curLoopSize2, remainBlockSize); - tileSizes[idx] = curLoopSize2 / curBlockSize; - remainBlockSize /= curBlockSize; + if (idx == redDim) { + if ((loopSizes[idx] <= warpSize) && + (getOperandReductionDim(*genericOp.getDpsInputOperand(0)).value() == + numLoops - 1) && + (remainBlockSize >= loopSizes[idx])) { + tileSizes[idx] = loopSizes[idx]; + remainBlockSize /= loopSizes[idx]; + mapToWarp = true; + } + } else { + int64_t curLoopSize2 = nextPowerOf2(loopSizes[idx]); + int64_t curBlockSize = std::min(curLoopSize2, remainBlockSize); + tileSizes[idx] = curLoopSize2 / curBlockSize; + remainBlockSize /= curBlockSize; + } } - if (remainBlockSize == blockSize) { + if ((remainBlockSize == blockSize) || + (loopSizes[redDim] == warpSize && remainBlockSize >= warpSize)) { tileSizes[redDim] = loopSizes[redDim]; } @@ -719,7 +782,7 @@ std::optional getBlockTileConfig(linalg::GenericOp genericOp, auto numTiledLoops = getNumTiledLoops(tileSizes); if (numTiledLoops >= 1 && numTiledLoops <= 3) { SmallVector mapping(numLoops, -1); - int64_t dimMapping = static_cast(gpu::MappingId::DimX); + int64_t dimMapping = static_cast(gpu::MappingId::LinearDim0); for (auto &&affineMap : genericOp.getIndexingMapsArray()) { if (affineMap.isPermutation()) { for (int64_t i = numLoops - 1; i >= 0; i--) { @@ -740,7 +803,11 @@ std::optional getBlockTileConfig(linalg::GenericOp genericOp, tileSizes, llvm::to_vector(llvm::map_range( mapping, [](int64_t i) { return static_cast(i); })), - fuseCandidates}; + fuseCandidates, + loopSizes[redDim], + remainBlockSize, + warpSize, + mapToWarp}; } return std::nullopt; } diff --git a/compiler/tools/byteir-opt/byteir-opt.cpp b/compiler/tools/byteir-opt/byteir-opt.cpp index 760e6dd67..2e4612184 100644 --- a/compiler/tools/byteir-opt/byteir-opt.cpp +++ b/compiler/tools/byteir-opt/byteir-opt.cpp @@ -24,7 +24,6 @@ #include "byteir/Dialect/Ccl/TransformOps/CclTransformOps.h" #include "byteir/Dialect/GPU/Passes.h" #include "byteir/Dialect/Lace/LaceDialect.h" -#include "byteir/Dialect/Lccl/LcclOps.h" #include "byteir/Dialect/Linalg/IR/LinalgExtOps.h" #include "byteir/Dialect/Linalg/Passes.h" #include "byteir/Dialect/Linalg/TransformOps/LinalgExtTransformOps.h" @@ -36,6 +35,7 @@ #include "byteir/Dialect/Tensor/Passes.h" #include "byteir/Dialect/Transform/IR/TransformExtOps.h" #include "byteir/Dialect/Transform/Passes.h" +#include "byteir/Dialect/Vector/TransformOps/VectorExtTransformOps.h" #include "byteir/Dialect/Vector/Transforms/Passes.h" #include "byteir/Dialect/mhlo/Passes.h" #include "byteir/Pipelines/InitAllPipelines.h" @@ -153,7 +153,6 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -163,6 +162,7 @@ int main(int argc, char **argv) { linalg_ext::registerTransformDialectExtension(registry); transform_ext::registerTransformDialectExtension(registry); tensor_ext::registerTilingInterfaceExternalModels(registry); + vector_ext::registerTransformDialectExtension(registry); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "ByteIR pass driver\n", registry)); diff --git a/tests/numerical_test/torch_e2e_testing/test_suite/basic.py b/tests/numerical_test/torch_e2e_testing/test_suite/basic.py index 483e434f7..4cdb6e154 100644 --- a/tests/numerical_test/torch_e2e_testing/test_suite/basic.py +++ b/tests/numerical_test/torch_e2e_testing/test_suite/basic.py @@ -86,3 +86,23 @@ def forward(self, a, b, c): @register_test_case(module_factory=lambda: BatchMatmulAddF32Module()) def BatchMatmulAddF32Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 6), tu.rand(2, 6, 10), tu.rand(2, 5, 10)) + + +class UniformStaticShapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + a = torch.ops.aten.uniform_(x, 1.0, 10.0) + std = torch.cat([ + torch.flatten(torch.std(a,dim=(1,2))), + ]) + mean = torch.cat([ + torch.flatten(torch.mean(a,dim=(0,1))), + ]) + return torch.cat([std, mean]) + +@register_test_case(module_factory=lambda: UniformStaticShapeModule()) +def UniformStaticShapeModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(32, 512, 256).float())