From f1321e440859d3f3ddf7e0db201d1365b219b006 Mon Sep 17 00:00:00 2001 From: enjustli <798634436@qq.com> Date: Wed, 3 Dec 2025 10:41:25 +0800 Subject: [PATCH 1/3] rm triton-to-linalg pass --- CMakeLists.txt | 2 +- README.md | 2 +- .../triton-shared/Conversion/CMakeLists.txt | 1 - .../Conversion/TritonToLinalg/CMakeLists.txt | 9 - .../Conversion/TritonToLinalg/Passes.h | 22 - .../Conversion/TritonToLinalg/Passes.td | 18 - .../TritonToLinalg/TritonToLinalg.h | 33 - lib/Conversion/CMakeLists.txt | 1 - lib/Conversion/TritonToLinalg/CMakeLists.txt | 27 - .../TritonToLinalg/TritonToLinalg.cpp | 95 -- .../TritonToLinalg/TritonToLinalgPass.cpp | 229 ----- .../TritonToLinalg/addptr_2d_example.mlir | 69 -- .../TritonToLinalg/addptr_add_value.mlir | 68 -- .../TritonToLinalg/addptr_dim1.mlir | 113 --- .../addptr_for_accumulation.mlir | 92 -- .../TritonToLinalg/addptr_for_expand_ptr.mlir | 73 -- .../addptr_for_more_init_args.mlir | 71 -- .../addptr_for_used_after_update.mlir | 98 --- .../addptr_for_used_before_update.mlir | 55 -- .../TritonToLinalg/addptr_loopback.mlir | 53 -- .../addptr_mul_const_const.mlir | 49 -- .../addptr_mul_value_const.mlir | 51 -- .../TritonToLinalg/addptr_nested.mlir | 73 -- .../addptr_reshape_broadcast.mlir | 43 - .../addptr_scalar_broadcast.mlir | 65 -- .../TritonToLinalg/addptr_scalar_for.mlir | 70 -- .../TritonToLinalg/addptr_scalar_for_2d.mlir | 92 -- .../addptr_scalar_loopback.mlir | 27 - .../TritonToLinalg/addptr_scalar_nested.mlir | 57 -- .../TritonToLinalg/addptr_scalar_splat.mlir | 45 - .../addptr_scalar_splat_2d.mlir | 56 -- .../TritonToLinalg/arith_not_ptr_arith.mlir | 39 - test/Conversion/TritonToLinalg/bitcast.mlir | 44 - .../TritonToLinalg/block_ptr_advance.mlir | 90 -- .../convert_1d_elemwise_arith_binary.mlir | 72 -- .../convert_1d_elemwise_arith_ternary.mlir | 49 -- .../convert_1d_elemwise_arith_unary.mlir | 88 -- .../convert_2d_elemwise_arith_binary.mlir | 55 -- .../convert_2d_elemwise_arith_ternary.mlir | 55 -- .../convert_2d_elemwise_arith_unary.mlir | 94 -- .../TritonToLinalg/convert_addi_reduce.mlir | 32 - .../TritonToLinalg/convert_argmin_argmax.mlir | 141 --- .../convert_argmin_argmax_2d.mlir | 215 ----- .../convert_extern_elementwise.mlir | 809 ------------------ .../TritonToLinalg/convert_minmax.mlir | 50 -- .../convert_minmax_fp_reduce.mlir | 68 -- .../TritonToLinalg/convert_minmax_reduce.mlir | 126 --- .../TritonToLinalg/convert_splat_float.mlir | 23 - .../convert_tensor_reshape.mlir | 45 - test/Conversion/TritonToLinalg/cumsum.mlir | 68 -- test/Conversion/TritonToLinalg/dot.mlir | 84 -- .../TritonToLinalg/get_num_programs.mlir | 45 - .../TritonToLinalg/reducemax_32_256_bf16.mlir | 58 -- .../reducesum_512_256_bf16_axis0.mlir | 51 -- .../reducesum_512_256_bf16_axis1.mlir | 53 -- .../reducesum_512_256_f32_axis0.mlir | 51 -- .../reducesum_512_256_f32_axis1.mlir | 53 -- .../TritonToLinalg/reducesum_middle_dim.mlir | 60 -- .../TritonToLinalg/reducesum_scalar.mlir | 38 - .../TritonToLinalg/triton_assert.mlir | 50 -- .../unsupported_extern_elementwise.mlir | 35 - .../TritonToLinalg/use_dot_opc.mlir | 76 -- .../TritonToLinalg/use_end_chain.mlir | 95 -- .../TritonToLinalg/use_mid_chain.mlir | 64 -- .../wraparound_side_by_side.mlir | 133 --- .../TritonToLinalg/wraparound_stacked.mlir | 129 --- .../wraparound_unsupported_add_offset.mlir | 57 -- .../RegisterTritonSharedDialects.h | 2 - 68 files changed, 2 insertions(+), 5054 deletions(-) delete mode 100644 include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt delete mode 100644 include/triton-shared/Conversion/TritonToLinalg/Passes.h delete mode 100644 include/triton-shared/Conversion/TritonToLinalg/Passes.td delete mode 100644 include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h delete mode 100644 lib/Conversion/TritonToLinalg/CMakeLists.txt delete mode 100644 lib/Conversion/TritonToLinalg/TritonToLinalg.cpp delete mode 100644 lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp delete mode 100644 test/Conversion/TritonToLinalg/addptr_2d_example.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_add_value.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_dim1.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_loopback.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_nested.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_for.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir delete mode 100644 test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir delete mode 100644 test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir delete mode 100644 test/Conversion/TritonToLinalg/bitcast.mlir delete mode 100644 test/Conversion/TritonToLinalg/block_ptr_advance.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_addi_reduce.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_minmax.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_splat_float.mlir delete mode 100644 test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir delete mode 100644 test/Conversion/TritonToLinalg/cumsum.mlir delete mode 100644 test/Conversion/TritonToLinalg/dot.mlir delete mode 100644 test/Conversion/TritonToLinalg/get_num_programs.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir delete mode 100644 test/Conversion/TritonToLinalg/reducesum_scalar.mlir delete mode 100644 test/Conversion/TritonToLinalg/triton_assert.mlir delete mode 100644 test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir delete mode 100644 test/Conversion/TritonToLinalg/use_dot_opc.mlir delete mode 100644 test/Conversion/TritonToLinalg/use_end_chain.mlir delete mode 100644 test/Conversion/TritonToLinalg/use_mid_chain.mlir delete mode 100644 test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir delete mode 100644 test/Conversion/TritonToLinalg/wraparound_stacked.mlir delete mode 100644 test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ae68d38..bc827b37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ add_subdirectory(test) add_subdirectory(tools/triton-shared-opt) if (TRITON_SHARED_BUILD_CPU_BACKEND) - add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR) + add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonTilingExtIR) target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers) endif() diff --git a/README.md b/README.md index 8d405534..9802ee7d 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ As part of the conversion process, there are three important analyses: ### Conversion strategy -We introduce the `TritonToLinalg` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like: +We introduce the `TritonToLinalgExperimental` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like: ```mlir tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr) { diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index a4a03949..f8e180a7 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(TritonToLinalg) add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) diff --git a/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt b/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt deleted file mode 100644 index 74ccdd39..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -#===------------------------------------------------------------------------===# -# -# Copyright (c) Triton Project Contributors. -# -#===------------------------------------------------------------------------===# - -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) -add_public_tablegen_target(TritonToLinalgConversionPassIncGen) diff --git a/include/triton-shared/Conversion/TritonToLinalg/Passes.h b/include/triton-shared/Conversion/TritonToLinalg/Passes.h deleted file mode 100644 index 404af080..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/Passes.h +++ /dev/null @@ -1,22 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_TO_LINALG_CONVERSION_PASSES_H -#define TRITON_TO_LINALG_CONVERSION_PASSES_H - -#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" - -namespace mlir { -namespace triton { - -#define GEN_PASS_REGISTRATION -#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" - -} // namespace triton -} // namespace mlir - -#endif diff --git a/include/triton-shared/Conversion/TritonToLinalg/Passes.td b/include/triton-shared/Conversion/TritonToLinalg/Passes.td deleted file mode 100644 index 627077e3..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/Passes.td +++ /dev/null @@ -1,18 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_TO_LINALG_CONVERSION_PASSES -#define TRITON_TO_LINALG_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { - let summary = "Convert Triton to Linalg dialect"; - let constructor = "triton::createTritonToLinalgPass()"; -} - -#endif diff --git a/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h b/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h deleted file mode 100644 index 4c58e992..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H -#define TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace mlir { -namespace triton { - -std::unique_ptr> createTritonToLinalgPass(); - -void populateTritonToLinalgCanonicalizationPatterns( - RewritePatternSet &patterns); - -void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns, - unsigned int launchGridRank); - -} // namespace triton -} // namespace mlir - -#endif // TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 358b4f92..2a591e97 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(TritonToLinalg) add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonToUnstructured) diff --git a/lib/Conversion/TritonToLinalg/CMakeLists.txt b/lib/Conversion/TritonToLinalg/CMakeLists.txt deleted file mode 100644 index acc3c4fb..00000000 --- a/lib/Conversion/TritonToLinalg/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -#===------------------------------------------------------------------------===# -# -# Copyright (c) Triton Project Contributors. -# -#===------------------------------------------------------------------------===# - -add_triton_library(TritonToLinalg - TritonToLinalg.cpp - TritonToLinalgPass.cpp - - DEPENDS - TritonToLinalgConversionPassIncGen - - LINK_LIBS PUBLIC - TritonTilingExtIR - MLIRArithDialect - MLIRDialectUtils - MLIRIR - MLIRMathDialect - MLIRPass - MLIRTensorDialect - MLIRTransforms - MLIRSupport - TritonIR - TritonTransforms - TritonSharedAnalysis -) diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp deleted file mode 100644 index 1c8ed9cf..00000000 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ /dev/null @@ -1,95 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" - -#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" - -#define DEBUG_TYPE "triton-to-linalg" -#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp" - -using namespace mlir; -using namespace triton; - -#define GEN_PASS_CLASSES -#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" - -void mlir::triton::populateTritonToLinalgCanonicalizationPatterns( - RewritePatternSet &patterns) { - patterns.add, MinMaxConverter>( - patterns.getContext()); -} - -void mlir::triton::populateTritonToLinalgConversionPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns, - unsigned int launchGridRank) { - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - populateExternElementwiseOpToMLIROps(patterns); - - // Reduce converters - // Triton's reduce op is idential to linalg.reduce op, so we can clone - // `tt.reduce` body to `linalg.reduce`. Unfortunately, we still need to - // perform pattern matching to know what reduce ops we are dealing with - // so that we know how to initialize the initial reduce values correctly. - // - // We can do this in a generic way without pattern matching by always using - // the first elements along the reduction axis and perform the reduction on - // the remaining elements. However, this results in creatings sub-tensors that - // aren't always multiple of 2s, which are sub-optimal for certain hardwares. - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - // Note: the ordering here matters! - // MetaOpConverter has PatternBenefit == 10 which should take precedence over - // these linalg patterns, but to be safe, add these patterns last so that they - // will be tried last. Incorrect ordering or having MetaOpConverter has lower - // PatternBenefit will result in element-wise meta ops being converted to - // linalg.generic ops. - linalg::populateElementwiseToLinalgConversionPatterns(patterns); -} diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp deleted file mode 100644 index 25b7db85..00000000 --- a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp +++ /dev/null @@ -1,229 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "triton-shared/Analysis/UseAnalysis.h" -#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" -#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" - -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "triton-to-linalg" - -using namespace mlir; -using namespace triton; - -#define GEN_PASS_CLASSES -#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" - -namespace { - -class TritonTypeConverter : public TypeConverter { -public: - TritonTypeConverter() { - // The order of type conversion is important: later ones are tried earlier. - addConversion([](Type type) { return type; }); - addConversion([](triton::PointerType ptrType) { - return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); - }); - addConversion([](TensorType tensorType) -> Type { - auto elemType = tensorType.getElementType(); - if (auto ptrType = dyn_cast(elemType)) { - elemType = ptrType.getPointeeType(); - } - return MemRefType::get(tensorType.getShape(), elemType); - }); - } -}; - -class TritonToLinalgPass : public TritonToLinalgBase { - - static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; - static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = - LAUNCH_GRID_RANK * 2; - - // Add additional I32 arguments to represent: - // - num_programs, 3 in total, one for each axis of the launch grid - // - program_id, 3 in total, one for each axis of the launch grid - static void addProgramInfo(triton::FuncOp func) { - OpBuilder b(func); - - auto origFuncType = func.getFunctionType(); - auto origInputTypes = origFuncType.getInputs(); - SmallVector newInputTypes(origInputTypes); - newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); - - auto newFuncType = - b.getFunctionType(newInputTypes, origFuncType.getResults()); - - func.setFunctionType(newFuncType); - - // Add empty attributes for each new argument if needed - if (func.getAllArgAttrs()) { - SmallVector newArgAttrs; - func.getAllArgAttrs(newArgAttrs); - newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); - func.setAllArgAttrs(newArgAttrs); - } - - // Add the corresponding arguments to function body - for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { - func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); - } - } - -public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - - void runOnOperation() override { - auto moduleOp = getOperation(); - - { - RewritePatternSet patterns(&getContext()); - populateTritonToLinalgCanonicalizationPatterns(patterns); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { - signalPassFailure(); - } - } - - moduleOp.walk([this](triton::FuncOp op) { - if (failed(runUseAnalysis(op))) { - signalPassFailure(); - } - }); - - RewritePatternSet patterns(&getContext()); - ConversionTarget target(getContext()); - TritonTypeConverter tritonTypeConverter; - - target.addLegalDialect< - func::FuncDialect, arith::ArithDialect, math::MathDialect, - linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, - cf::ControlFlowDialect, tensor::TensorDialect, - bufferization::BufferizationDialect, memref::MemRefDialect, - ttx::TritonTilingExtDialect>(); - - target.addLegalOp(); - - // Update function signature to use memrefs - target.addDynamicallyLegalOp([&](triton::FuncOp op) { - return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); - }); - - // Lower dense constant to linalg.fill - target.addDynamicallyLegalOp([](arith::ConstantOp op) { - if (!isa(op.getResult().getType())) { - return true; - } - - if (auto denseAttr = dyn_cast(op.getValue())) { - if (denseAttr.isSplat() && - isa(denseAttr.getElementType())) { - return false; - } - } - return true; - }); - - target.addDynamicallyLegalOp([](Operation *op) { - return llvm::all_of(op->getOperandTypes(), [](Type t) { - if (isa(t)) { - return false; - } - if (auto shapedType = dyn_cast(t)) { - return shapedType.getElementType().isIntOrFloat(); - } - assert(t.isIntOrIndexOrFloat()); - return true; - }); - }); - - target.addDynamicallyLegalDialect( - [](Operation *op) { - if (op->hasAttr("MetaUse")) { - return false; - } - - if (isa(op)) { - return true; - } - - bool operateOnTensors = - llvm::all_of(op->getOperandTypes(), [](Type type) { - return isa(type); - }); - - return !operateOnTensors; - }); - - triton::populateTritonToLinalgConversionPatterns( - tritonTypeConverter, patterns, LAUNCH_GRID_RANK); - - for (auto func : getOperation().getOps()) - addProgramInfo(func); - - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) - signalPassFailure(); - - // Convert tt.func and tt.return into func's counterparts - moduleOp.walk([&](triton::FuncOp func) { - OpBuilder builder(func); - - auto name = func.getName(); - auto type = func.getFunctionType(); - - SmallVector argAttrs, resAttrs; - func.getAllArgAttrs(argAttrs); - func.getAllResultAttrs(resAttrs); - - auto funcFunc = builder.create(func.getLoc(), name, type); - funcFunc.setAllArgAttrs(argAttrs); - funcFunc.setAllResultAttrs(resAttrs); - - auto &funcFuncBody = funcFunc.getBody(); - auto &funcBody = func.getBody(); - - IRMapping map; - funcBody.cloneInto(&funcFuncBody, map); - - for (Block &block : funcFuncBody.getBlocks()) { - auto term = block.getTerminator(); - builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); - term->erase(); - } - func.erase(); - }); - - // Erase dead code and fold constants created during lowering - PassManager pm(&getContext(), moduleOp.getOperationName()); - pm.addPass(createCanonicalizerPass()); - if (failed(runPipeline(pm, getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> triton::createTritonToLinalgPass() { - return std::make_unique(); -} diff --git a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir b/test/Conversion/TritonToLinalg/addptr_2d_example.mlir deleted file mode 100644 index f0f7d1c7..00000000 --- a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir +++ /dev/null @@ -1,69 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr, - %arg3 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> - // offset = [%arg3,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}: tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %6 = arith.constant 5 : i32 - %splat6 = tt.splat %6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> - // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %10 = tt.load %9 : tensor<4x256x!tt.ptr> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %13 = tt.load %12 : tensor<4x256x!tt.ptr> - %14 = arith.addf %10, %13 : tensor<4x256xbf16> - %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> - %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - tt.store %16, %14 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = arith.constant 5 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_15]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: bf16, %[[VAL_18:.*]]: bf16, %[[VAL_19:.*]]: bf16): -// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16 -// CHECK: linalg.yield %[[VAL_20]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_21]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_add_value.mlir b/test/Conversion/TritonToLinalg/addptr_add_value.mlir deleted file mode 100644 index 0ed60796..00000000 --- a/test/Conversion/TritonToLinalg/addptr_add_value.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32, - %arg3 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32> - %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> - // offset = [%arg2,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %offset2, %arg3splat : tensor<4x256xi32> - // offset = [%arg2+%arg3,0], size = [4,256], stride = [1,0] - %c10 = arith.constant 10 : i32 - %c10splat = tt.splat %c10 : i32 -> tensor<4x256xi32> - %offset4 = arith.addi %offset3, %c10splat : tensor<4x256xi32> - // offset = [%arg2+%arg3+10,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,6] - %7 = arith.addi %offset4, %scale5: tensor<4x256xi32> - // offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>,tensor<4x256xi32> - // source = %arg0, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] - %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source = %arg1, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] - %12 = tt.load %9 : tensor<4x256x!tt.ptr> - tt.store %11, %12 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 10 : index -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : index -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_8]] : index -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_19]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_19]] restrict writable : memref<4x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_20]] in writable %[[VAL_18]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_dim1.mlir b/test/Conversion/TritonToLinalg/addptr_dim1.mlir deleted file mode 100644 index 0e314fa4..00000000 --- a/test/Conversion/TritonToLinalg/addptr_dim1.mlir +++ /dev/null @@ -1,113 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -// XFAIL: * -// This test crashes because tt.broadcast's folder tries to cast -// the src operand to a RankedTensorType value, but the TritonToLinalg -// pass has already replaced the src with a value of a different type. -// We're going to retire the monolith triton-to-linalg pass which prevents -// this problem. xfailing the test for now. - -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : i32 - ) - { - %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - - %splat_arg0 = tt.splat %arg0 : !tt.ptr -> tensor<1x256x!tt.ptr> - %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> - - // 1x256 pointer should have meaningful stride in outer dimension - %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256x!tt.ptr> - - %4 = tt.splat %arg1 : i32 -> tensor<1x256xi32> - // 1x256 pointer should have meaningful stride in outer dimension - %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> - tt.store %5, %3 : tensor<1x256x!tt.ptr>, tensor<1x256x!tt.ptr> - - %10 = arith.constant 0.0 : bf16 - %11 = tt.splat %10 : bf16 -> tensor<4x256xbf16> - - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %c256 = arith.constant 256 : i32 - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { - %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> - - %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - %i_i32 = arith.index_cast %i : index to i32 - %21 = arith.muli %c256, %i_i32 : i32 - %22 = tt.splat %21 : i32 -> tensor<4xi32> - %23 = arith.muli %20, %22 : tensor<4xi32> - %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %25 = tt.broadcast %24 : tensor<4x1xi32> -> tensor<4x256xi32> - - // %bptr should have zero stride and %30 should have correct stride - %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> - - %40 = tt.splat %c256 : i32 -> tensor<1x256xi32> - %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> - - scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> - } - - %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - %splat_c256 = tt.splat %c256 : i32 -> tensor<4xi32> - %32 = arith.muli %31, %splat_c256 : tensor<4xi32> - %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %34 = tt.broadcast %33 : tensor<4x1xi32> -> tensor<4x256xi32> - %35 = tt.broadcast %2 : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> - %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - tt.store %36, %sum_out : tensor<4x256x!tt.ptr>, tensor<4x256x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<4x256xbf16>) -> tensor<4x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<1x256xbf16, strided<[256, 1]>> to memref<1x256xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[CST_0_]], [[VAR_arg8_:%.+]] = [[CST_0_]]) -> (tensor<4x256xbf16>, index, index) { -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg5_]] : index to i32 -// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_1_]] : i32 -// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_arg8_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], [[CST_1_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16> -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg6_]], [[VAR_9_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg6_]] : tensor<4x256xbf16>) { -// CHECK: ^bb0([[in1:%.+]]: bf16, [[in2:%.+]]: bf16, [[out:%.+]]: bf16): -// CHECK: [[VAR_13_:%.+]] = arith.addf [[in1]], [[in2]] : bf16 -// CHECK: linalg.yield [[VAR_13_]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_256_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_11_]], [[VAR_arg8_]] : index -// CHECK: scf.yield [[VAR_10_]], [[VAR_12_]], [[CST_0_]] : tensor<4x256xbf16>, index, index -// CHECK: } -// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>> -// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir b/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir deleted file mode 100644 index 89cb4590..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr, - %arg3 : i32, - %arg4 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> - // offset = [%arg3,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c5 = arith.constant 5 : i32 - %splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32> - // scalar = 5 - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here? - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here? - // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> // Why is the input unknown - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %19 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> // this will be replaced with a memref.copy - %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr>) { - %20 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16> - // pointer updates - %17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32> - // offset: [3, 0], size = [4, 256], stride [0, 0] - %ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5] - scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr> - } - %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> - %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - tt.store %16, %sum_out : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_1O:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: %[[VAL_19:.*]]:3 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index) { -// CHECK: %[[VAL_25:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_22]], %[[VAL_25]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_25]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_21]], %[[VAL_26]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_21]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_28:.*]]: bf16, %[[VAL_29:.*]]: bf16, %[[VAL_30:.*]]: bf16): -// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : bf16 -// CHECK: linalg.yield %[[VAL_31]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index -// CHECK: %[[VAL_34:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_33]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index -// CHECK: } -// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_36]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in writable %[[VAL_37]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir b/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir deleted file mode 100644 index 67d82948..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - - // gep operand is another gep' output, which is passed into the loop as varible, used after update - %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - - %8 = tt.broadcast %7 : tensor<256x1xi32> -> tensor<256x256xi32> - // sizes: [256, 256], offsets: [0, 0], strides: [1, 0] - - %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - - %11 = tt.broadcast %10 : tensor<1x256xi32> -> tensor<256x256xi32> - // sizes: [256, 256], offsets: [0, 256], strides: [0, 1] - - %12 = arith.addi %8, %11 : tensor<256x256xi32> - // sizes: [256, 256], offsets: [0, 256], strides: [1, 1] - - %13 = tt.expand_dims %ptr {axis = 1 : i32} : tensor<256x!tt.ptr> -> tensor<256x1x!tt.ptr> - %14 = tt.broadcast %13 : tensor<256x1x!tt.ptr> -> tensor<256x256x!tt.ptr> - - %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> - // source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [2, 1] - - // perform load - %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256x!tt.ptr> - tt.store %15, %16 : tensor<256x256x!tt.ptr> - // pointer updates - %17 = tt.splat %i_c3 : i32 -> tensor<256xi32> - // sizes: 256, offsets: 3, strides: 0 - %ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i, strides: 4 - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (index) { -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index -// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [256, 256], strides: {{\[}}%[[VAL_4]], 1] : memref<*xbf16> to memref<256x256xbf16, strided<[?, 1], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<256x256xbf16> -// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<256x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_16]] in writable %[[VAL_14]] -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : index -// CHECK: scf.yield %[[VAL_17]] : index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir b/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir deleted file mode 100644 index 4d77760e..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir +++ /dev/null @@ -1,71 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c12 = arith.constant 12 : index - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - %3 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> - %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg1, sizes: 256, offsets: 1024, strides: 1 - %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { - // perform load - %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr_st, %5 : tensor<256x!tt.ptr> - // pointer updates - %cast3 = arith.index_cast %c3 : index to i32 - %6 = tt.splat %cast3 : i32 -> tensor<256xi32> - %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 1 - %arg2_iter = arith.addi %arg2, %c3 : index - %arg3_iter = arith.addi %arg3, %c3 : index - %arg4_iter = arith.addi %arg4, %c3 : index - %7 = arith.addi %arg2_iter, %arg3_iter : index - %8 = arith.addi %7, %arg4_iter : index - %cast8 = arith.index_cast %8 : index to i32 - %9 = tt.splat %cast8 : i32 -> tensor<256xi32> - %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 1 - scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index - } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_14:.*]]:7 = scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_9]], %[[VAL_19:.*]] = %[[VAL_13]], %[[VAL_20:.*]] = %[[VAL_10]], %[[VAL_21:.*]] = %[[VAL_6]], %[[VAL_22:.*]] = %[[VAL_6]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) { -// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_24]] in writable %[[VAL_19]] -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index -// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index -// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_10]] : index -// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_20]], %[[VAL_10]] : index -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index -// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]] : index -// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_22]], %[[VAL_31]] : index -// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: scf.yield %[[VAL_27]], %[[VAL_26]], %[[VAL_28]], %[[VAL_33]], %[[VAL_29]], %[[VAL_25]], %[[VAL_32]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir deleted file mode 100644 index 60b0b7fc..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir +++ /dev/null @@ -1,98 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - // gep operand is another gep' output, which is passed into the loop as varible, used after update - %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - // pointer updates - %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> - // sizes: 256, offsets: 3, strides: 0 - %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i, strides: 1 - // perform load - %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr_iter, %3 : tensor<256x!tt.ptr> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // %subview = memref.subview %arg0, [%4][256][4] : memref<> -> memref<> <- generate subview on getelementptr (already done) - // ... - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - // TODO: examples below are not supported since scf.for does not support returning a tensor type - // Example 3, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used after update - //%_ptr3 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { - // // offset update - // %3 = tt.splat %c3 : i32 -> tensor<256xi32> - // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> - // // generate pointer - // %gep_ptr = tt.addptr %0, %ptr_iter : tensor<256x!tt.ptr> - // // perform load - // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - // tt.store %gep_ptr, %4 : tensor<256x!tt.ptr> - // scf.yield %ptr_iter : tensor<256xi32> - //} - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) - // ... - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - //// Example 4, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used before update - //%_ptr4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { - // // generate pointer - // %gep_ptr = tt.addptr %0, %ptr : tensor<256x!tt.ptr> - // - // // perform load - // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - // tt.store %gep_ptr, %4 : tensor<256x!tt.ptr> - // // offset update - // %3 = tt.splat %c3 : i32 -> tensor<256xi32> - // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> - // scf.yield %ptr_iter : tensor<256xi32> - //} - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) - // ... - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_5]]) -> (index) { -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_13]] -// CHECK: scf.yield %[[VAL_12]] : index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir deleted file mode 100644 index 7855730a..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - // Example 2, gep operand is another gep's output, which is passed into the loop as varible, used before update - %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - // perform load - %3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr, %3 : tensor<256x!tt.ptr> - // pointer updates - %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> - %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) - // ... - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_5]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]], %[[VAL_13:.*]] = %[[VAL_5]]) -> (memref<256xbf16, strided<[?], offset: ?>>, index) { -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_12]] -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_16]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: scf.yield %[[VAL_17]], %[[VAL_16]] : memref<256xbf16, strided<[?], offset: ?>>, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_loopback.mlir deleted file mode 100644 index ee5cb2cc..00000000 --- a/test/Conversion/TritonToLinalg/addptr_loopback.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32> - %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> - // offset = [%arg2,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,6] - %7 = arith.addi %offset2, %scale5: tensor<4x256xi32> - // offset = [%arg2, 0], size = [4, 256], stride = [1, 6] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: arg0, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] - %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: arg1, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] - %12 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - tt.store %11, %12 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_7]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<4x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_10]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir deleted file mode 100644 index 61ddea4f..00000000 --- a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.get_program_id x : i32 - %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> - %2 = tt.splat %0 : i32 -> tensor<1024xi32> - %3 = arith.addi %2, %1 : tensor<1024xi32> - //%3: splat(%0) + range(0, 1024) - //%3: offset = %0, size = 1024, stride = 1 - // vector and scalar are both constant - %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> - %c10 = arith.constant 10 : i32 - %5 = tt.splat %c10 : i32 -> tensor<1024xi32> - %6 = arith.muli %5, %4 : tensor<1024xi32> - //%6: splat(%c10)*range(2048, 4096); - //%6: offset = %c10*2048, size = 1024, stride = %c10*1 - %7 = arith.addi %3, %6 : tensor<1024xi32> - //%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*1+1 - %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*1+1 - %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> - %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg1, offset = pid0, size = 1024, stride = 1 - %16 = tt.load %9 : tensor<1024x!tt.ptr> - tt.store %11, %16 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 11 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 20480 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}%[[VAL_6]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1024xbf16> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> -// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_14]] in writable %[[VAL_12]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir deleted file mode 100644 index 77907e06..00000000 --- a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.get_program_id x : i32 - %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> - %2 = tt.splat %0 : i32 -> tensor<1024xi32> - %3 = arith.addi %2, %1 : tensor<1024xi32> - //%3: splat(%0) + range(0, 1024) - //%3: offset = %0, size = 1024, stride = 1 - // vector is constant, scalar is value - %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> - %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> - %6 = arith.muli %5, %4 : tensor<1024xi32> - //%6: splat(%arg2)*range(2048, 3072); - //%6: offset = %arg2*2048, size = 1024, stride = %arg2*1 - %7 = arith.addi %3, %6 : tensor<1024xi32> - //%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*1+1 - %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*1+1 - %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> - %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg1: offset = pid0, size = 1024, stride = 1 - %16 = tt.load %9 : tensor<1024x!tt.ptr> - tt.store %11, %16 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2048 : index -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[ARG_6]] : i32 to index -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_10]], %[[VAL_6]] : index -// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [1024], strides: {{\[}}%[[VAL_14]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[ARG_6]] : i32 to index -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_16]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xbf16> -// CHECK: memref.copy %[[VAL_15]], %[[VAL_18]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> -// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_17]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_nested.mlir b/test/Conversion/TritonToLinalg/addptr_nested.mlir deleted file mode 100644 index bbbc0b22..00000000 --- a/test/Conversion/TritonToLinalg/addptr_nested.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg1splat = tt.splat %arg1 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg1splat : tensor<4x256xi32> - // offset = [%arg1,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %6 = arith.constant 5 : i32 - %splat6 = tt.splat %6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> - // offset = [%arg1, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg1, 0], size = [4, 256], stride = [1, 5] - %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %12 = tt.addptr %9, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg1+%arg1, 0], size = [4, 256], stride = [2, 10] - %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %14 = arith.addf %10, %13 : tensor<4x256xbf16> - %16 = tt.addptr %12, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg1+%arg1+%arg1, 0], size = [4, 256], stride = [3, 15] - tt.store %16, %14 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 15 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 10 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index -// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_14]]], sizes: [4, 256], strides: [2, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[2, ?], offset: ?>> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_15]], %[[VAL_16]] : memref<4x256xbf16, strided<[2, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_17]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: bf16, %[[VAL_20:.*]]: bf16, %[[VAL_21:.*]]: bf16): -// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : bf16 -// CHECK: linalg.yield %[[VAL_22]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : index -// CHECK: %[[VAL_26:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index -// CHECK: %[[VAL_28:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_27]]], sizes: [4, 256], strides: [3, %[[VAL_5]]] : memref<*xbf16> to memref<4x256xbf16, strided<[3, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_29:.*]] in writable %[[VAL_28]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir deleted file mode 100644 index 2f508262..00000000 --- a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir +++ /dev/null @@ -1,43 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -// TODO: expand this example to 3D -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [1,0] - %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 1 - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,1] - %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> - %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,6] - %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [1,6] - %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> - %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - %19 = tt.load %18 : tensor<256x128x!tt.ptr> - tt.store %18, %19 : tensor<256x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: {{\[}}1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir deleted file mode 100644 index 2af087ce..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir +++ /dev/null @@ -1,65 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = arg1, offset = %1, size = 1, strides = 0 - %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg1, offset = %1, size = 1024, strides = 0 - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<1024x!tt.ptr> -> tensor<1024x1x!tt.ptr> - // source = arg1, offset = [%1, 0], size = [1024, 1], strides = [0, 0] - %5 = tt.broadcast %4 : tensor<1024x1x!tt.ptr> -> tensor<1024x1024x!tt.ptr> - // source = arg1, offset = [%1, 0], size = [1024, 1024], strides = [0, 0] - %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32> -> tensor<1x1024xi32> - // offset = [0, 0], size = [1, 1024], strides = [0, 1] - %8 = tt.broadcast %7 : tensor<1x1024xi32> -> tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [0, 1] - %9 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<1024xi32> -> tensor<1024x1xi32> - // offset = [0, 0], size = [1024, 1], strides = [1, 0] - %11 = tt.broadcast %10 : tensor<1024x1xi32> -> tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [1, 0] - %12 = arith.addi %8, %11 : tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [1, 1] - %13 = tt.addptr %5, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> - // source = arg1, offset = [pid * %arg2, 0], size = [1024, 1024], strides = [1, 1] - %14 = tt.load %13 : tensor<1024x1024x!tt.ptr> - %17 = math.exp %14 : tensor<1024x1024xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = pid+arg3, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg0, offset = pid+arg3, size = 1024, strides = 0 - %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<1024x!tt.ptr> -> tensor<1024x1x!tt.ptr> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1], strides = [0, 0] - %22 = tt.broadcast %21 : tensor<1024x1x!tt.ptr> -> tensor<1024x1024x!tt.ptr> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [0, 0] - %23 = tt.addptr %22, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [1, 1] - tt.store %23, %17 : tensor<1024x1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024x1024xf32> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024x1024xf32, strided<[1, 1], offset: ?>> to memref<1024x1024xf32> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024x1024xf32> -// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_12]] : tensor<1024x1024xf32>) outs(%[[VAL_12]] : tensor<1024x1024xf32>) { -// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): -// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -> tensor<1024x1024xf32> -// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in writable %[[VAL_19]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir deleted file mode 100644 index 466778f4..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = %arg1, offset = %1, size = 1, strides = 0 - %cf0 = arith.constant 0.000000e+00 : f32 - %tensor_cf0 = tt.splat %cf0 : f32 -> tensor<1024xf32> - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %_ptr, %sum_out = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr_iter = %2, %sum_iter = %tensor_cf0) -> (!tt.ptr, tensor<1024xf32>) { - %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %4 = tt.splat %ptr_iter : !tt.ptr -> tensor<1024x!tt.ptr> - // source = %arg1, offset = %1, size = 1024, strides = 0 - %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = %arg1, offset = %1, size = 1024, strides = 1 - %8 = tt.load %5 : tensor<1024x!tt.ptr> - %9 = math.exp %8 : tensor<1024xf32> - %sum_next = arith.addf %sum_iter, %9 : tensor<1024xf32> - %cast_i = arith.index_cast %i : index to i32 - %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 - // source = %arg1, offset = %1 + %i, size = 1, strides = 0 - scf.yield %ptr_next, %sum_next : !tt.ptr, tensor<1024xf32> - } - %10 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - %21 = tt.addptr %20, %10 : tensor<1024x!tt.ptr>, tensor<1024xi32> - tt.store %21, %sum_out : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_11]] : f32) outs(%[[VAL_14]] : tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: %[[VAL_12:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i32 to index -// CHECK: %[[VAL_16:.*]]:2 = scf.for %[[VAL_17:.*]] = %[[VAL_10]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_13]]) -> (tensor<1024xf32>, index) { -// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_19]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_21:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_20]], %[[VAL_21]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> -// CHECK: %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_21]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_23:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_22]] : tensor<1024xf32>) outs(%[[VAL_22]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_24:.*]]: f32, %[[VAL_25:.*]]: f32): -// CHECK: %[[VAL_26:.*]] = math.exp %[[VAL_24]] : f32 -// CHECK: linalg.yield %[[VAL_26]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18]], %[[VAL_28:.*]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_29:.*]]: f32, %[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32): -// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32 -// CHECK: linalg.yield %[[VAL_32]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index -// CHECK: scf.yield %[[VAL_34:.*]], %[[VAL_33]] : tensor<1024xf32>, index -// CHECK: } -// CHECK: %[[VAL_35:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_35]] : i32 to index -// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_36]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in writable %[[VAL_37]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir deleted file mode 100644 index 39f3913b..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - %cf0 = arith.constant 0.000000e+00 : f32 - %tensor_cf0 = tt.splat %cf0 : f32 -> tensor<128x128xf32> - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %tensor_cf0, %ptr_iter = %2) -> (tensor<128x128xf32>, !tt.ptr ) { - %3 = tt.splat %ptr_iter : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] - %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> - // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> - %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 0] - %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 1] - %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] - %12 = tt.load %11 : tensor<128x128x!tt.ptr> - %17 = math.exp %12 : tensor<128x128xf32> - %sum_next = arith.addf %sum_iter, %17 : tensor<128x128xf32> - %cast_i = arith.index_cast %i : index to i32 - %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 - // source = %arg1, offset = %1 + %i, size = 1, strides = 0 - scf.yield %sum_next, %ptr_next : tensor<128x128xf32>, !tt.ptr - } - %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> - // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> - %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 0] - %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 1] - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] - %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] - tt.store %21, %sum_out : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<128x128xf32> -// CHECK: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_15]] : tensor<128x128xf32>) -> tensor<128x128xf32> -// CHECK: %[[VAL_13:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i32 to index -// CHECK: %[[VAL_17:.*]]:2 = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_14]]) -> (tensor<128x128xf32>, index) { -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_22]], %[[VAL_23]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> -// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_24]] : tensor<128x128xf32>) outs(%[[VAL_24]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_29:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_19]], %[[VAL_30:.*]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_19]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: f32, %[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: f32): -// CHECK: %[[VAL_34:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32 -// CHECK: linalg.yield %[[VAL_34]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_20]], %[[VAL_18]] : index -// CHECK: scf.yield %[[VAL_36:.*]], %[[VAL_35]] : tensor<128x128xf32>, index -// CHECK: } -// CHECK: %[[VAL_37:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_38:.*]] = arith.index_cast %[[VAL_37]] : i32 to index -// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_8]] : index -// CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_39]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_41:.*]]#0 in writable %[[VAL_40]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir deleted file mode 100644 index 567dd950..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) { - %0 = tt.addptr %arg0, %arg2 : !tt.ptr, i32 - %1 = tt.addptr %arg1, %arg2 : !tt.ptr, i32 - %10 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: !tt.ptr - tt.store %1, %10 : !tt.ptr - tt.return - } -} - -// CHECK: module { -// CHECK: func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { -// CHECK: %0 = arith.index_cast %arg2 : i32 to index -// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%0], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1], offset: ?>> -// CHECK: %1 = arith.index_cast %arg2 : i32 to index -// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%1], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1], offset: ?>> -// CHECK: %2 = affine.load %reinterpret_cast[0] : memref<1xbf16, strided<[1], offset: ?>> -// CHECK: affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir deleted file mode 100644 index 1bf5f031..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = arg1, offset = %1, size = 1, strides = 0 - %3 = arith.muli %0, %arg3 : i32 - %4 = tt.addptr %2, %3 : !tt.ptr, i32 - // source = arg1, offset = %1+%3, size = 1, strides = 0 - %5 = arith.muli %0, %arg4 : i32 - %6 = tt.addptr %4, %5 : !tt.ptr, i32 - // source = arg1, offset = %1+%3+%5, size = 1, strides = 0 - %7 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %8 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg1, offset = %1, size = 1024, strides = 0 - %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = arg1, offset = %1+%3+%5, size = 1024, strides = 1 - %10 = tt.load %9 : tensor<1024x!tt.ptr> - %17 = math.exp %10 : tensor<1024xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg0, offset = %18, size = 1024, strides = 0 - %21 = tt.addptr %20, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = arg0, offset = %18, size = 1024, strides = 1 - tt.store %21, %17 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_10:.*]] = arith.muli %[[ARG_8]], %[[VAL_4]] : i32 -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_9]] : i32 to index -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_10]] : i32 to index -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index -// CHECK: %[[VAL_16:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_15]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_16]], %[[VAL_17]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> -// CHECK: %[[VAL_18:.*]] = bufferization.to_tensor %[[VAL_17]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18]] : tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): -// CHECK: %[[VAL_22:.*]] = math.exp %[[VAL_20]] : f32 -// CHECK: linalg.yield %[[VAL_22]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_23:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index -// CHECK: %[[VAL_25:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_24]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_26:.*]] in writable %[[VAL_25]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir deleted file mode 100644 index ccb8ce49..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = %arg1, offset = %1, size = 1, strides = 0 - %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %4 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = %arg1, offset = %1, size = 1024, strides = 0 - %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = %arg1, offset = %1, size = 1024, strides = 1 - %8 = tt.load %5 : tensor<1024x!tt.ptr> - %17 = math.exp %8 : tensor<1024xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = %arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = %arg0, offset = %18, size = 1024, strides = 0 - %21 = tt.addptr %20, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = %arg0, offset = %18, size = 1024, strides = 1 - tt.store %21, %17 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_12]] : tensor<1024xf32>) outs(%[[VAL_12]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): -// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in writable %[[VAL_19]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir deleted file mode 100644 index 122d1c40..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir +++ /dev/null @@ -1,56 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - %3 = tt.splat %2 : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] - %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> - // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> - // offset = 128, size = 128, strides = 1 - %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 0] - %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 1] - %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] - %12 = tt.load %11 : tensor<128x128x!tt.ptr> - %17 = math.exp %12 : tensor<128x128xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] - %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] - tt.store %21, %17 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.constant 128 : index -// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_8]] : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_13]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> -// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_14]] : tensor<128x128xf32>) outs(%[[VAL_14]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32): -// CHECK: %[[VAL_18:.*]] = math.exp %[[VAL_16]] : f32 -// CHECK: linalg.yield %[[VAL_18]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_19:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : i32 to index -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir b/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir deleted file mode 100644 index e0efdde0..00000000 --- a/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir +++ /dev/null @@ -1,39 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %am = tt.load %9 : tensor<1024x!tt.ptr> - %bm = tt.load %19 : tensor<1024x!tt.ptr> - %5 = arith.addi %am, %bm : tensor<1024xi32> - tt.store %19, %5 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: memref<*xi32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_5:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy %[[VAL_5]], %[[VAL_7]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_7]] restrict writable : memref<1024xi32> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy %[[VAL_6]], %[[VAL_9]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<1024xi32> -// CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_8]], %[[VAL_10]] : tensor<1024xi32>, tensor<1024xi32>) outs(%[[VAL_8]] : tensor<1024xi32>) { -// CHECK: ^bb0(%[[VAL_12:.*]]: i32, %[[VAL_13:.*]]: i32, %[[VAL_14:.*]]: i32): -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32 -// CHECK: linalg.yield %[[VAL_15]] : i32 -// CHECK: } -> tensor<1024xi32> -// CHECK: bufferization.materialize_in_destination %[[VAL_16:.*]] in writable %[[VAL_6]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/bitcast.mlir b/test/Conversion/TritonToLinalg/bitcast.mlir deleted file mode 100644 index f838a901..00000000 --- a/test/Conversion/TritonToLinalg/bitcast.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func @kernel(%a : !tt.ptr, %b : !tt.ptr) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - - %am = tt.load %9 : tensor<1024x!tt.ptr> - - // cast result before doing float add - %am_bitcast = tt.bitcast %am : tensor<1024xi32> -> tensor<1024xf32> - - - tt.store %19, %am_bitcast : tensor<1024x!tt.ptr> - tt.return - } -} - -// CHECK: module { -// CHECK: func.func @kernel(%arg0: memref<*xi32>, %arg1: memref<*xf32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { -// CHECK: [[RC_:%.+]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: [[RC_0_:%.+]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: [[ALLOC_:%.+]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy [[RC_]], [[ALLOC_]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: [[VAR_0_:%.+]] = bufferization.to_tensor [[ALLOC_]] restrict writable : memref<1024xi32> -// CHECK: [[VAR_1_:%.+]] = tensor.empty() : tensor<1024xf32> -// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_0_]] : tensor<1024xi32>) outs([[VAR_1_]] : tensor<1024xf32>) { -// CHECK: ^bb0(%in: i32, %out: f32): -// CHECK: [[VAR_5_:%.+]] = arith.bitcast %in : i32 to f32 -// CHECK: linalg.yield [[VAR_5_]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[RC_0_]] -// CHECK: return -// CHECK: } -// CHECK: } - diff --git a/test/Conversion/TritonToLinalg/block_ptr_advance.mlir b/test/Conversion/TritonToLinalg/block_ptr_advance.mlir deleted file mode 100644 index 8cf0fe7d..00000000 --- a/test/Conversion/TritonToLinalg/block_ptr_advance.mlir +++ /dev/null @@ -1,90 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @matmul_kernel_with_block_pointers_01234567891011(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) { - %c64_i32 = arith.constant 64 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant 0.000000e+00 : bf16 - %c256_i32 = arith.constant 256 : i32 - %0 = arith.extsi %arg3 : i32 to i64 - %1 = arith.extsi %arg5 : i32 to i64 - %2 = arith.extsi %arg6 : i32 to i64 - %3 = arith.extsi %arg7 : i32 to i64 - %4 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %3], [%arg12, %c0_i32] {order = array} : > - %5 = tt.advance %4, [%c0_i32, %c64_i32] : > - %6 = tt.splat %cst : bf16 -> tensor<128x64xbf16> - %7:3 = scf.for %arg14 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg15 = %6, %arg16 = %5, %arg17 = %4) -> (tensor<128x64xbf16>, !tt.ptr>, !tt.ptr>) : i32 { - %13 = tt.load %arg16 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> - %14 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> - %15 = arith.addf %13, %14 : tensor<128x64xbf16> - %16 = arith.addf %arg15, %15 : tensor<128x64xbf16> - %17 = tt.advance %arg16, [%c0_i32, %c64_i32] : > - %18 = tt.advance %arg17, [%c64_i32, %c0_i32] : > - scf.yield %16, %17, %18 : tensor<128x64xbf16>, !tt.ptr>, !tt.ptr> - } - %8 = arith.extsi %arg10 : i32 to i64 - %9 = arith.extsi %arg11 : i32 to i64 - %10 = arith.extsi %arg4 : i32 to i64 - %11 = arith.muli %arg13, %c256_i32 : i32 - %12 = tt.make_tensor_ptr %arg2, [%0, %10], [%8, %9], [%arg12, %11] {order = array} : > - tt.store %12, %7#0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr> - tt.return - } -} - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: module { -// CHECK: func.func @matmul_kernel_with_block_pointers_01234567891011(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: memref<*xbf16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32, %arg18: i32, %arg19: i32) { -// CHECK: %c64 = arith.constant 64 : index -// CHECK: %c256_i32 = arith.constant 256 : i32 -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %c64_i32 = arith.constant 64 : i32 -// CHECK: %cst = arith.constant 0.000000e+00 : bf16 -// CHECK: %0 = tensor.empty() : tensor<128x64xbf16> -// CHECK: %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> -// CHECK: %2 = arith.index_cast %arg12 : i32 to index -// CHECK: %3 = arith.index_cast %arg6 : i32 to index -// CHECK: %4 = arith.index_cast %arg7 : i32 to index -// CHECK: %5 = arith.muli %2, %3 : index -// CHECK: %6 = arith.muli %4, %c64 : index -// CHECK: %7 = arith.addi %5, %6 : index -// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%7], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%5], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: %8:5 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %5) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index) : i32 { -// CHECK: %alloc = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy %arg22, %alloc : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> -// CHECK: %17 = bufferization.to_tensor %alloc restrict writable : memref<128x64xbf16> -// CHECK: %alloc_2 = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy %arg23, %alloc_2 : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> -// CHECK: %18 = bufferization.to_tensor %alloc_2 restrict writable : memref<128x64xbf16> -// CHECK: %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%17, %18 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%17 : tensor<128x64xbf16>) { -// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16): -// CHECK: %25 = arith.addf %in, %in_5 : bf16 -// CHECK: linalg.yield %25 : bf16 -// CHECK: } -> tensor<128x64xbf16> -// CHECK: %20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg21, %19 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%arg21 : tensor<128x64xbf16>) { -// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16): -// CHECK: %25 = arith.addf %in, %in_5 : bf16 -// CHECK: linalg.yield %25 : bf16 -// CHECK: } -> tensor<128x64xbf16> -// CHECK: %21 = arith.muli %4, %c64 : index -// CHECK: %22 = arith.addi %arg24, %21 : index -// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%22], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: %23 = arith.muli %3, %c64 : index -// CHECK: %24 = arith.addi %23, %arg25 : index -// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%24], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %22, %24 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index -// CHECK: } -// CHECK: %9 = arith.muli %arg13, %c256_i32 : i32 -// CHECK: %10 = arith.index_cast %arg12 : i32 to index -// CHECK: %11 = arith.index_cast %9 : i32 to index -// CHECK: %12 = arith.index_cast %arg10 : i32 to index -// CHECK: %13 = arith.index_cast %arg11 : i32 to index -// CHECK: %14 = arith.muli %10, %12 : index -// CHECK: %15 = arith.muli %11, %13 : index -// CHECK: %16 = arith.addi %14, %15 : index -// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [%16], sizes: [128, 64], strides: [%12, %13] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %8#0 in writable %reinterpret_cast_1 -// CHECK: return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir b/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir deleted file mode 100644 index 03363130..00000000 --- a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : tensor<1024x!tt.ptr> - ) -> () { - %cst = arith.constant dense : tensor<1024xi1> - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %am = tt.load %9 : tensor<1024x!tt.ptr> - %bm = tt.load %19 : tensor<1024x!tt.ptr> - %1 = arith.addf %am, %bm : tensor<1024xf32> - %2 = arith.subf %1, %bm : tensor<1024xf32> - %3 = arith.mulf %2, %bm : tensor<1024xf32> - %4 = arith.divf %3, %bm : tensor<1024xf32> - %5 = arith.cmpf "oeq", %4, %bm : tensor<1024xf32> - %6 = arith.select %5, %am, %bm : tensor<1024xi1>, tensor<1024xf32> - tt.store %c, %6 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<1024xf32>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_6]], %[[VAL_8]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_12:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_9]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_9]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_17:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): -// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32 -// CHECK: linalg.yield %[[VAL_22]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_23:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_24:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_24]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_25]], %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_29:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_30:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_30]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: f32, %[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: f32): -// CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32 -// CHECK: linalg.yield %[[VAL_34]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_35:.*]] = tensor.empty() : tensor<1024xi1> -// CHECK: %[[VAL_36:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_37:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_35]] : tensor<1024xi1>) { -// CHECK: ^bb0(%[[VAL_38:.*]]: f32, %[[VAL_39:.*]]: f32, %[[VAL_40:.*]]: i1): -// CHECK: %[[VAL_41:.*]] = arith.cmpf oeq, %[[VAL_38]], %[[VAL_39]] : f32 -// CHECK: linalg.yield %[[VAL_41]] : i1 -// CHECK: } -> tensor<1024xi1> -// CHECK: %[[VAL_42:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_43:.*]], %[[VAL_9]], %[[VAL_11]] : tensor<1024xi1>, tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_9]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_44:.*]]: i1, %[[VAL_45:.*]]: f32, %[[VAL_46:.*]]: f32, %[[VAL_47:.*]]: f32): -// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_44]], %[[VAL_45]], %[[VAL_46]] : f32 -// CHECK: linalg.yield %[[VAL_48]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_49:.*]] in writable %[[VAL_2]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir b/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir deleted file mode 100644 index 39f4d5ca..00000000 --- a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : !tt.ptr, - %d : tensor<1024x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // c pointer - %28 = tt.splat %c : !tt.ptr -> tensor<1024x!tt.ptr> - %29 = tt.addptr %28, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %am = tt.load %9 : tensor<1024x!tt.ptr> - %bm = tt.load %19 : tensor<1024x!tt.ptr> - %cm = tt.load %29 : tensor<1024x!tt.ptr> - %10 = arith.select %am, %bm, %cm : tensor<1024xi1>, tensor<1024xf32> - tt.store %d, %10 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi1>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<*xf32>, %[[VAL_3:.*]]: memref<1024xf32>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi1> to memref<1024xi1, strided<[1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<1024xi1> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<1024xi1, strided<[1]>> to memref<1024xi1> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<1024xi1> -// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_12]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_14]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<1024xi1>, tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_13]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: i1, %[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32): -// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32 -// CHECK: linalg.yield %[[VAL_21]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir b/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir deleted file mode 100644 index 457647a1..00000000 --- a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir +++ /dev/null @@ -1,88 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %f32ptr : !tt.ptr, - %intptr : !tt.ptr, - %f16ptr : !tt.ptr, - %save0 : tensor<1024x!tt.ptr>, - %save1 : tensor<1024x!tt.ptr>, - %save2 : tensor<1024x!tt.ptr>, - %save3 : tensor<1024x!tt.ptr>, - %save4 : tensor<1024x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // f32ptr pointer - %8 = tt.splat %f32ptr : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // intptr pointer - %18 = tt.splat %intptr : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // f32ptr pointer - %28 = tt.splat %f16ptr : !tt.ptr -> tensor<1024x!tt.ptr> - %29 = tt.addptr %28, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %afm = tt.load %9 : tensor<1024x!tt.ptr> - %aim = tt.load %19 : tensor<1024x!tt.ptr> - %bfm = tt.load %29 : tensor<1024x!tt.ptr> - %5 = arith.truncf %afm : tensor<1024xf32> to tensor<1024xbf16> - %6 = math.exp %afm : tensor<1024xf32> - %7 = arith.sitofp %aim : tensor<1024xi32> to tensor<1024xf32> - %10 = arith.extf %bfm : tensor<1024xf16> to tensor<1024xf32> - %11 = math.sqrt %afm : tensor<1024xf32> - tt.store %save0, %5 : tensor<1024x!tt.ptr> - tt.store %save1, %6 : tensor<1024x!tt.ptr> - tt.store %save2, %7 : tensor<1024x!tt.ptr> - tt.store %save3, %10 : tensor<1024x!tt.ptr> - tt.store %save4, %11 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xi32>, %[[VAL_2:.*]]: memref<*xf16>, %[[VAL_3:.*]]: memref<1024xbf16>, %[[VAL_4:.*]]: memref<1024xf32>, %[[VAL_5:.*]]: memref<1024xf32>, %[[VAL_6:.*]]: memref<1024xf32>, %[[VAL_7:.*]]: memref<1024xf32>, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32) { -// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf16> to memref<1024xf16, strided<[1]>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_11]], %[[VAL_14]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_16]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<1024xi32> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_18]] : memref<1024xf16, strided<[1]>> to memref<1024xf16> -// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xf16> -// CHECK: %[[VAL_20:.*]] = tensor.empty() : tensor<1024xbf16> -// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_20]] : tensor<1024xbf16>) { -// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: bf16): -// CHECK: %[[VAL_24:.*]] = arith.truncf %[[VAL_22]] : f32 to bf16 -// CHECK: linalg.yield %[[VAL_24]] : bf16 -// CHECK: } -> tensor<1024xbf16> -// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_15]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_30:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_17]] : tensor<1024xi32>) outs(%[[VAL_29]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: i32, %[[VAL_32:.*]]: f32): -// CHECK: %[[VAL_33:.*]] = arith.sitofp %[[VAL_31]] : i32 to f32 -// CHECK: linalg.yield %[[VAL_33]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_19]] : tensor<1024xf16>) outs(%[[VAL_34]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_36:.*]]: f16, %[[VAL_37:.*]]: f32): -// CHECK: %[[VAL_38:.*]] = arith.extf %[[VAL_36]] : f16 to f32 -// CHECK: linalg.yield %[[VAL_38]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_39:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_15]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_40:.*]]: f32, %[[VAL_41:.*]]: f32): -// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32 -// CHECK: linalg.yield %[[VAL_42]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in writable %[[VAL_3]] -// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in writable %[[VAL_4]] -// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in writable %[[VAL_5]] -// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in writable %[[VAL_6]] -// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir b/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir deleted file mode 100644 index 0f855fc7..00000000 --- a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : tensor<128x128x!tt.ptr>, - %d : tensor<128x128x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> - %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<128x128x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<128x128x!tt.ptr> - %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %af = tt.load %9 : tensor<128x128x!tt.ptr> - %bf = tt.load %19 : tensor<128x128x!tt.ptr> - %res0 = arith.addf %af, %bf : tensor<128x128xf32> - %res1 = arith.subf %af, %bf : tensor<128x128xf32> - tt.store %c, %res0 : tensor<128x128x!tt.ptr> - tt.store %d, %res1 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<128x128xf32>, %[[VAL_3:.*]]: memref<128x128xf32>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_10]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32, %[[VAL_16:.*]]: f32): -// CHECK: %[[VAL_17:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 -// CHECK: linalg.yield %[[VAL_17]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_10]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): -// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32 -// CHECK: linalg.yield %[[VAL_22]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_2]] -// CHECK: bufferization.materialize_in_destination %[[VAL_24:.*]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir b/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir deleted file mode 100644 index f0398736..00000000 --- a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : !tt.ptr, - %d : tensor<128x128x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> - %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<128x128x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<128x128x!tt.ptr> - %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // c pointer - %28 = tt.splat %c : !tt.ptr -> tensor<128x128x!tt.ptr> - %29 = tt.addptr %28, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %am = tt.load %9 : tensor<128x128x!tt.ptr> - %bm = tt.load %19 : tensor<128x128x!tt.ptr> - %cm = tt.load %29 : tensor<128x128x!tt.ptr> - %100 = arith.select %am, %bm, %cm : tensor<128x128xi1>, tensor<128x128xf32> - tt.store %d, %100 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi1>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<*xf32>, %[[VAL_3:.*]]: memref<128x128xf32>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xi1> to memref<128x128xi1, strided<[1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128x128xi1> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<128x128xi1, strided<[1, 1]>> to memref<128x128xi1> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128x128xi1> -// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_12]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_14]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<128x128xi1>, tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_13]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: i1, %[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32): -// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32 -// CHECK: linalg.yield %[[VAL_21]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir b/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir deleted file mode 100644 index 835e4e18..00000000 --- a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir +++ /dev/null @@ -1,94 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %f32ptr : !tt.ptr, - %intptr : !tt.ptr, - %f16ptr : !tt.ptr, - %save0 : tensor<128x128x!tt.ptr>, - %save1 : tensor<128x128x!tt.ptr>, - %save2 : tensor<128x128x!tt.ptr>, - %save3 : tensor<128x128x!tt.ptr>, - %save4 : tensor<128x128x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> - %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> - // f32ptr pointer - %8 = tt.splat %f32ptr : !tt.ptr -> tensor<128x128x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // intptr pointer - %18 = tt.splat %intptr : !tt.ptr -> tensor<128x128x!tt.ptr> - %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // f16ptr pointer - %28 = tt.splat %f16ptr : !tt.ptr -> tensor<128x128x!tt.ptr> - %29 = tt.addptr %28, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %afm = tt.load %9 : tensor<128x128x!tt.ptr> - %aim = tt.load %19 : tensor<128x128x!tt.ptr> - %bfm = tt.load %29 : tensor<128x128x!tt.ptr> - %5 = arith.truncf %afm : tensor<128x128xf32> to tensor<128x128xbf16> - %6 = math.exp %afm : tensor<128x128xf32> - %7 = arith.sitofp %aim : tensor<128x128xi32> to tensor<128x128xf32> - %10 = arith.extf %bfm : tensor<128x128xf16> to tensor<128x128xf32> - %11 = math.sqrt %afm : tensor<128x128xf32> - tt.store %save0, %5 : tensor<128x128x!tt.ptr> - tt.store %save1, %6 : tensor<128x128x!tt.ptr> - tt.store %save2, %7 : tensor<128x128x!tt.ptr> - tt.store %save3, %10 : tensor<128x128x!tt.ptr> - tt.store %save4, %11 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xi32>, %[[VAL_2:.*]]: memref<*xf16>, %[[VAL_3:.*]]: memref<128x128xbf16>, %[[VAL_4:.*]]: memref<128x128xf32>, %[[VAL_5:.*]]: memref<128x128xf32>, %[[VAL_6:.*]]: memref<128x128xf32>, %[[VAL_7:.*]]: memref<128x128xf32>, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32) { -// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xi32> to memref<128x128xi32, strided<[1, 1]>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf16> to memref<128x128xf16, strided<[1, 1]>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_11]], %[[VAL_14]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<128x128xi32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_16]] : memref<128x128xi32, strided<[1, 1]>> to memref<128x128xi32> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<128x128xi32> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<128x128xf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_18]] : memref<128x128xf16, strided<[1, 1]>> to memref<128x128xf16> -// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x128xf16> -// CHECK: %[[VAL_20:.*]] = tensor.empty() : tensor<128x128xbf16> -// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_20]] : tensor<128x128xbf16>) { -// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: bf16): -// CHECK: %[[VAL_24:.*]] = arith.truncf %[[VAL_22]] : f32 to bf16 -// CHECK: linalg.yield %[[VAL_24]] : bf16 -// CHECK: } -> tensor<128x128xbf16> -// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_15]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<128x128xf32> -// CHECK: %[[VAL_30:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_17]] : tensor<128x128xi32>) outs(%[[VAL_29]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: i32, %[[VAL_32:.*]]: f32): -// CHECK: %[[VAL_33:.*]] = arith.sitofp %[[VAL_31]] : i32 to f32 -// CHECK: linalg.yield %[[VAL_33]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<128x128xf32> -// CHECK: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_19]] : tensor<128x128xf16>) outs(%[[VAL_34]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_36:.*]]: f16, %[[VAL_37:.*]]: f32): -// CHECK: %[[VAL_38:.*]] = arith.extf %[[VAL_36]] : f16 to f32 -// CHECK: linalg.yield %[[VAL_38]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_39:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_15]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_40:.*]]: f32, %[[VAL_41:.*]]: f32): -// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32 -// CHECK: linalg.yield %[[VAL_42]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in writable %[[VAL_3]] -// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in writable %[[VAL_4]] -// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in writable %[[VAL_5]] -// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in writable %[[VAL_6]] -// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir b/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir deleted file mode 100644 index f430ce9c..00000000 --- a/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @addi(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.addi %arg14, %arg15 : i32 - tt.reduce.return %69 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK-LABEL: func.func @addi( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_8]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_11:.*]] = tensor.insert %[[VAL_7]] into %[[VAL_10]][] : tensor -// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<4096xi32>) outs(%[[VAL_11]] : tensor) dimensions = [0] -// CHECK: (%[[VAL_13:.*]]: i32, %[[VAL_14:.*]]: i32) { -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32 -// CHECK: linalg.yield %[[VAL_15]] : i32 -// CHECK: } -// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_12]][] : tensor -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_16]], %[[VAL_17]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir b/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir deleted file mode 100644 index c96738ee..00000000 --- a/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir +++ /dev/null @@ -1,141 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @argmax_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> - %3 = tt.splat %1 : i32 -> tensor<4096xi32> - %4 = arith.addi %3, %2 : tensor<4096xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %7 = tt.load %6 : tensor<4096x!tt.ptr> - %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ - ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): - %11 = arith.cmpf oeq, %arg9, %arg11 : f32 - %12 = arith.cmpi slt, %arg10, %arg12 : i32 - %13 = arith.andi %11, %12 : i1 - %14 = arith.cmpf ogt, %arg9, %arg11 : f32 - %15 = arith.ori %14, %13 : i1 - %16 = arith.select %15, %arg9, %arg11 : f32 - %17 = arith.select %15, %arg10, %arg12 : i32 - tt.reduce.return %16, %17 : f32, i32 - }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) - %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 - tt.store %9, %8#1 : !tt.ptr - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @argmax_012 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_10_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i32 -// CHECK: linalg.yield [[VAR_11_]] : i32 -// CHECK: } -> tensor<4096xi32> -// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4096], strides: [1] : memref<*xf32> to memref<4096xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_5_]] : tensor) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_7_]] : tensor) -> tensor -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xf32>, tensor<4096xi32>) outs([[VAR_6_]], [[VAR_8_]] : tensor, tensor) dimensions = [0] -// CHECK: ([[in:.+]]: f32, [[in_1:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_10_1_:%.+]] = arith.cmpf oeq, [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_11_1_:%.+]] = arith.cmpi slt, [[in_1]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.andi [[VAR_10_1_]], [[VAR_11_1_]] : i1 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpf ogt, [[in]], [[init]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.ori [[VAR_13_]], [[VAR_12_]] : i1 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.select [[VAR_14_]], [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[in_1]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_15_]], [[VAR_16_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_0_]][0] : memref<1xi32, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @argmin_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> - %3 = tt.splat %1 : i32 -> tensor<4096xi32> - %4 = arith.addi %3, %2 : tensor<4096xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %7 = tt.load %6 : tensor<4096x!tt.ptr> - %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ - ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): - %11 = arith.cmpf oeq, %arg9, %arg11 : f32 - %12 = arith.cmpi slt, %arg10, %arg12 : i32 - %13 = arith.andi %11, %12 : i1 - %14 = arith.cmpf olt, %arg9, %arg11 : f32 - %15 = arith.ori %14, %13 : i1 - %16 = arith.select %15, %arg9, %arg11 : f32 - %17 = arith.select %15, %arg10, %arg12 : i32 - tt.reduce.return %16, %17 : f32, i32 - }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) - %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 - tt.store %9, %8#1 : !tt.ptr - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @argmin_012 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_10_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i32 -// CHECK: linalg.yield [[VAR_11_]] : i32 -// CHECK: } -> tensor<4096xi32> -// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4096], strides: [1] : memref<*xf32> to memref<4096xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_5_]] : tensor) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_7_]] : tensor) -> tensor -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xf32>, tensor<4096xi32>) outs([[VAR_6_]], [[VAR_8_]] : tensor, tensor) dimensions = [0] -// CHECK: ([[in:.+]]: f32, [[in_1:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_10_1_:%.+]] = arith.cmpf oeq, [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_11_1_:%.+]] = arith.cmpi slt, [[in_1]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.andi [[VAR_10_1_]], [[VAR_11_1_]] : i1 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpf olt, [[in]], [[init]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.ori [[VAR_13_]], [[VAR_12_]] : i1 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.select [[VAR_14_]], [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[in_1]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_15_]], [[VAR_16_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_0_]][0] : memref<1xi32, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir b/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir deleted file mode 100644 index 73e8fd6e..00000000 --- a/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir +++ /dev/null @@ -1,215 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -// @triton.jit -// def test( -// a_ptr, c_ptr, stride_am, stride_an -// ): -// offs_am = tl.arange(0, 4) -// offs_an = tl.arange(0, 4) -// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) -// a = tl.load(a_ptrs) -// m = tl.argmax(a, axis=1) -// tl.store(c_ptr + tl.arange(0, 4), m) -// -// ret = triton.compiler.compile( -// test, -// signature=" *fp32,*fp32,i32,i32", -// print_triton_ir_only=True, -// ) - -module { - tt.func public @test_argmax(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4x1xi32> - %3 = arith.muli %1, %2 : tensor<4x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %5 = tt.splat %arg3 : i32 -> tensor<1x4xi32> - %6 = arith.muli %4, %5 : tensor<1x4xi32> - %7 = tt.broadcast %3 : tensor<4x1xi32> -> tensor<4x4xi32> - %8 = tt.broadcast %6 : tensor<1x4xi32> -> tensor<4x4xi32> - %9 = arith.addi %7, %8 : tensor<4x4xi32> - %10 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %11 = tt.addptr %10, %9 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %12 = tt.load %11 : tensor<4x4x!tt.ptr> - %13 = tt.broadcast %4 : tensor<1x4xi32> -> tensor<4x4xi32> - %14:2 = "tt.reduce"(%12, %13) <{axis = 1 : i32}> ({ - ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): - %18 = arith.cmpf oeq, %arg4, %arg6 : f32 - %19 = arith.cmpi slt, %arg5, %arg7 : i32 - %20 = arith.andi %18, %19 : i1 - %21 = arith.cmpf ogt, %arg4, %arg6 : f32 - %22 = arith.ori %21, %20 : i1 - %23 = arith.select %22, %arg4, %arg6 : f32 - %24 = arith.select %22, %arg5, %arg7 : i32 - tt.reduce.return %23, %24 : f32, i32 - }) : (tensor<4x4xf32>, tensor<4x4xi32>) -> (tensor<4xf32>, tensor<4xi32>) - %15 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %16 = tt.addptr %15, %0 : tensor<4x!tt.ptr>, tensor<4xi32> - %17 = arith.sitofp %14#1 : tensor<4xi32> to tensor<4xf32> - tt.store %16, %17 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @test_argmax -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32 -// CHECK: linalg.yield [[VAR_14_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_1_]] {{.}}[0, 1]{{.}} output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 4], strides: {{.}}[[VAR_2_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x4xf32, strided<[?, ?]>> to memref<4x4xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<4x4xi32> -// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<1x4xi32>) outs([[VAR_5_]] : tensor<4x4xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: i32): -// CHECK: linalg.yield [[in_]] : i32 -// CHECK: } -> tensor<4x4xi32> -// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_7_]] : tensor<4xf32>) -> tensor<4xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_10_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_9_]] : tensor<4xi32>) -> tensor<4xi32> -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_6_]] : tensor<4x4xf32>, tensor<4x4xi32>) outs([[VAR_8_]], [[VAR_10_]] : tensor<4xf32>, tensor<4xi32>) dimensions = [1] -// CHECK: ([[in_:.+]]: f32, [[in_1_:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_13_1_:%.+]] = arith.cmpf oeq, [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_14_1_:%.+]] = arith.cmpi slt, [[in_1_]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_15_:%.+]] = arith.andi [[VAR_13_1_]], [[VAR_14_1_]] : i1 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf ogt, [[in_]], [[init]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.ori [[VAR_16_]], [[VAR_15_]] : i1 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[in_1_]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_18_]], [[VAR_19_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref<*xf32> to memref<4xf32, strided<[1]>> -// CHECK-DAG: [[VAR_11_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK: [[VAR_12_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_reduced_]]#1 : tensor<4xi32>) outs([[VAR_11_]] : tensor<4xf32>) { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: f32): -// CHECK: [[VAR_13_2_:%.+]] = arith.sitofp [[in_]] : i32 to f32 -// CHECK: linalg.yield [[VAR_13_2_]] : f32 -// CHECK: } -> tensor<4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } - -// ----- - -// @triton.jit -// def test( -// a_ptr, c_ptr, stride_am, stride_an -// ): -// offs_am = tl.arange(0, 4) -// offs_an = tl.arange(0, 4) -// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) -// a = tl.load(a_ptrs) -// m = tl.argmin(a, axis=1) -// tl.store(c_ptr + tl.arange(0, 4), m) -// -// ret = triton.compiler.compile( -// test, -// signature=" *fp32,*fp32,i32,i32", -// print_triton_ir_only=True, -// ) - -module { - tt.func public @test_argmin(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4x1xi32> - %3 = arith.muli %1, %2 : tensor<4x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %5 = tt.splat %arg3 : i32 -> tensor<1x4xi32> - %6 = arith.muli %4, %5 : tensor<1x4xi32> - %7 = tt.broadcast %3 : tensor<4x1xi32> -> tensor<4x4xi32> - %8 = tt.broadcast %6 : tensor<1x4xi32> -> tensor<4x4xi32> - %9 = arith.addi %7, %8 : tensor<4x4xi32> - %10 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %11 = tt.addptr %10, %9 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %12 = tt.load %11 : tensor<4x4x!tt.ptr> - %13 = tt.broadcast %4 : tensor<1x4xi32> -> tensor<4x4xi32> - %14:2 = "tt.reduce"(%12, %13) <{axis = 1 : i32}> ({ - ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): - %18 = arith.cmpf oeq, %arg4, %arg6 : f32 - %19 = arith.cmpi slt, %arg5, %arg7 : i32 - %20 = arith.andi %18, %19 : i1 - %21 = arith.cmpf olt, %arg4, %arg6 : f32 - %22 = arith.ori %21, %20 : i1 - %23 = arith.select %22, %arg4, %arg6 : f32 - %24 = arith.select %22, %arg5, %arg7 : i32 - tt.reduce.return %23, %24 : f32, i32 - }) : (tensor<4x4xf32>, tensor<4x4xi32>) -> (tensor<4xf32>, tensor<4xi32>) - %15 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %16 = tt.addptr %15, %0 : tensor<4x!tt.ptr>, tensor<4xi32> - %17 = arith.sitofp %14#1 : tensor<4xi32> to tensor<4xf32> - tt.store %16, %17 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @test_argmin -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32 -// CHECK: linalg.yield [[VAR_14_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_1_]] {{.}}[0, 1]{{.}} output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 4], strides: {{.}}[[VAR_2_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x4xf32, strided<[?, ?]>> to memref<4x4xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<4x4xi32> -// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<1x4xi32>) outs([[VAR_5_]] : tensor<4x4xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: i32): -// CHECK: linalg.yield [[in_]] : i32 -// CHECK: } -> tensor<4x4xi32> -// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_7_]] : tensor<4xf32>) -> tensor<4xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_10_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_9_]] : tensor<4xi32>) -> tensor<4xi32> -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_6_]] : tensor<4x4xf32>, tensor<4x4xi32>) outs([[VAR_8_]], [[VAR_10_]] : tensor<4xf32>, tensor<4xi32>) dimensions = [1] -// CHECK: ([[in_:.+]]: f32, [[in_1_:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_13_1_:%.+]] = arith.cmpf oeq, [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_14_1_:%.+]] = arith.cmpi slt, [[in_1_]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_15_:%.+]] = arith.andi [[VAR_13_1_]], [[VAR_14_1_]] : i1 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf olt, [[in_]], [[init]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.ori [[VAR_16_]], [[VAR_15_]] : i1 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[in_1_]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_18_]], [[VAR_19_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref<*xf32> to memref<4xf32, strided<[1]>> -// CHECK-DAG: [[VAR_11_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK: [[VAR_12_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_reduced_]]#1 : tensor<4xi32>) outs([[VAR_11_]] : tensor<4xf32>) { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: f32): -// CHECK: [[VAR_13_2_:%.+]] = arith.sitofp [[in_]] : i32 to f32 -// CHECK: linalg.yield [[VAR_13_2_]] : f32 -// CHECK: } -> tensor<4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir b/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir deleted file mode 100644 index a7ab57ab..00000000 --- a/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir +++ /dev/null @@ -1,809 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @atan2_kernel_0123(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg3 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %11 = tt.addptr %10, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %12 = tt.load %11, %6 : tensor<32x!tt.ptr> - %13 = tt.extern_elementwise %9, %12 {libname = "", libpath = "", pure = true, symbol = "__nv_atan2f"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %15, %13 : tensor<32x!tt.ptr> - tt.return - } -} -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @atan2_kernel_0123 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]], [[VAR_2:%.+]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_3:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_4:%.+]] = math.atan2 [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_4]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @pow_kernel_0123(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg3 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %11 = tt.addptr %10, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %12 = tt.load %11, %6 : tensor<32x!tt.ptr> - %13 = tt.extern_elementwise %9, %12 {libname = "", libpath = "", pure = true, symbol = "__nv_powf"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %15, %13 : tensor<32x!tt.ptr> - tt.return - } -} -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @pow_kernel_0123 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]], [[VAR_2:%.+]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_3:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_4:%.+]] = math.powf [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_4]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @fabs_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_fabsf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @fabs_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.absf [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @sin_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sinf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @sin_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.sin [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @cos_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_cosf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @cos_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.cos [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @tan_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_tanf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @tan_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.tan [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @asin_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_asinf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @asin_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.asin [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @acos_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_acosf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @acos_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.acos [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @atan_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_atanf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @atan_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.atan [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @sinh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sinhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @sinh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.sinh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @cosh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_coshf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @cosh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.cosh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @tanh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_tanhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @tanh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.tanh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @asinh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_asinhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @asinh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.asinh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @acosh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_acoshf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @acosh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.acosh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @atanh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_atanhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @atanh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.atanh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @log_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_logf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @log_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.log [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @log10_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_log10f"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @log10_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.log10 [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @log1p_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @log1p_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.log1p [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @exp_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_expf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @exp_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.exp [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @exp2_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_exp2f"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @exp2_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.exp2 [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @erf_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_erff"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @erf_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.erf [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @sqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sqrtf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @sqrt_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.sqrt [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @rsqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_rsqrtf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @rsqrt_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.rsqrt [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @ceil_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_ceilf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @ceil_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.ceil [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @floor_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_floorf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @floor_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.floor [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @trunc_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_truncf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @trunc_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.trunc [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> diff --git a/test/Conversion/TritonToLinalg/convert_minmax.mlir b/test/Conversion/TritonToLinalg/convert_minmax.mlir deleted file mode 100644 index dd7edd3a..00000000 --- a/test/Conversion/TritonToLinalg/convert_minmax.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s -module { - tt.func public @minmax_olt(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf olt, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_olt -// CHECK: %[[VAL:.*]] = arith.minimumf %arg1, %arg2 : f32 - -// ----- - -module { - tt.func public @minmax_ole(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf ole, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_ole -// CHECK: %[[VAL:.*]] = arith.minimumf %arg1, %arg2 : f32 - -// ----- - -module { - tt.func public @minmax_ogt(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf ogt, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_ogt -// CHECK: %[[VAL:.*]] = arith.maximumf %arg1, %arg2 : f32 - -// ----- - -module { - tt.func public @minmax_oge(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf oge, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_oge -// CHECK: %[[VAL:.*]] = arith.maximumf %arg1, %arg2 : f32 diff --git a/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir deleted file mode 100644 index 7d915e27..00000000 --- a/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @maxnumf(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<4096xf32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: f32, %arg15: f32): - %69 = arith.maxnumf %arg14, %arg15 : f32 - tt.reduce.return %69 : f32 - }) {axis = 0 : i32} : (tensor<4096xf32>) -> f32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK-LABEL: func.func @maxnumf( -// CHECK-SAME: %arg0: memref<*xf32>, %[[ARG_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32) -// CHECK: %[[CST:.*]] = arith.constant 0xFF800000 : f32 -// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<4096xf32> -// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[VAL_0]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK: %[[VAL_2:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_3:.*]] = tensor.insert %[[CST]] into %[[VAL_2]][] : tensor -// CHECK: %[[VAL_4:.*]] = linalg.reduce ins(%[[VAL_1]] : tensor<4096xf32>) outs(%[[VAL_3]] : tensor) dimensions = [0] -// CHECK: (%in: f32, %init: f32) { -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %in, %init : f32 -// CHECK: linalg.yield %[[VAL_5]] : f32 -// CHECK: } -// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_4]][] : tensor -// CHECK: %[[VAL_7:.*]] = memref.[[VAL_7]] %arg0 to offset: [0], sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1]>> -// CHECK: affine.store %[[VAL_6]], %[[VAL_7]][0] : memref<1xf32, strided<[1]>> -// CHECK: return -// CHECK:} - -// ----- - - -module { - tt.func public @minnumf(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<4096xf32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: f32, %arg15: f32): - %69 = arith.minnumf %arg14, %arg15 : f32 - tt.reduce.return %69 : f32 - }) {axis = 0 : i32} : (tensor<4096xf32>) -> f32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK-LABEL: func.func @minnumf( -// CHECK-SAME: %arg0: memref<*xf32>, %[[ARG_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32) -// CHECK: %[[CST:.*]] = arith.constant 0x7F800000 : f32 -// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<4096xf32> -// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[VAL_0]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK: %[[VAL_2:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_3:.*]] = tensor.insert %[[CST]] into %[[VAL_2]][] : tensor -// CHECK: %[[VAL_4:.*]] = linalg.reduce ins(%[[VAL_1]] : tensor<4096xf32>) outs(%[[VAL_3]] : tensor) dimensions = [0] -// CHECK: (%in: f32, %init: f32) { -// CHECK: %[[VAL_5:.*]] = arith.minnumf %in, %init : f32 -// CHECK: linalg.yield %[[VAL_5]] : f32 -// CHECK: } -// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_4]][] : tensor -// CHECK: %[[VAL_7:.*]] = memref.[[VAL_7]] %arg0 to offset: [0], sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1]>> -// CHECK: affine.store %[[VAL_6]], %[[VAL_7]][0] : memref<1xf32, strided<[1]>> -// CHECK: return -// CHECK:} diff --git a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir deleted file mode 100644 index eaf30630..00000000 --- a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir +++ /dev/null @@ -1,126 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s -module { - tt.func public @minmax_sgt(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi sgt, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_sgt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c-2147483648{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.maxsi %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @minmax_ugt(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi ugt, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_ugt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c0{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.maxui %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @minmax_slt(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi slt, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_slt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c2147483647{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.minsi %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @minmax_ult(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi ult, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_ult(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c-1{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.minui %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } \ No newline at end of file diff --git a/test/Conversion/TritonToLinalg/convert_splat_float.mlir b/test/Conversion/TritonToLinalg/convert_splat_float.mlir deleted file mode 100644 index f37b2107..00000000 --- a/test/Conversion/TritonToLinalg/convert_splat_float.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%fin : f32, - %bin : bf16, - %save0 : tensor<1024x!tt.ptr>, - %save1 : tensor<128x256x!tt.ptr>) -> () { - %0 = tt.splat %fin : f32 -> tensor<1024xf32> - %1 = tt.splat %bin : bf16 -> tensor<128x256xbf16> - tt.store %save0, %0 : tensor<1024x!tt.ptr> - tt.store %save1, %1 : tensor<128x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: memref<1024xf32>, %[[VAL_3:.*]]: memref<128x256xbf16>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : f32) outs(%[[VAL_7]] : tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_1]] : bf16) outs(%[[VAL_9]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_8]] in writable %[[VAL_2]] -// CHECK: bufferization.materialize_in_destination %[[VAL_10]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir b/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir deleted file mode 100644 index 33e5e67f..00000000 --- a/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func public @bcast_kernel_01(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> - %6 = tt.splat %1 : i32 -> tensor<2048xi32> - %7 = arith.addi %6, %5 : tensor<2048xi32> - %8 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %9 = tt.addptr %8, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %10 = tt.load %9 : tensor<32x!tt.ptr> - %11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32> - %12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32> - %13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32> - %14 = tt.splat %arg1 : !tt.ptr -> tensor<2048x!tt.ptr> - %15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr>, tensor<2048xi32> - tt.store %15, %13 : tensor<2048x!tt.ptr> - tt.return - } -} - - -// CHECK-LABEL: func.func @bcast_kernel_01( -// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 -// CHECK: %[[VAR_0:.*]] = arith.muli %arg5, %[[C32_I32]] : i32 -// CHECK: %[[VAR_1:.*]] = arith.index_cast %[[VAR_0]] : i32 to index -// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[VAR_1]]], sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32> -// CHECK: memref.copy %[[REINTERPRET_CAST:.*]], %[[ALLOC]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32> -// CHECK: %[[VAR_2:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<32xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAR_2]] {{.}}[0, 1]{{.}} output_shape [1, 32] : tensor<32xf32> into tensor<1x32xf32> -// CHECK: %[[VAR_3:.*]] = tensor.empty() : tensor<64x32xf32> -// CHECK: %[[VAR_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[EXPANDED]] : tensor<1x32xf32>) outs(%[[VAR_3:.*]] : tensor<64x32xf32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%in: f32, %out: f32): -// CHECK: linalg.yield %in : f32 -// CHECK: } -> tensor<64x32xf32> -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAR_4]] {{.}}[0, 1]{{.}} : tensor<64x32xf32> into tensor<2048xf32> -// CHECK: %[[VAR_7:.*]] = arith.index_cast %[[VAR_0]] : i32 to index -// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %arg1 to offset: [%[[VAR_7]]], sizes: [2048], strides: [1] : memref<*xf32> to memref<2048xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[COLLAPSED]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> () -// CHECK: return diff --git a/test/Conversion/TritonToLinalg/cumsum.mlir b/test/Conversion/TritonToLinalg/cumsum.mlir deleted file mode 100644 index b579517a..00000000 --- a/test/Conversion/TritonToLinalg/cumsum.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -// @triton.jit -// def test_cumsum_op( -// input_ptr, output_ptr, n_columns -// ): -// row = tl.program_id(axis=0) -// row_start = row * n_columns -// columns = tl.arange(0, 4096) -// offsets = row_start + columns -// data = tl.load(input_ptr + offsets) -// result = tl.cumsum(data, axis=0) -// tl.store(output_ptr + offsets, result) -// -// ret = triton.compiler.compile( -// test_cumsum_op, -// signature=" *fp32,*i32,i32", -// print_triton_ir_only=True, -// ) -// print(ret.asm["ttir"]) - -module { - tt.func public @test_cumsum_op_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> - %3 = tt.splat %1 : i32 -> tensor<4096xi32> - %4 = arith.addi %3, %2 : tensor<4096xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %7 = tt.load %6 : tensor<4096x!tt.ptr> - %8 = "tt.scan"(%7) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg3: f32, %arg4: f32): - %12 = arith.addf %arg3, %arg4 : f32 - tt.scan.return %12 : f32 - }) : (tensor<4096xf32>) -> tensor<4096xf32> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4096x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %11 = arith.fptosi %8 : tensor<4096xf32> to tensor<4096xi32> - tt.store %10, %11 : tensor<4096x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @test_cumsum_op_012 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [4096], strides: [1] : memref<*xf32> to memref<4096xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = tensor.empty() : tensor<4096xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = ttx.cumsum {axis = 0 : ui32, operandSegmentSizes = array} ins([[VAR_2_]] : tensor<4096xf32>) outs([[VAR_3_]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: [4096], strides: [1] : memref<*xi32> to memref<4096xi32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]] : tensor<4096xf32>) outs([[VAR_6_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[in_:.+]]: f32, [[out_:.+]]: i32): -// CHECK: [[VAR_8_:%.+]] = arith.fptosi [[in_]] : f32 to i32 -// CHECK: linalg.yield [[VAR_8_]] : i32 -// CHECK: } -> tensor<4096xi32> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/dot.mlir b/test/Conversion/TritonToLinalg/dot.mlir deleted file mode 100644 index 95cb91b7..00000000 --- a/test/Conversion/TritonToLinalg/dot.mlir +++ /dev/null @@ -1,84 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr - ) - { - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %c64 = arith.constant 128 : i32 - %1 = tt.splat %c64 : i32 -> tensor<128xi32> - %2 = arith.muli %0, %1 : tensor<128xi32> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x64xi32> - %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %7 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<128x64xi32> - %8 = arith.addi %4, %7 : tensor<128x64xi32> - %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - %12 = tt.broadcast %11 : tensor<256x1xi32> -> tensor<256x64xi32> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %c256 = arith.constant 256 : i32 - %14 = tt.splat %c256 : i32 -> tensor<64xi32> - %15 = arith.muli %13, %14 : tensor<64xi32> - %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %17 = tt.broadcast %16 : tensor<1x64xi32> -> tensor<256x64xi32> - %18 = arith.addi %12, %17 : tensor<256x64xi32> - %20 = tt.splat %c256 : i32 -> tensor<128xi32> - %21 = arith.muli %0, %20 : tensor<128xi32> - %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %23 = tt.broadcast %22 : tensor<128x1xi32> -> tensor<128x256xi32> - %24 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %25 = tt.broadcast %24 {axis = 0 : i32} : tensor<1x256xi32> -> tensor<128x256xi32> - %26 = arith.addi %23, %25 : tensor<128x256xi32> - %30 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> - %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64x!tt.ptr> - %40 = tt.splat %arg1 : !tt.ptr -> tensor<256x64x!tt.ptr> - %41 = tt.addptr %40, %18 : tensor<256x64x!tt.ptr>, tensor<256x64xi32> - %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x64x!tt.ptr> - %43 = tt.trans %42 {order = array} : tensor<256x64xbf16> -> tensor<64x256xbf16> - %50 = tt.splat %arg2 : !tt.ptr -> tensor<128x256x!tt.ptr> - %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %52 = tt.load %51 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x256x!tt.ptr> - %60 = tt.dot %32, %43, %52 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> - tt.store %51, %60 : tensor<128x256x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: {{.}}[[CST_128_]], 1] : memref<*xbf16> to memref<128x64xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[?, 1]>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [256, 64], strides: [1, [[CST_256_]]{{.}} : memref<*xbf16> to memref<256x64xbf16, strided<[1, ?]>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<256x64xbf16, strided<[1, ?]>> to memref<256x64xbf16> -// CHECK-DAG: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x64xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<64x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_transposed_:%.+]] = linalg.transpose ins([[VAR_1_]] : tensor<256x64xbf16>) outs([[VAR_2_]] : tensor<64x256xbf16>) permutation = [1, 0] -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<128x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<128x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_2_]] : memref<128x256xbf16, strided<[?, 1]>> to memref<128x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<128x256xbf16> -// CHECK: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [[[MAP_]], [[MAP_]], [[MAP_]]], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) { -// CHECK: ^bb0([[VAR_in_1:%.+]]: bf16, [[VAR_in_2:%.+]]: bf16, {{%.+}}: bf16): -// CHECK: [[VAR_8_:%.+]] = arith.addf [[VAR_in_1]], [[VAR_in_2]] : bf16 -// CHECK: linalg.yield [[VAR_8_:%.+]] : bf16 -// CHECK: } -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/get_num_programs.mlir b/test/Conversion/TritonToLinalg/get_num_programs.mlir deleted file mode 100644 index afda2996..00000000 --- a/test/Conversion/TritonToLinalg/get_num_programs.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// XFAIL: * -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @num_programs(%arg0: !tt.ptr) { - %0 = tt.get_num_programs x : i32 - %1 = tt.get_num_programs y : i32 - %2 = tt.get_num_programs z : i32 - %3 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> - %4 = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> - %5 = tt.make_range {end = 3 : i32, start = 2 : i32} : tensor<1xi32> - %6 = tt.splat %arg0 : !tt.ptr -> tensor<1x!tt.ptr> - %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr>, tensor<1xi32> - %8 = tt.splat %0 : i32 -> tensor<1xi32> - tt.store %7, %8 : tensor<1x!tt.ptr> - %9 = tt.addptr %6, %4 : tensor<1x!tt.ptr>, tensor<1xi32> - %10 = tt.splat %1 : i32 -> tensor<1xi32> - tt.store %9, %10 : tensor<1x!tt.ptr> - %11 = tt.addptr %6, %5 : tensor<1x!tt.ptr>, tensor<1xi32> - %12 = tt.splat %2 : i32 -> tensor<1xi32> - tt.store %11, %12 : tensor<1x!tt.ptr> - tt.return - } -} - -// CHECK: module { -// CHECK: func.func @num_programs(%arg0: memref<*xi32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { -// CHECK: %c2 = arith.constant 2 : index -// CHECK: %c1 = arith.constant 1 : index -// CHECK: %c0 = arith.constant 0 : index -// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: %0 = tensor.empty() : tensor<1xi32> -// CHECK: %1 = linalg.fill ins(%arg1 : i32) outs(%0 : tensor<1xi32>) -> tensor<1xi32> -// CHECK: bufferization.materialize_in_destination %1 in writable %reinterpret_cast -// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%c1], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: %2 = tensor.empty() : tensor<1xi32> -// CHECK: %3 = linalg.fill ins(%arg2 : i32) outs(%2 : tensor<1xi32>) -> tensor<1xi32> -// CHECK: bufferization.materialize_in_destination %3 in writable %reinterpret_cast_0 -// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg0 to offset: [%c2], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: %4 = tensor.empty() : tensor<1xi32> -// CHECK: %5 = linalg.fill ins(%arg3 : i32) outs(%4 : tensor<1xi32>) -> tensor<1xi32> -// CHECK: bufferization.materialize_in_destination %5 in writable %reinterpret_cast_1 -// CHECK: return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir b/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir deleted file mode 100644 index 78afe418..00000000 --- a/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : tensor<256x16x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<32xi32> - %ws = arith.muli %ct256, %0 : tensor<32xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> - %m2 = tt.broadcast %1 : tensor<32x1xi32> -> tensor<32x256xi32> - %100 = tt.expand_dims %m2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %moff = tt.broadcast %100 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %33 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %34 = tt.expand_dims %33 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %k2 = tt.broadcast %34 : tensor<1x256xi32> -> tensor<32x256xi32> - %200 = tt.expand_dims %k2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %koff = tt.broadcast %200 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %23 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> - %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> - %n2 = tt.broadcast %24 : tensor<1x16xi32> -> tensor<256x16xi32> - %300 = tt.expand_dims %n2 {axis = 0 : i32} : tensor<256x16xi32> -> tensor<1x256x16xi32> - %noff = tt.broadcast %300 : tensor<1x256x16xi32> -> tensor<32x256x16xi32> - %mkoff = arith.addi %moff, %koff : tensor<32x256x16xi32> - %mknoff = arith.addi %mkoff, %noff : tensor<32x256x16xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<32x256x16x!tt.ptr> - %9 = tt.addptr %8, %mknoff : tensor<32x256x16x!tt.ptr>, tensor<32x256x16xi32> - %afm = tt.load %9 : tensor<32x256x16x!tt.ptr> - %6 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.cmpf ogt, %arg5, %arg6 : bf16 - %22 = arith.select %21, %arg5, %arg6 : bf16 - tt.reduce.return %22 : bf16 - }) {axis = 0 : i32} : (tensor<32x256x16xbf16>) -> tensor<256x16xbf16> - tt.store %res, %6 : tensor<256x16x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<256x16xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0xFF80 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_5]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<32x256x16xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<32x256x16xbf16> -// CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<256x16xbf16> -// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : tensor<256x16xbf16>) -> tensor<256x16xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<32x256x16xbf16>) outs(%[[VAL_11]] : tensor<256x16xbf16>) dimensions = [0] -// CHECK: (%[[VAL_13:.*]]: bf16, %[[VAL_14:.*]]: bf16) { -// CHECK: %[[VAL_15:.*]] = arith.maximumf %[[VAL_13]], %[[VAL_14]] : bf16 -// CHECK: linalg.yield %[[VAL_15]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_1]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir deleted file mode 100644 index 5726ea0c..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<256x!tt.ptr> - %19 = tt.addptr %18, %3 : tensor<256x!tt.ptr>, tensor<256xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 0 : i32} : (tensor<512x256xbf16>) -> tensor<256xbf16> - tt.store %19, %5 : tensor<256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xbf16> to memref<512x256xbf16, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[?, 1]>> to memref<512x256xbf16> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_11]] : tensor<256xbf16>) -> tensor<256xbf16> -// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_12]] : tensor<256xbf16>) dimensions = [0] -// CHECK: (%[[VAL_14:.*]]: bf16, %[[VAL_15:.*]]: bf16) { -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : bf16 -// CHECK: linalg.yield %[[VAL_16]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir deleted file mode 100644 index 7f37a9f7..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<512x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<512x!tt.ptr>, tensor<512xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 1 : i32} : (tensor<512x256xbf16>) -> tensor<512xbf16> - tt.store %19, %5 : tensor<512x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xbf16> to memref<512x256xbf16, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref<*xbf16> to memref<512xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[?, 1]>> to memref<512x256xbf16> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256x512xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_11]] : tensor<256x512xbf16>) permutation = [1, 0] -// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xbf16> -// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_13]] : tensor<512xbf16>) -> tensor<512xbf16> -// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_12]] : tensor<256x512xbf16>) outs(%[[VAL_14]] : tensor<512xbf16>) dimensions = [0] -// CHECK: (%[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16) { -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : bf16 -// CHECK: linalg.yield %[[VAL_18]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir deleted file mode 100644 index a63270ef..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<256x!tt.ptr> - %19 = tt.addptr %18, %3 : tensor<256x!tt.ptr>, tensor<256xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: f32, %arg6: f32): - %21 = arith.addf %arg5, %arg6 : f32 - tt.reduce.return %21 : f32 - }) {axis = 0 : i32} : (tensor<512x256xf32>) -> tensor<256xf32> - tt.store %19, %5 : tensor<256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xf32> to memref<512x256xf32, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[?, 1]>> to memref<512x256xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xf32> -// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_12]] : tensor<256xf32>) dimensions = [0] -// CHECK: (%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32) { -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir deleted file mode 100644 index 175d33f6..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<512x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<512x!tt.ptr>, tensor<512xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: f32, %arg6: f32): - %21 = arith.addf %arg5, %arg6 : f32 - tt.reduce.return %21 : f32 - }) {axis = 1 : i32} : (tensor<512x256xf32>) -> tensor<512xf32> - tt.store %19, %5 : tensor<512x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xf32> to memref<512x256xf32, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref<*xf32> to memref<512xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[?, 1]>> to memref<512x256xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256x512xf32> -// CHECK: %[[VAL_12:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_11]] : tensor<256x512xf32>) permutation = [1, 0] -// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xf32> -// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_13]] : tensor<512xf32>) -> tensor<512xf32> -// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_12]] : tensor<256x512xf32>) outs(%[[VAL_14]] : tensor<512xf32>) dimensions = [0] -// CHECK: (%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32) { -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32 -// CHECK: linalg.yield %[[VAL_18]] : f32 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir b/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir deleted file mode 100644 index 33b9c7a1..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr, - %out: tensor<32x16x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<32xi32> - %ws = arith.muli %ct256, %0 : tensor<32xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> - %m2 = tt.broadcast %1 : tensor<32x1xi32> -> tensor<32x256xi32> - %100 = tt.expand_dims %m2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %moff = tt.broadcast %100 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %33 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %34 = tt.expand_dims %33 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %k2 = tt.broadcast %34 : tensor<1x256xi32> -> tensor<32x256xi32> - %200 = tt.expand_dims %k2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %koff = tt.broadcast %200 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %23 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> - %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> - %n2 = tt.broadcast %24 : tensor<1x16xi32> -> tensor<256x16xi32> - %300 = tt.expand_dims %n2 {axis = 0 : i32} : tensor<256x16xi32> -> tensor<1x256x16xi32> - %noff = tt.broadcast %300 : tensor<1x256x16xi32> -> tensor<32x256x16xi32> - %mkoff = arith.addi %moff, %koff : tensor<32x256x16xi32> - %mknoff = arith.addi %mkoff, %noff : tensor<32x256x16xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<32x256x16x!tt.ptr> - %9 = tt.addptr %8, %mknoff : tensor<32x256x16x!tt.ptr>, tensor<32x256x16xi32> - %afm = tt.load %9 : tensor<32x256x16x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 1 : i32} : (tensor<32x256x16xbf16>) -> tensor<32x16xbf16> - tt.store %out, %5 : tensor<32x16x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[ARG0:.*]]: memref<*xbf16>, %[[ARG1:.*]]: memref<*xbf16>, %[[ARG2:.*]]: memref<32x16xbf16>, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: i32) { -// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_1]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>> -// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<32x256x16xbf16> -// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16> -// CHECK: %[[VAL_4:.*]] = bufferization.to_tensor %[[VAL_3]] restrict writable : memref<32x256x16xbf16> to tensor<32x256x16xbf16> -// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<256x32x16xbf16> -// CHECK: %[[VAL_6:.*]] = linalg.transpose ins(%[[VAL_4]] : tensor<32x256x16xbf16>) outs(%[[VAL_5]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2] -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x16xbf16> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : bf16) outs(%[[VAL_7]] : tensor<32x16xbf16>) -> tensor<32x16xbf16> -// CHECK: %[[VAL_9:.*]] = linalg.reduce ins(%[[VAL_6]] : tensor<256x32x16xbf16>) outs(%[[VAL_8]] : tensor<32x16xbf16>) dimensions = [0] -// CHECK: (%[[VAL_10:.*]]: bf16, %[[VAL_11:.*]]: bf16) { -// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : bf16 -// CHECK: linalg.yield %[[VAL_12]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[ARG2]] : (tensor<32x16xbf16>, memref<32x16xbf16>) -> () -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_scalar.mlir b/test/Conversion/TritonToLinalg/reducesum_scalar.mlir deleted file mode 100644 index a5ca4d0a..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_scalar.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr) - { - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.splat %afloat : !tt.ptr -> tensor<128x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> - %afm = tt.load %2 : tensor<128x!tt.ptr> - %3 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16 - tt.store %res, %3 : !tt.ptr - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> -// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<128xbf16> -// CHECK: memref.copy %[[VAL_6]], %[[VAL_7]] : memref<128xbf16, strided<[1]>> to memref<128xbf16> -// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_7]] restrict writable : memref<128xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %[[VAL_5]] into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<128xbf16>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%[[VAL_12:.*]]: bf16, %[[VAL_13:.*]]: f32) { -// CHECK: %[[VAL_14:.*]] = arith.extf %[[VAL_12]] : bf16 to f32 -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_14]], %[[VAL_13]] : f32 -// CHECK: linalg.yield %[[VAL_15]] : f32 -// CHECK: } -// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_17:.*]] = arith.truncf %[[VAL_16]] : f32 to bf16 -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1]>> -// CHECK: affine.store %[[VAL_17]], %[[VAL_18]][0] : memref<1xbf16, strided<[1]>> -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/triton_assert.mlir b/test/Conversion/TritonToLinalg/triton_assert.mlir deleted file mode 100644 index d2ed1e8e..00000000 --- a/test/Conversion/TritonToLinalg/triton_assert.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -tt.func public @assert_tensor_1d() { - %0 = tensor.empty() : tensor<4xi1> - tt.assert %0, "message" : tensor<4xi1> - tt.return -} - -// CHECK-LABEL: func.func @assert_tensor_1d -// CHECK-NOT: tt.assert -// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0 : tensor<4xi1>) { -// CHECK: ^bb0(%in: i1): -// CHECK: cf.assert %in, "Assertion `message` failed" -// CHECK: linalg.yield -// CHECK: } -// CHECK-NOT: tt.assert - -tt.func public @assert_tensor_2d() { - %0 = tensor.empty() : tensor<4x4xi1> - tt.assert %0, "message" : tensor<4x4xi1> - tt.return -} - -// CHECK-LABEL: func.func @assert_tensor_2d -// CHECK-NOT: tt.assert -// CHECK: linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<4x4xi1>) { -// CHECK: ^bb0(%in: i1): -// CHECK: cf.assert %in, "Assertion `message` failed" -// CHECK: linalg.yield -// CHECK: } -// CHECK-NOT: tt.assert - -tt.func public @assert_tensor_3d() { - %0 = tensor.empty() : tensor<4x4x4xi1> - tt.assert %0, "message" : tensor<4x4x4xi1> - tt.return -} - -// CHECK-LABEL: func.func @assert_tensor_3d -// CHECK-NOT: tt.assert -// CHECK: linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<4x4x4xi1>) { -// CHECK: ^bb0(%in: i1): -// CHECK: cf.assert %in, "Assertion `message` failed" -// CHECK: linalg.yield -// CHECK: } -// CHECK-NOT: tt.assert diff --git a/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir b/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir deleted file mode 100644 index 11e588bc..00000000 --- a/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @rand(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.extern_elementwise %3, %0 {libname = "", libpath = "", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> - %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @rand -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<8xi32> -// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<8xi32>) { -// CHECK: ^bb0([[out:.+]]: i32): -// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 -// CHECK: linalg.yield [[VAR_5_]] : i32 -// CHECK: } -> tensor<8xi32> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [8], strides: [1] : memref<*xi32> to memref<8xi32, strided<[1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<8xi32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<8xi32, strided<[1]>> to memref<8xi32> -// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<8xi32> -// CHECK-DAG: [[VAR_3_:%.+]] = tt.extern_elementwise [[VAR_2_]], [[VAR_1_]] {libname = "", libpath = "", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [8], strides: [1] : memref<*xi32> to memref<8xi32, strided<[1]>> -// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_dot_opc.mlir b/test/Conversion/TritonToLinalg/use_dot_opc.mlir deleted file mode 100644 index df5f2140..00000000 --- a/test/Conversion/TritonToLinalg/use_dot_opc.mlir +++ /dev/null @@ -1,76 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr - ) - { - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %c64 = arith.constant 128 : i32 - %1 = tt.splat %c64 : i32 -> tensor<128xi32> - %2 = arith.muli %0, %1 : tensor<128xi32> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x64xi32> - %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %7 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<128x64xi32> - %8 = arith.addi %4, %7 : tensor<128x64xi32> - %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %12 = tt.broadcast %11 : tensor<1x256xi32> -> tensor<64x256xi32> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %c256 = arith.constant 256 : i32 - %14 = tt.splat %c256 : i32 -> tensor<64xi32> - %15 = arith.muli %13, %14 : tensor<64xi32> - %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %17 = tt.broadcast %16 : tensor<64x1xi32> -> tensor<64x256xi32> - %18 = arith.addi %12, %17 : tensor<64x256xi32> - %20 = tt.splat %c256 : i32 -> tensor<128xi32> - %21 = arith.muli %0, %20 : tensor<128xi32> - %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %23 = tt.broadcast %22 : tensor<128x1xi32> -> tensor<128x256xi32> - %24 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %25 = tt.broadcast %24 {axis = 0 : i32} : tensor<1x256xi32> -> tensor<128x256xi32> - %26 = arith.addi %23, %25 : tensor<128x256xi32> - %30 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> - %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64x!tt.ptr> - %40 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> - %41 = tt.addptr %40, %18 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> - %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<64x256x!tt.ptr> - %50 = tt.splat %arg2 : !tt.ptr -> tensor<128x256x!tt.ptr> - %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %cf0 = arith.constant 0.0 : bf16 - %71 = tt.splat %cf0 : bf16 -> tensor<128x256xbf16> - %60 = tt.dot %32, %42, %71 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> - tt.store %51, %60 : tensor<128x256x!tt.ptr> - tt.store %51, %71 : tensor<128x256x!tt.ptr> - tt.return - } -} - -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_0_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: {{.}}[[CST_128_]], 1] : memref<*xbf16> to memref<128x64xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[?, 1]>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [64, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<64x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<64x256xbf16, strided<[?, 1]>> to memref<64x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<128x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_4_:%.+]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_2_]], [[VAR_3_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination [[VAR_6_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_end_chain.mlir b/test/Conversion/TritonToLinalg/use_end_chain.mlir deleted file mode 100644 index a66116d4..00000000 --- a/test/Conversion/TritonToLinalg/use_end_chain.mlir +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [1,0] - %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 1 - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,1] - %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> - %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,6] - %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [1,6] - // mixed use - %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> - %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - %19 = tt.load %18 : tensor<256x128x!tt.ptr> - tt.store %18, %19 : tensor<256x128x!tt.ptr> - %20 = arith.sitofp %14 : tensor<256x128xi32> to tensor<256x128xbf16> - tt.store %18, %20 : tensor<256x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : i32 -// CHECK-DAG: %[[CST_512_:.*]] = arith.constant 512 : i32 -// CHECK-DAG: %[[CST_1024_:.*]] = arith.constant 1024 : i32 -// CHECK: %[[VAL_30:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_30]] : tensor<256x128xi32>) -> tensor<256x128xi32> -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> -// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { -// CHECK: ^bb0(%[[VAL_10:.*]]: i32): -// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_12]], %[[CST_512_]] : i32 -// CHECK: linalg.yield %[[VAL_55]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_13]] : tensor<256x1xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32): -// CHECK: linalg.yield %[[VAL_17]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_19:.*]] = tensor.empty() : tensor<128xi32> -// CHECK: %[[VAL_20:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_19]] : tensor<128xi32>) { -// CHECK: ^bb0(%[[VAL_21:.*]]: i32): -// CHECK: %[[VAL_22:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i32 -// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_23]], %[[CST_1024_]] : i32 -// CHECK: linalg.yield %[[VAL_56]] : i32 -// CHECK: } -> tensor<128xi32> -// CHECK: %[[VAL_24:.*]] = tensor.expand_shape %[[VAL_25:.*]] {{\[\[}}0, 1]] output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32> -// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_24]] : tensor<1x128xi32>) outs(%[[VAL_26]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_28:.*]]: i32, %[[VAL_29:.*]]: i32): -// CHECK: linalg.yield %[[VAL_28]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_32:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_33:.*]], %[[VAL_31]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[VAL_33]] : tensor<256x128xi32>) { -// CHECK: ^bb0(%[[VAL_34:.*]]: i32, %[[VAL_35:.*]]: i32, %[[VAL_36:.*]]: i32): -// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_34]], %[[VAL_35]] : i32 -// CHECK: linalg.yield %[[VAL_37]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_38:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_39:.*]], %[[VAL_40:.*]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[VAL_39]] : tensor<256x128xi32>) { -// CHECK: ^bb0(%[[VAL_41:.*]]: i32, %[[VAL_42:.*]]: i32, %[[VAL_43:.*]]: i32): -// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_41]], %[[VAL_42]] : i32 -// CHECK: linalg.yield %[[VAL_44]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_45:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_46:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_45]], %[[VAL_46]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_47:.*]] = bufferization.to_tensor %[[VAL_46]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_47]] in writable %[[VAL_45]] -// CHECK: %[[VAL_48:.*]] = tensor.empty() : tensor<256x128xbf16> -// CHECK: %[[VAL_49:.*]] = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_50:.*]] : tensor<256x128xi32>) outs(%[[VAL_48]] : tensor<256x128xbf16>) { -// CHECK: ^bb0(%[[VAL_51:.*]]: i32, %[[VAL_52:.*]]: bf16): -// CHECK: %[[VAL_53:.*]] = arith.sitofp %[[VAL_51]] : i32 to bf16 -// CHECK: linalg.yield %[[VAL_53]] : bf16 -// CHECK: } -> tensor<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_54:.*]] in writable %[[VAL_45]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_mid_chain.mlir b/test/Conversion/TritonToLinalg/use_mid_chain.mlir deleted file mode 100644 index f4a855aa..00000000 --- a/test/Conversion/TritonToLinalg/use_mid_chain.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr - ) - { - %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [1,0] - // mixed use - %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 1 - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,1] - %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> - %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,6] - %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [1,6] - %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> - %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - %19 = tt.load %18 : tensor<256x128x!tt.ptr> - tt.store %18, %19 : tensor<256x128x!tt.ptr> - %20 = tt.splat %arg2 : !tt.ptr -> tensor<256x128x!tt.ptr> - %21 = tt.addptr %20, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - tt.store %21, %2 : tensor<256x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xi32>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_25:.*]] = arith.constant 512 : i32 -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> -// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { -// CHECK: ^bb0(%[[VAL_10:.*]]: i32): -// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_12]], %[[VAL_25]] : i32 -// CHECK: linalg.yield %[[VAL_24]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_13]] : tensor<256x1xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32): -// CHECK: linalg.yield %[[VAL_17]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_19]], %[[VAL_20]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_20]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_21]] in writable %[[VAL_19]] -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_7]]] : memref<*xi32> to memref<256x128xi32, strided<[1, ?], offset: 6656>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir b/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir deleted file mode 100644 index cb27d947..00000000 --- a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir +++ /dev/null @@ -1,133 +0,0 @@ -// XFAIL: * -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c2_i32 = arith.constant 2 : i32 - %cst_0 = arith.constant dense<2> : tensor<4x1xi32> - %cst_1 = arith.constant dense<6> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %c4_i32 = arith.constant 4 : i32 - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = arith.addi %0, %cst_1 : tensor<4xi32> - %3 = tt.splat %arg3 : i32 -> tensor<4xi32> - %4 = arith.remsi %2, %3 : tensor<4xi32> - %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> - %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg4, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31 = arith.muli %arg5, %c4_i32 : i32 - %32 = tt.splat %31 : i32 -> tensor<4x4xi32> - %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { - %34 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> - tt.store %arg10, %34 : tensor<4x4x!tt.ptr> - %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> - } - tt.return - } -} - -// CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index -// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_2_1_]] : index -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_6_]] : index -// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_1_]], [[VAR_4_]] : index -// CHECK: [[VAR_6_:%.+]] = arith.remsi [[VAR_5_]], [[VAR_2_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_5_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_6_]], [[CST_4_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_8_]], [[VAR_2_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_9_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_10_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[CST_4_]], [[VAR_10_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_7_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_11_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_1_]] : i32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_1_]] : index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_19_]], [[CST_6_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>>) : i32 { -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK: [[VAR_dim_:%.+]] = memref.dim [[VAR_arg15_]], [[CST_1_]] : memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_dim_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_23_:%.+]] = arith.subi [[CST_4_]], [[VAR_22_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_arg15_]][0, 0] [2, [[VAR_22_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_arg20_]][0, 0] [2, [[VAR_23_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_22_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_22_]]{{.}} [2, [[VAR_23_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_3 : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>> -// CHECK: [[VAR_24_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_24_]] in writable [[VAR_arg16_]] -// CHECK: [[VAR_25_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK: [[VAR_26_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_25_]] : index -// CHECK: [[VAR_27_:%.+]] = arith.addi [[VAR_26_]], [[VAR_20_]] : index -// CHECK: [[VAR_28_:%.+]] = arith.remsi [[VAR_27_]], [[VAR_18_]] : index -// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_27_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_30_:%.+]] = arith.addi [[VAR_28_]], [[CST_4_]] : index -// CHECK: [[VAR_31_:%.+]] = arith.minsi [[VAR_30_]], [[VAR_18_]] : index -// CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_32_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_33_:%.+]] = arith.subi [[CST_4_]], [[VAR_32_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_29_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_33_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_34_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index -// CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_34_]] : index -// CHECK: [[VAR_36_:%.+]] = arith.addi [[VAR_35_]], [[VAR_arg19_]] : index -// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_36_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_reinterpret_cast_5_]], [[VAR_reinterpret_cast_7_]], [[VAR_26_]], [[VAR_36_]], [[CST_0_]], [[VAR_reinterpret_cast_6_]] : memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir b/test/Conversion/TritonToLinalg/wraparound_stacked.mlir deleted file mode 100644 index f0e86002..00000000 --- a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir +++ /dev/null @@ -1,129 +0,0 @@ -// XFAIL: * -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @wrap_stacked_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c2_i32 = arith.constant 2 : i32 - %cst_0 = arith.constant dense<3> : tensor<1x4xi32> - %cst_1 = arith.constant dense<3> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %c4_i32 = arith.constant 4 : i32 - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4xi32> - %3 = arith.remsi %1, %2 : tensor<4xi32> - %4 = arith.addi %0, %cst_1 : tensor<4xi32> - %5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32> - %28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg5, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { - %32 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> - tt.store %arg10, %32 : tensor<4x4x!tt.ptr> - %33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %34 = tt.addptr %arg10, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - scf.yield %33, %34 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> - } - tt.return - } -} - -// CHECK-LABEL: func.func @wrap_stacked_masked_loop_01234567 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_3_]] : index -// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_2_]], [[VAR_4_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.remsi [[VAR_5_]], [[VAR_1_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_0_]], [[VAR_1_]] : index -// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_7_]], [[VAR_6_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.subi [[VAR_8_]], [[VAR_5_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.divsi [[VAR_9_]], [[VAR_1_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: {{.}}[[VAR_10_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[CST_4_]], [[VAR_10_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_6_]]{{.}}, sizes: {{.}}[[VAR_11_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[VAR_18_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_1_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref>) : i32 { -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK: [[VAR_dim_:%.+]] = memref.dim [[VAR_arg15_]], [[CST_0_]] : memref> -// CHECK: [[VAR_21_:%.+]] = arith.minsi [[VAR_dim_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_22_:%.+]] = arith.subi [[CST_4_]], [[VAR_21_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_arg15_]][0, 0] {{.}}[[VAR_21_]], 3] [1, 1] : memref> to memref> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_arg20_]][0, 0] {{.}}[[VAR_22_]], 3] [1, 1] : memref> to memref> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_21_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_21_]], 0] {{.}}[[VAR_22_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_3 : memref> to memref> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_23_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_23_]] in writable [[VAR_arg16_]] -// CHECK: [[VAR_24_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK: [[VAR_25_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_24_]] : index -// CHECK: [[VAR_26_:%.+]] = arith.addi [[VAR_25_]], [[VAR_19_]] : index -// CHECK-DAG: [[VAR_27_:%.+]] = arith.remsi [[VAR_26_]], [[VAR_16_]] : index -// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[VAR_15_]], [[VAR_16_]] : index -// CHECK: [[VAR_29_:%.+]] = arith.addi [[VAR_28_]], [[VAR_27_]] : index -// CHECK: [[VAR_30_:%.+]] = arith.subi [[VAR_29_]], [[VAR_26_]] : index -// CHECK: [[VAR_31_:%.+]] = arith.divsi [[VAR_30_]], [[VAR_16_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: {{.}}[[VAR_31_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_32_:%.+]] = arith.subi [[CST_4_]], [[VAR_31_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[VAR_32_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_33_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK: [[VAR_34_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_33_]] : index -// CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_34_]], [[VAR_arg19_]] : index -// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_35_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_reinterpret_cast_5_]], [[VAR_reinterpret_cast_7_]], [[VAR_25_]], [[VAR_35_]], [[CST_0_]], [[VAR_reinterpret_cast_6_]] : memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref> -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir b/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir deleted file mode 100644 index 9455f1e8..00000000 --- a/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -// XFAIL: * -// We currently do not support this kind of modulo pattern: -// (a + arrange(0, K)) % M -module { - tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c2_i32 = arith.constant 2 : i32 - %cst_0 = arith.constant dense<2> : tensor<4x1xi32> - %cst_1 = arith.constant dense<6> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %c4_i32 = arith.constant 4 : i32 - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = tt.splat %arg3 : i32 -> tensor<4xi32> - %3 = arith.remsi %0, %2 : tensor<4xi32> - %4 = arith.addi %3, %cst_1 : tensor<4xi32> - %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> - %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg4, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31 = arith.muli %arg5, %c4_i32 : i32 - %32 = tt.splat %31 : i32 -> tensor<4x4xi32> - %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { - %34 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> - tt.store %arg10, %34 : tensor<4x4x!tt.ptr> - %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> - } - tt.return - } -} diff --git a/tools/triton-shared-opt/RegisterTritonSharedDialects.h b/tools/triton-shared-opt/RegisterTritonSharedDialects.h index 6d92953f..d78f96be 100644 --- a/tools/triton-shared-opt/RegisterTritonSharedDialects.h +++ b/tools/triton-shared-opt/RegisterTritonSharedDialects.h @@ -13,7 +13,6 @@ #include "triton-shared/Conversion/StructuredToMemref/Passes.h" #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" #include "triton-shared/Conversion/TritonPtrToMemref/Passes.h" -#include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" #include "triton-shared/Conversion/TritonToUnstructured/Passes.h" @@ -29,7 +28,6 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::registerAllPasses(); mlir::registerLinalgPasses(); mlir::triton::registerTritonPasses(); - mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPasses(); mlir::triton::registerTritonToStructuredPass(); mlir::triton::registerTritonPtrToMemref(); From 4f593cb433708ed487419df61e603a97f50fde10 Mon Sep 17 00:00:00 2001 From: enjustli <798634436@qq.com> Date: Sat, 15 Nov 2025 12:32:54 +0800 Subject: [PATCH 2/3] Update triton to dbfbc1e1e6c bypass bf16 type update api usage for llvm version change update backend update test case --- backend/compiler.py | 9 +- backend/driver.py | 43 +- .../ConversionPatterns.hpp | 401 +++++++++--------- lib/Analysis/MaskAnalysis.cpp | 45 +- lib/Analysis/OpFoldResultUtils.cpp | 66 +-- lib/Analysis/PtrAnalysis.cpp | 107 ++--- lib/AnalysisStructured/PtrAnalysis.cpp | 116 +++-- .../StructuredToMemref/StructuredToMemref.cpp | 353 ++++++++------- .../StructuredToMemrefPass.cpp | 6 +- .../TritonArithToLinalgPass.cpp | 5 +- .../TritonPtrToMemrefPass.cpp | 3 +- .../CollapseShape.cpp | 56 +-- .../ReconcilePtrCastsPass.cpp | 25 +- .../TritonToPtrPass.cpp | 58 +-- .../TritonToStructuredPass.cpp | 26 +- .../TritonToUnstructuredPass.cpp | 49 +-- .../UnstructuredToMemrefPass.cpp | 118 +++--- lib/Dialect/TPtr/IR/TPtrDialect.cpp | 11 +- .../IR/TritonStructuredOps.cpp | 4 +- .../IR/TritonTilingExtDialect.cpp | 8 +- python/examples/conftest.py | 3 +- .../convert_addi_reduce.mlir | 5 +- .../StructuredToMemref/masked_ldst_2d.mlir | 14 +- .../wraparound_side_by_side.mlir | 8 +- .../wraparound_stacked.mlir | 8 +- .../conditional_ptr_as_src.mlir | 5 +- .../convert_unsplat.mlir | 4 +- triton-hash.txt | 2 +- 28 files changed, 794 insertions(+), 764 deletions(-) diff --git a/backend/compiler.py b/backend/compiler.py index 9efdb48c..ee941bb8 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -64,6 +64,7 @@ def _ttir_to_ttsharedir(mod): subprocess_args.insert(2, "--add-llvm-debug-info") subprocess.check_call(subprocess_args) + _dump_ir_if_needed([dst_path]) return Path(dst_path).read_text() @@ -113,6 +114,7 @@ def _ttsharedir_to_llir(ttsharedir: str): "--mlir-print-debuginfo", "-o", llmlir_path]) + _dump_ir_if_needed([llmlir_path]) # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") @@ -120,7 +122,7 @@ def _ttsharedir_to_llir(ttsharedir: str): "--mlir-to-llvmir", "-o", llir_path]) - _dump_ir_if_needed([ttshared_path, llmlir_path, llir_path]) + _dump_ir_if_needed([llir_path]) return Path(llir_path).read_text() @@ -151,7 +153,7 @@ def _llir_to_bin(llir: str, metadata): sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None)) if not sanitizer_attributes_pass_path: - raise Exception(f"libSanitizerAttributes.so does not exist.") + raise Exception("libSanitizerAttributes.so does not exist.") subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path, "-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path, @@ -194,6 +196,7 @@ class CPUOptions: allow_fp8e4nv: bool = False allowed_dot_input_precisions: Tuple[str] = ("ieee", ) sanitize_overflow: bool = True + instrumentation_mode: str = "" def __post_init__(self): pass @@ -256,7 +259,7 @@ def make_ttir(mod, metadata, options): passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) passes.common.add_cse(pm) - pm.run(mod) + pm.run(mod, 'make_ttir') return mod def add_stages(self, stages, options, language): diff --git a/backend/driver.py b/backend/driver.py index dcfe9b74..5bcd9cdf 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -1,8 +1,9 @@ import hashlib import tempfile -import sysconfig -import os, subprocess, tempfile, platform +import os +import subprocess +import platform import importlib.util import sys @@ -63,7 +64,18 @@ def _ty_to_cpp(ty): "fp64": "double", }[ty] +def _flatten_signature(sig, output): + # Flatten tuples + if isinstance(sig, tuple): + for x in sig: + _flatten_signature(x, output) + else: + output.append(sig) + def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" if ty[0] == '*': return "PyObject*" if ty == "constexpr": @@ -71,6 +83,15 @@ def _extracted_type(ty): return _ty_to_cpp(ty) def _format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(_format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty == "constexpr": + return "O" + if ty.startswith("tensordesc"): + return "O" return { "PyObject*": "O", "constexpr": "O", @@ -85,15 +106,20 @@ def _format_of(ty): "uint16_t": "H", "uint32_t": "I", "uint64_t": "K", - }[ty] + }[_ty_to_cpp(ty)] def _generate_launcher(constants, signature, kernel_name): - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + args_format = ''.join([_format_of(ty) for ty in signature.values()]) format = "iiiOOOO" + args_format + + flat_signature = [] + for sig in signature.values(): + _flatten_signature(sig, flat_signature) + signature = {i: s for i, s in enumerate(flat_signature)} + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if ty != "constexpr") + kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else "int64_t, void*" for i, ty in signature.items() if ty != "constexpr") kernel_arg_decls += ', ' if kernel_arg_decls else '' kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr") @@ -327,7 +353,7 @@ def launch( libomp_path = next(Path(Path(_get_llvm_bin_path("")).parent).rglob("libomp.so"), None) if not libomp_path: - raise Exception(f"libomp.so does not exist.") + raise Exception("libomp.so does not exist.") libomp_path = str(libomp_path.parent) @@ -364,7 +390,8 @@ def __init__(self, src, metadata): kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + def cst_key(i): + return src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..c509be31 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -59,14 +59,14 @@ static Value getScalarValue(Value operand, Location loc, if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return rewriter.create(loc, resType, src); + return arith::SIToFPOp::create(rewriter, loc, resType, src); }) .Case([&](Operation *op) { auto resType = op->getResults()[0].getType(); if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return rewriter.create(loc, resType, src); + return arith::TruncFOp::create(rewriter, loc, resType, src); }) .Default([](Operation *op) { llvm_unreachable("unsupported op in generating "); @@ -134,11 +134,11 @@ static Value getTransposedValue(Value source, const Location loc, } } - Value transposeInit = rewriter.create( - loc, transposedShape, sourceType.getElementType()); + Value transposeInit = tensor::EmptyOp::create(rewriter, loc, transposedShape, + sourceType.getElementType()); Value transpose = - rewriter.create(loc, source, transposeInit, perm) + linalg::TransposeOp::create(rewriter, loc, source, transposeInit, perm) .getResults()[0]; return transpose; @@ -191,8 +191,8 @@ struct MakeTensorPtrConverter Location loc) const { for (auto opnd : ops) { if (isa(opnd.getType())) { - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), opnd); + auto castOp = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), opnd); vec.push_back(castOp.getResult()); } else { assert(isa(opnd.getType())); @@ -224,8 +224,8 @@ struct MakeTensorPtrConverter SmallVector newOffsets; for (auto [offset, stride] : llvm::zip(pointerState.offsets, pointerState.strides)) { - auto mulOp = rewriter.create(loc, cast(offset), - cast(stride)); + auto mulOp = arith::MulIOp::create(rewriter, loc, cast(offset), + cast(stride)); newOffsets.push_back(mulOp.getResult()); } @@ -278,71 +278,67 @@ struct LoadConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{zero, block1Col}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } void createStackedCopies(Value block1, Value block2, Value dst, Location loc, ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{block1Row, zero}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } public: @@ -359,8 +355,8 @@ struct LoadConverter : public OpConversionPattern { auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), adaptor.getPtr(), loc, rewriter); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - auto loadOp = rewriter.create( - op.getLoc(), sMemRef, zeroMap, ValueRange{}); + auto loadOp = affine::AffineLoadOp::create(rewriter, op.getLoc(), sMemRef, + zeroMap, ValueRange{}); rewriter.replaceOp(op, loadOp.getResult()); return success(); } @@ -378,8 +374,8 @@ struct LoadConverter : public OpConversionPattern { auto tensorType = RankedTensorType::get(type.getShape(), type.getElementType()); - auto alloc = rewriter.create( - loc, MemRefType::get(type.getShape(), type.getElementType())); + auto alloc = memref::AllocOp::create( + rewriter, loc, MemRefType::get(type.getShape(), type.getElementType())); if (!mask) { assert(!other && "other value used in non-masked load"); @@ -404,11 +400,12 @@ struct LoadConverter : public OpConversionPattern { } } else { - rewriter.create(loc, ptr, alloc); + memref::CopyOp::create(rewriter, loc, ptr, alloc); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create( + rewriter, loc, tensorType, alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -435,31 +432,33 @@ struct LoadConverter : public OpConversionPattern { // the result auto shape = type.getShape(); auto accBase = - rewriter.create(loc, rewriter.getBoolAttr(false)) + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)) .getResult(); for (size_t i = 0; i < type.getShape().size(); i++) { - auto shapei = rewriter.create( - loc, rewriter.getIndexAttr(shape[i])); + auto shapei = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(shape[i])); Value dimi = dyn_cast(mstate.dims[i]); if (!dimi) { - dimi = rewriter.create( - loc, cast(cast(mstate.dims[i]))); + dimi = arith::ConstantOp::create( + rewriter, loc, + cast(cast(mstate.dims[i]))); } - auto cmpOp = rewriter.create( - loc, arith::CmpIPredicate::slt, dimi, shapei); - accBase = rewriter.create(loc, accBase, cmpOp.getResult()) - .getResult(); + auto cmpOp = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = + arith::OrIOp::create(rewriter, loc, accBase, cmpOp.getResult()) + .getResult(); } // condition the memset on the or-accumulation // initialize with padding prior to CopyOp - rewriter.create( - loc, accBase, [&](OpBuilder &builder, Location loc) { - builder.create(loc, ValueRange{scalarOther}, - ValueRange{alloc}); - builder.create(loc); + scf::IfOp::create( + rewriter, loc, accBase, [&](OpBuilder &builder, Location loc) { + linalg::FillOp::create(builder, loc, ValueRange{scalarOther}, + ValueRange{alloc}); + scf::YieldOp::create(builder, loc); }); } @@ -492,11 +491,12 @@ struct LoadConverter : public OpConversionPattern { } else { memref::SubViewOp srcSubview = mstate.getSubview(ptr, loc, rewriter); memref::SubViewOp dstSubview = mstate.getSubview(alloc, loc, rewriter); - rewriter.create(loc, srcSubview, dstSubview); + memref::CopyOp::create(rewriter, loc, srcSubview, dstSubview); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create(rewriter, loc, tensorType, + alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -519,16 +519,16 @@ struct StoreConverter : public OpConversionPattern { auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - rewriter.create(loc, val, sMemRef, zeroMap, - ValueRange{}); + affine::AffineStoreOp::create(rewriter, loc, val, sMemRef, zeroMap, + ValueRange{}); rewriter.eraseOp(op); return success(); } // 1. Simple case where no mask is used. if (!mask) { - auto storeOp = rewriter.create( - loc, val, ptr); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, val, ptr); storeOp.setWritable(true); rewriter.eraseOp(op); return success(); @@ -546,8 +546,8 @@ struct StoreConverter : public OpConversionPattern { auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); auto dstSubview = mstate.getSubview(ptr, loc, rewriter); - auto storeOp = rewriter.create( - loc, srcSlice, dstSubview); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, srcSlice, dstSubview); storeOp.setWritable(true); rewriter.eraseOp(op); @@ -635,13 +635,12 @@ struct SplatConverter : public OpConversionPattern { auto opType = cast(op.getType()); auto loc = op.getLoc(); - auto init = rewriter.create(loc, opType.getShape(), - opType.getElementType()); + auto init = tensor::EmptyOp::create(rewriter, loc, opType.getShape(), + opType.getElementType()); auto filledTensor = - rewriter - .create(loc, ValueRange{adaptor.getSrc()}, - ValueRange{init}) + linalg::FillOp::create(rewriter, loc, ValueRange{adaptor.getSrc()}, + ValueRange{init}) .result(); rewriter.replaceOp(op, filledTensor); @@ -697,16 +696,16 @@ struct BroadcastConverter : public OpConversionPattern { rewriter.getMultiDimIdentityMap(resultRank)); assert(op->getNumResults() == 1 && "code assumes single result!"); - auto init = rewriter.create(loc, resultType.getShape(), - elementType); + auto init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + elementType); - auto linalgOp = rewriter.create( - loc, op->getResultTypes(), ValueRange{adaptor.getSrc()}, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, op->getResultTypes(), ValueRange{adaptor.getSrc()}, ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value opResult = blockArgs[0]; - nestedBuilder.create(loc, opResult); + linalg::YieldOp::create(nestedBuilder, loc, opResult); }); linalgOp->setAttr("broadcastDims", @@ -740,8 +739,8 @@ struct ExpandDimsConverter : public OpConversionPattern { reassoc.push_back(g); } - auto expandShapeOp = rewriter.create( - op.getLoc(), resType, src, reassoc); + auto expandShapeOp = tensor::ExpandShapeOp::create(rewriter, op.getLoc(), + resType, src, reassoc); rewriter.replaceOp(op, expandShapeOp.getResult()); return success(); @@ -781,22 +780,22 @@ struct MakeRangeConverter : public OpConversionPattern { /* dimCount */ 1, /* symbolCount */ 0, SmallVector{mlir::getAffineDimExpr(0, context)}, context)}; - auto init = rewriter.create(loc, shape, elementType); - auto linalgOp = rewriter.create( - loc, op->getResultTypes(), /* operands */ ValueRange{}, + auto init = tensor::EmptyOp::create(rewriter, loc, shape, elementType); + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, op->getResultTypes(), /* operands */ ValueRange{}, ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(1), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { - Value index = nestedBuilder.create(loc, 0); - Value res = nestedBuilder.create( - loc, type.getElementType(), index); + Value index = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value res = arith::IndexCastOp::create(nestedBuilder, loc, + type.getElementType(), index); if (op.getStart()) { - auto start = rewriter.create( - op.getLoc(), op.getStart(), + auto start = arith::ConstantIntOp::create( + rewriter, op.getLoc(), op.getStart(), type.getElementType().getIntOrFloatBitWidth()); - res = nestedBuilder.create(loc, res, start); + res = arith::AddIOp::create(nestedBuilder, loc, res, start); } - nestedBuilder.create(loc, res); + linalg::YieldOp::create(nestedBuilder, loc, res); }); rewriter.replaceOp(op, linalgOp->getResults()); @@ -819,8 +818,7 @@ struct AssertConverter : public OpConversionPattern { // TritonOps.td. Tensors will always be RankedTensorType. if (isa(condVal.getType())) { // handle scalar case - rewriter.create(op.getLoc(), condVal, - assertMessage.str()); + cf::AssertOp::create(rewriter, op.getLoc(), condVal, assertMessage.str()); } else if (auto tensorType = dyn_cast(condVal.getType())) { // handle tensor case @@ -834,8 +832,8 @@ struct AssertConverter : public OpConversionPattern { SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - rewriter.create( - op.getLoc(), TypeRange{}, condVal, ValueRange{}, + linalg::GenericOp::create( + rewriter, op.getLoc(), TypeRange{}, condVal, ValueRange{}, ArrayRef{indexingMaps}, ArrayRef{iteratorTypes}, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -843,9 +841,9 @@ struct AssertConverter : public OpConversionPattern { Value element = args[0]; // make a cf.assert for the current element - b.create(loc, element, assertMessage.str()); + cf::AssertOp::create(b, loc, element, assertMessage.str()); - b.create(loc); + linalg::YieldOp::create(b, loc); }); } else { op.emitError("Unexpected type in triton::AssertOp"); @@ -868,8 +866,8 @@ struct BitcastConverter : public OpConversionPattern { return failure(); } - auto arithBitcast = rewriter.create( - op.getLoc(), op.getType(), op.getOperand()); + auto arithBitcast = arith::BitcastOp::create(rewriter, op.getLoc(), + op.getType(), op.getOperand()); rewriter.replaceOp(op, arithBitcast.getResult()); return success(); @@ -908,8 +906,8 @@ struct CallConverter : public OpConversionPattern { } } - auto call = rewriter.create(op.getLoc(), op.getCallee(), - op.getResultTypes(), args); + auto call = func::CallOp::create(rewriter, op.getLoc(), op.getCallee(), + op.getResultTypes(), args); if (!call) { op.emitError("Failed to create func::CallOp"); @@ -946,14 +944,14 @@ struct FpToFpConverter : public OpConversionPattern { "Not a float-like operand or result"); if (operandWidth.value() > resultWidth.value()) { - Value truncatedValue = rewriter.create( - op.getLoc(), resultType, op.getOperand()); + Value truncatedValue = arith::TruncFOp::create( + rewriter, op.getLoc(), resultType, op.getOperand()); rewriter.replaceOp(op, truncatedValue); return success(); } - Value extendedValue = rewriter.create( - op.getLoc(), resultType, op.getOperand()); + Value extendedValue = arith::ExtFOp::create(rewriter, op.getLoc(), + resultType, op.getOperand()); rewriter.replaceOp(op, extendedValue); return success(); @@ -975,11 +973,11 @@ struct ClampConverter : public OpConversionPattern { Value clamp; if (propagateNan) { - Value maxMin = rewriter.create(loc, x, min); - clamp = rewriter.create(loc, maxMin, max); + Value maxMin = arith::MaximumFOp::create(rewriter, loc, x, min); + clamp = arith::MinimumFOp::create(rewriter, loc, maxMin, max); } else { - Value maxMin = rewriter.create(loc, x, min); - clamp = rewriter.create(loc, maxMin, max); + Value maxMin = arith::MaxNumFOp::create(rewriter, loc, x, min); + clamp = arith::MinNumFOp::create(rewriter, loc, maxMin, max); } rewriter.replaceOp(op, clamp); @@ -995,7 +993,7 @@ struct PreciseSqrtConverter matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto replacement = - rewriter.create(op.getLoc(), adaptor.getOperands()); + math::SqrtOp::create(rewriter, op.getLoc(), adaptor.getOperands()); rewriter.replaceOp(op, replacement); return success(); @@ -1009,7 +1007,7 @@ struct PreciseDivConverter : public OpConversionPattern { matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto replacement = - rewriter.create(op.getLoc(), adaptor.getOperands()); + arith::DivFOp::create(rewriter, op.getLoc(), adaptor.getOperands()); rewriter.replaceOp(op, replacement); return success(); @@ -1022,8 +1020,8 @@ struct CatConverter : public OpConversionPattern { LogicalResult matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto replacement = rewriter.create( - op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); + auto replacement = tensor::ConcatOp::create( + rewriter, op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); rewriter.replaceOp(op, replacement); @@ -1060,8 +1058,8 @@ struct SplitConverter : public OpConversionPattern { offsets.push_back(rewriter.getIndexAttr(i)); sizes.push_back(rewriter.getIndexAttr(1)); - Value slice = rewriter.create( - loc, resultTensor, input, offsets, sizes, strides); + Value slice = tensor::ExtractSliceOp::create( + rewriter, loc, resultTensor, input, offsets, sizes, strides); results.push_back(slice); } @@ -1081,8 +1079,8 @@ struct JoinConverter : public OpConversionPattern { auto resultType = cast(op.getResult().getType()); auto loc = op.getLoc(); - Value result = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + Value result = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()); auto shape = resultType.getShape(); @@ -1099,8 +1097,8 @@ struct JoinConverter : public OpConversionPattern { offsets.push_back(rewriter.getIndexAttr(i)); sizes.push_back(rewriter.getIndexAttr(1)); - result = rewriter.create(loc, inputs[i], result, - offsets, sizes, strides); + result = tensor::InsertSliceOp::create(rewriter, loc, inputs[i], result, + offsets, sizes, strides); } rewriter.replaceOp(op, result); @@ -1118,7 +1116,7 @@ struct MulHiUIOpConverter : public OpConversionPattern { Location loc = op.getLoc(); auto mulResult = - rewriter.create(loc, adaptor.getOperands()); + arith::MulUIExtendedOp::create(rewriter, loc, adaptor.getOperands()); rewriter.replaceOp(op, mulResult.getHigh()); return success(); @@ -1169,29 +1167,28 @@ struct MatmulConverter : public OpConversionPattern { bool integers = elementType.isInteger(); bool skipC = isZeroTensor(opc, integers); auto init = - rewriter.create(loc, dstType.getShape(), elementType); + tensor::EmptyOp::create(rewriter, loc, dstType.getShape(), elementType); TypedAttr constantAttr = integers ? static_cast(rewriter.getIntegerAttr(elementType, 0)) : static_cast(rewriter.getFloatAttr(elementType, 0)); - auto zero = rewriter.create( - op.getLoc(), elementType, constantAttr); + auto zero = arith::ConstantOp::create(rewriter, op.getLoc(), elementType, + constantAttr); - auto zeroes = - rewriter.create(loc, ValueRange{zero}, ValueRange{init}) - .result(); + auto zeroes = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{init}) + .result(); - auto res = rewriter - .create(loc, ValueRange{opa, opb}, - ValueRange{zeroes}) + auto res = linalg::MatmulOp::create(rewriter, loc, ValueRange{opa, opb}, + ValueRange{zeroes}) .getResult(0); if (!skipC) { if (integers) { - res = rewriter.create(loc, opc, res); + res = arith::AddIOp::create(rewriter, loc, opc, res); } else { - res = rewriter.create(loc, opc, res); + res = arith::AddFOp::create(rewriter, loc, opc, res); } } @@ -1273,8 +1270,8 @@ struct ReduceConverter : public OpConversionPattern { return nullptr; }); - return rewriter.create(redOp->getLoc(), constantType, - attr); + return arith::ConstantOp::create(rewriter, redOp->getLoc(), constantType, + attr); } bool requiresF32Conversion(const Type elemType, Operation *redOp) const { @@ -1291,16 +1288,16 @@ struct ReduceConverter : public OpConversionPattern { return llvm::TypeSwitch(redOp) .Case([&](auto redOp) { if (convertLhsToF32Precision) { - lhs = b.create(loc, Float32Type::get(b.getContext()), - lhs); + lhs = arith::ExtFOp::create(b, loc, + Float32Type::get(b.getContext()), lhs); } - return b.create(loc, lhs, rhs); + return decltype(redOp)::create(b, loc, lhs, rhs); }) .Case([&](auto redOp) { - return b.create(loc, lhs, rhs); + return decltype(redOp)::create(b, loc, lhs, rhs); }) .Default([](Operation *op) { op->dump(); @@ -1379,40 +1376,40 @@ struct ReduceConverter : public OpConversionPattern { // directly instead of EmptyOp so that the subsequent pass can recognize // the patterns (EmptyOp is susceptible to being CSE'd away, making it // harder to match the patterns correctly). - initTensor = rewriter.create( - loc, RankedTensorType::get({}, constantType), ValueRange{}); - initTensor = rewriter.create(loc, accBaseConstOp, - initTensor, ValueRange{}); + initTensor = bufferization::AllocTensorOp::create( + rewriter, loc, RankedTensorType::get({}, constantType), ValueRange{}); + initTensor = tensor::InsertOp::create(rewriter, loc, accBaseConstOp, + initTensor, ValueRange{}); } else { - Value init = rewriter.create( - loc, cast(resType).getShape(), constantType); - initTensor = rewriter - .create(loc, ValueRange{accBaseConstOp}, - ValueRange{init}) - .result(); + Value init = tensor::EmptyOp::create( + rewriter, loc, cast(resType).getShape(), + constantType); + initTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{accBaseConstOp}, + ValueRange{init}) + .result(); } Value finalResult = - rewriter - .create( - loc, ValueRange{source}, ValueRange{initTensor}, - SmallVector{axis}, - [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { - assert(inputs.size() == 2); - Value result = - getRedElement(inputs[0], inputs[1], loc, rop, opBuilder, - convertToF32Precision); - opBuilder.create(loc, result); - }) + linalg::ReduceOp::create( + rewriter, loc, ValueRange{source}, ValueRange{initTensor}, + SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + assert(inputs.size() == 2); + Value result = getRedElement(inputs[0], inputs[1], loc, rop, + opBuilder, convertToF32Precision); + linalg::YieldOp::create(opBuilder, loc, result); + }) .getResult(0); if (isVectorReduce) { finalResult = - rewriter.create(loc, constantType, finalResult); + tensor::ExtractOp::create(rewriter, loc, constantType, finalResult); } if (convertToF32Precision) { - finalResult = rewriter.create(loc, resType, finalResult); + finalResult = + arith::TruncFOp::create(rewriter, loc, resType, finalResult); } rewriter.replaceOp(op, finalResult); @@ -1570,10 +1567,9 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { ArrayRef shape, Value fillValue, Location loc) const { Value initTensor = - rewriter.create(loc, shape, fillValue.getType()); - return rewriter - .create(loc, ValueRange{fillValue}, - ValueRange{initTensor}) + tensor::EmptyOp::create(rewriter, loc, shape, fillValue.getType()); + return linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{initTensor}) .result(); } @@ -1648,15 +1644,15 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { // the result value to either -inf or +inf depending on // whether we're dealing with argmax or argmin auto valueType = elemTypes[0]; - auto valuesAccBaseVal = rewriter.create( - loc, valueType, + auto valuesAccBaseVal = arith::ConstantOp::create( + rewriter, loc, valueType, rewriter.getFloatAttr(valueType, T::getBaseReductionValue())); // Set the initial value of the rank-0 tensor containing the index of the // min or max value to -1 auto indexType = elemTypes[1]; - auto indicesAccBaseVal = rewriter.create( - loc, indexType, rewriter.getIntegerAttr(indexType, -1)); + auto indicesAccBaseVal = arith::ConstantOp::create( + rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, -1)); // Get the shape of the resulting tensors (both for values and indices). If // we are reducing to a single scalar, then the result's type is a tensor of @@ -1671,8 +1667,8 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; - auto linalgOp = rewriter.create( - loc, adaptor.getOperands(), outputs, + auto linalgOp = linalg::ReduceOp::create( + rewriter, loc, adaptor.getOperands(), outputs, SmallVector{adaptor.getAxis()}, [&](OpBuilder &b, Location loc, ValueRange inputs) { assert(inputs.size() == 4); @@ -1690,15 +1686,15 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { return mapping.lookup(val); }); - b.create(loc, results); + linalg::YieldOp::create(b, loc, results); }); if (isScalarReduce) { SmallVector reduceResults{ - rewriter.create( - loc, valueType, linalgOp.getResults()[0], ValueRange{}), - rewriter.create( - loc, indexType, linalgOp.getResults()[1], ValueRange{})}; + tensor::ExtractOp::create(rewriter, loc, valueType, + linalgOp.getResults()[0], ValueRange{}), + tensor::ExtractOp::create(rewriter, loc, indexType, + linalgOp.getResults()[1], ValueRange{})}; rewriter.replaceOp(op, reduceResults); } else { rewriter.replaceOp(op, linalgOp); @@ -1927,8 +1923,9 @@ struct DenseConstantConverter : public OpConversionPattern { auto splatConst = arith::ConstantOp::materialize( rewriter, attr.getSplatValue(), attr.getElementType(), loc); - auto init = rewriter.create( - loc, cast(op.getResult().getType()).getShape(), + auto init = tensor::EmptyOp::create( + rewriter, loc, + cast(op.getResult().getType()).getShape(), attr.getElementType()); rewriter.replaceOpWithNewOp(op, ValueRange{splatConst}, @@ -1996,8 +1993,8 @@ class CumSumConverter : public OpConversionPattern { "= {1, 2} and axis = rank - 1"); } - Value init = rewriter.create(op.getLoc(), type.getShape(), - type.getElementType()); + Value init = tensor::EmptyOp::create(rewriter, op.getLoc(), type.getShape(), + type.getElementType()); rewriter.replaceOpWithNewOp( op, input, rewriter.getUI32IntegerAttr(axis), init); @@ -2032,7 +2029,7 @@ class AddPtrConverter : public OpConversionPattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } @@ -2063,8 +2060,8 @@ class TensorOpConverter : public OpConversionPattern { rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - SmallVector outputs = {rewriter.create( - op->getLoc(), resultTensorType.getShape(), + SmallVector outputs = {tensor::EmptyOp::create( + rewriter, op->getLoc(), resultTensorType.getShape(), resultTensorType.getElementType())}; rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps, @@ -2078,7 +2075,7 @@ class TensorOpConverter : public OpConversionPattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } @@ -2116,7 +2113,7 @@ class StorePtrToLinalgConverter : public OpConversionPattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } @@ -2154,8 +2151,8 @@ class ReshapeConverter : public OpConversionPattern { ArrayRef outputShape = outputType.getShape(); - auto shape = rewriter.create( - loc, rewriter.getI64TensorAttr(outputShape)); + auto shape = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64TensorAttr(outputShape)); rewriter.replaceOpWithNewOp(op, outputType, input, shape); diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 1aa9760f..2fcdded5 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -68,8 +68,8 @@ tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); - return builder.create(loc, dstType, source, offsets, - dims, strides); + return tensor::ExtractSliceOp::create(builder, loc, dstType, source, offsets, + dims, strides); } memref::SubViewOp MaskState::getSubview(Value source, const Location loc, @@ -80,8 +80,8 @@ memref::SubViewOp MaskState::getSubview(Value source, const Location loc, auto dstType = memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); - return builder.create(loc, cast(dstType), - source, offsets, dims, strides); + return memref::SubViewOp::create(builder, loc, cast(dstType), + source, offsets, dims, strides); } static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, @@ -91,8 +91,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, auto srcType = cast(src.getType()); auto dstType = memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return b.create(loc, cast(dstType), src, - offsets, sizes, strides); + return memref::SubViewOp::create(b, loc, cast(dstType), src, + offsets, sizes, strides); } // Assume block1 wraps around and the remainder is block2. @@ -155,7 +155,8 @@ MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, OpBuilder &builder) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; - OpFoldResult col1 = builder.create(loc, block1, 1).getResult(); + OpFoldResult col1 = + memref::DimOp::create(builder, loc, block1, 1).getResult(); OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, builder); OpFoldResult subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, builder); @@ -174,7 +175,8 @@ MaskState::getStackedSubviews(Value block1, Value block2, const Location loc, OpBuilder &builder) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; - OpFoldResult row1 = builder.create(loc, block1, 0).getResult(); + OpFoldResult row1 = + memref::DimOp::create(builder, loc, block1, 0).getResult(); OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder); OpFoldResult subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, builder); @@ -314,8 +316,8 @@ LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, if (scalar.getType().isInteger(1)) { this->scalar = scalar; } else { - auto castOp = - builder.create(loc, builder.getIndexType(), scalar); + auto castOp = arith::IndexCastOp::create(builder, loc, + builder.getIndexType(), scalar); this->scalar = castOp.getResult(); } return success(); @@ -407,18 +409,17 @@ LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, } auto targetTensorType = RankedTensorType::get({size}, builder.getI32Type()); - Value range = - builder - .create(loc, targetTensorType, 0, size) - .getResult(); + Value range = triton::MakeRangeOp::create(builder, loc, + targetTensorType, 0, size) + .getResult(); Value v = ofrToIndexValue(ofr, loc, builder); - v = builder - .create(loc, builder.getI32Type(), v) + v = arith::IndexCastUIOp::create(builder, loc, builder.getI32Type(), + v) .getResult(); - v = builder.create(loc, targetTensorType, v) + v = triton::SplatOp::create(builder, loc, targetTensorType, v) .getResult(); - return builder - .create(loc, arith::CmpIPredicate::ult, range, v) + return arith::CmpIOp::create(builder, loc, + arith::CmpIPredicate::ult, range, v) .getResult(); }; if (!lhsV) { @@ -436,7 +437,7 @@ LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, continue; } // And the mask. - masks.push_back(builder.create(loc, lhsV, rhsV)); + masks.push_back(arith::AndIOp::create(builder, loc, lhsV, rhsV)); } } // Only support one unstructured mask. @@ -494,8 +495,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, SmallVector reassociation = *maybeReassociationMap; // Set masks. - unstructuredMask = builder.create( - loc, flatType, cmpOp, reassociation); + unstructuredMask = tensor::CollapseShapeOp::create( + builder, loc, flatType, cmpOp, reassociation); } masks[cmpOpDim] = unstructuredMask; } diff --git a/lib/Analysis/OpFoldResultUtils.cpp b/lib/Analysis/OpFoldResultUtils.cpp index 49ae4c44..85d7842f 100644 --- a/lib/Analysis/OpFoldResultUtils.cpp +++ b/lib/Analysis/OpFoldResultUtils.cpp @@ -59,7 +59,7 @@ Value ofrToValue(const OpFoldResult ofr, const Location loc, OpBuilder &b) { auto attr = dyn_cast(ofr); auto typedAttr = dyn_cast(attr); - return b.create(loc, typedAttr); + return arith::ConstantOp::create(b, loc, typedAttr); } Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, @@ -67,14 +67,14 @@ Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, if (Value val = dyn_cast(ofr)) { assert(val.getType().isIntOrIndex()); if (!val.getType().isIndex()) { - val = b.create(loc, b.getIndexType(), val); + val = arith::IndexCastOp::create(b, loc, b.getIndexType(), val); } return val; } auto intVal = getIntAttr(ofr); if (intVal.has_value()) { - return b.create(loc, b.getIndexAttr(intVal.value())); + return arith::ConstantOp::create(b, loc, b.getIndexAttr(intVal.value())); } llvm_unreachable("Unexpected OpFoldResult state"); return nullptr; @@ -93,14 +93,14 @@ Value indexTypeCast(Value v, Type targetTy, const Location loc, OpBuilder &b) { if (isa(targetTy) || isa(ty)) { assert((isa(targetTy) || isa(ty)) && "Only cast between index type and integer type"); - return b.create(loc, targetTy, v).getResult(); + return arith::IndexCastOp::create(b, loc, targetTy, v).getResult(); } else { auto targetIntTy = cast(targetTy); auto intTy = cast(ty); if (targetIntTy.getWidth() > intTy.getWidth()) - return b.create(loc, targetTy, v).getResult(); + return arith::ExtSIOp::create(b, loc, targetTy, v).getResult(); else - return b.create(loc, targetTy, v).getResult(); + return arith::TruncIOp::create(b, loc, targetTy, v).getResult(); } } @@ -114,8 +114,8 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy, Value v = dyn_cast(ofr); if (!v) - v = b.create(loc, - cast(cast(ofr))); + v = arith::ConstantOp::create(b, loc, + cast(cast(ofr))); Type ty = v.getType(); if (targetTy == ty) @@ -127,7 +127,7 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy, // cast to target element type first. if (targetEltTy != ty) v = indexTypeCast(v, targetEltTy, loc, b); - return b.create(loc, targetTy, v).getResult(); + return triton::SplatOp::create(b, loc, targetTy, v).getResult(); } else if (targetShapedTy && shapedTy) { Type targetEltTy = targetShapedTy.getElementType(); Type eltTy = shapedTy.getElementType(); @@ -148,26 +148,26 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy, SmallVector shapeValues; for (auto dim : targetShapedTy.getShape()) { shapeValues.push_back( - b.create(loc, b.getIndexAttr(dim))); + arith::ConstantOp::create(b, loc, b.getIndexAttr(dim))); } RankedTensorType targetShapeTensorTy = RankedTensorType::get( targetShapedTy.getShape().size(), b.getIndexType()); - auto shapeTensor = b.create( - loc, targetShapeTensorTy, shapeValues); - return b.create(loc, targetTy, v, shapeTensor) + auto shapeTensor = tensor::FromElementsOp::create( + b, loc, targetShapeTensorTy, shapeValues); + return triton::ReshapeOp::create(b, loc, targetTy, v, shapeTensor) .getResult(); } if (isa(targetEltTy) || isa(eltTy)) { assert((isa(targetEltTy) || isa(eltTy)) && "Only cast between index type and integer type"); - return b.create(loc, targetTy, v).getResult(); + return arith::IndexCastOp::create(b, loc, targetTy, v).getResult(); } else { auto targetIntTy = cast(targetEltTy); auto intTy = cast(eltTy); if (targetIntTy.getWidth() > intTy.getWidth()) - return b.create(loc, targetTy, v).getResult(); + return arith::ExtSIOp::create(b, loc, targetTy, v).getResult(); else - return b.create(loc, targetTy, v).getResult(); + return arith::TruncIOp::create(b, loc, targetTy, v).getResult(); } } else { assert(!shapedTy && "src type rank should be >= target type rank"); @@ -194,18 +194,18 @@ OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - return b.create(loc, lhsValue, rhsValue).getResult(); + return arith::AddIOp::create(b, loc, lhsValue, rhsValue).getResult(); } OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, @@ -225,18 +225,18 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto sumOp = b.create(loc, lhsValue, rhsValue); + auto sumOp = arith::SubIOp::create(b, loc, lhsValue, rhsValue); return sumOp.getResult(); } @@ -280,17 +280,17 @@ OpFoldResult mulOFRs(const OpFoldResult lhs, const OpFoldResult rhs, // otherwise, need to create instructions to calculate new attribute value if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - return b.create(loc, lhsValue, rhsValue).getResult(); + return arith::MulIOp::create(b, loc, lhsValue, rhsValue).getResult(); } OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, @@ -306,18 +306,18 @@ OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto minOp = b.create(loc, lhsValue, rhsValue); + auto minOp = arith::MinSIOp::create(b, loc, lhsValue, rhsValue); return minOp.getResult(); } @@ -334,18 +334,18 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto maxOp = b.create(loc, lhsValue, rhsValue); + auto maxOp = arith::MaxSIOp::create(b, loc, lhsValue, rhsValue); return maxOp.getResult(); } @@ -359,7 +359,7 @@ OpFoldResult selectOFRs(const OpFoldResult condOFR, const OpFoldResult trueOFR, "Condition for selectOp must be a bool type"); auto selectOp = - b.create(loc, condValue, trueValue, falseValue); + arith::SelectOp::create(b, loc, condValue, trueValue, falseValue); return selectOp.getResult(); } @@ -398,7 +398,7 @@ OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = ofrToIndexValue(lhs, loc, b); auto rhsValue = ofrToIndexValue(rhs, loc, b); - auto cmpOp = b.create(loc, pred, lhsValue, rhsValue); + auto cmpOp = arith::CmpIOp::create(b, loc, pred, lhsValue, rhsValue); return selectOFRs(cmpOp.getResult(), trueOFR, falseOFR, loc, b); } diff --git a/lib/Analysis/PtrAnalysis.cpp b/lib/Analysis/PtrAnalysis.cpp index 2fc2f85e..1a175670 100644 --- a/lib/Analysis/PtrAnalysis.cpp +++ b/lib/Analysis/PtrAnalysis.cpp @@ -82,7 +82,7 @@ void PtrState::addState(const PtrState &lhsState, const PtrState &rhsState, if (lhsState.scalar && rhsState.scalar) { auto addOp = - rewriter.create(loc, lhsState.scalar, rhsState.scalar); + arith::AddIOp::create(rewriter, loc, lhsState.scalar, rhsState.scalar); scalar = addOp.getResult(); } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; @@ -210,28 +210,28 @@ PtrState::createStackedCastOps(ArrayRef resultShape, Value strideRow = ofrToIndexValue(strides[0], loc, rewriter); Value strideCol = ofrToIndexValue(strides[1], loc, rewriter); - Value modRow = rewriter.create( - loc, rewriter.getIndexType(), modulos[0]->size); + Value modRow = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), modulos[0]->size); // First chunk Value wrappedAroundOff = - rewriter.create(loc, targetOffset, strideRow); - Value clampedOff = rewriter.create(loc, modRow, strideRow); + arith::RemSIOp::create(rewriter, loc, targetOffset, strideRow); + Value clampedOff = arith::MulIOp::create(rewriter, loc, modRow, strideRow); clampedOff = - rewriter.create(loc, clampedOff, wrappedAroundOff); - Value d1 = rewriter.create(loc, clampedOff, targetOffset); - d1 = rewriter.create(loc, d1, strideRow); + arith::AddIOp::create(rewriter, loc, clampedOff, wrappedAroundOff); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOff, targetOffset); + d1 = arith::DivSIOp::create(rewriter, loc, d1, strideRow); SmallVector sizes1{d1, colSize}; - memref::ReinterpretCastOp cast1 = rewriter.create( - loc, resultType, source, targetOffset, sizes1, + memref::ReinterpretCastOp cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, source, targetOffset, sizes1, ValueRange{strideRow, strideCol}); // Second chunk - Value d2 = rewriter.create(loc, rowSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, rowSize, d1); SmallVector sizes2{d2, colSize}; - memref::ReinterpretCastOp cast2 = rewriter.create( - loc, resultType, source, wrappedAroundOff, sizes2, + memref::ReinterpretCastOp cast2 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, source, wrappedAroundOff, sizes2, ValueRange{strideRow, strideCol}); return {cast1, cast2}; @@ -299,29 +299,29 @@ PtrState::createSideBySideCastOps(ArrayRef resultShape, Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); - Value modN = rewriter.create(loc, rewriter.getIndexType(), - modulos[1]->size); + Value modN = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), modulos[1]->size); - Value x = rewriter.create(loc, targetOffset, modN); - Value y = rewriter.create(loc, targetOffset, x); + Value x = arith::RemSIOp::create(rewriter, loc, targetOffset, modN); + Value y = arith::SubIOp::create(rewriter, loc, targetOffset, x); SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); // First chunk - Value nextOffset = rewriter.create(loc, x, colSize); - Value clampedOffset = rewriter.create(loc, nextOffset, modN); - Value d1 = rewriter.create(loc, clampedOffset, x); + Value nextOffset = arith::AddIOp::create(rewriter, loc, x, colSize); + Value clampedOffset = arith::MinSIOp::create(rewriter, loc, nextOffset, modN); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOffset, x); SmallVector sizes1{rowSize, d1}; - auto cast1 = rewriter.create( - loc, resultType, source, targetOffset, sizes1, strideVals); + auto cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, source, targetOffset, sizes1, strideVals); // Second chunk - Value d2 = rewriter.create(loc, colSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, colSize, d1); SmallVector sizes2{rowSize, d2}; - auto cast2 = rewriter.create( - loc, resultType, source, y, sizes2, strideVals); + auto cast2 = memref::ReinterpretCastOp::create(rewriter, loc, resultType, + source, y, sizes2, strideVals); return {cast1, cast2}; } @@ -341,8 +341,8 @@ PtrState::createCastOp(ArrayRef resultShape, const Location loc, getResultMemrefType(rewriter.getContext(), staticOffset[0], resultShape); // Create reinterpret cast - return rewriter.create( - loc, resultType, source, targetOffset, sizes, strides); + return memref::ReinterpretCastOp::create(rewriter, loc, resultType, source, + targetOffset, sizes, strides); } void PtrAnalysis::visitOperandAdd( @@ -648,8 +648,8 @@ void PtrAnalysis::visitOperand( } if (isa(operand.getType())) { - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), operand); + auto castOp = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), operand); state.scalar = castOp.getResult(); return; } @@ -805,10 +805,10 @@ void PtrAnalysis::rewriteAddptrOp( rewriter.getContext(), ShapedType::kDynamic, resultShape); UnrealizedConversionCastOp combinedCast = - rewriter.create( - op.getLoc(), resultType, - ValueRange{casts[0].getResult(), casts[1].getResult(), - op.getResult()}); + UnrealizedConversionCastOp::create(rewriter, op.getLoc(), resultType, + ValueRange{casts[0].getResult(), + casts[1].getResult(), + op.getResult()}); combinedCast->setAttr(ModuloState::WraparoundAttr, rewriter.getStringAttr(type)); @@ -858,18 +858,18 @@ void PtrAnalysis::rewriteAdvanceOp( llvm::zip(incrementOffsets, ptrState.offsets, ptrState.strides)) { Value offsetValue; if (auto offsetIntAttr = getIntAttr(offset)) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); + auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(0)); offsetValue = constOp.getResult(); } else { offsetValue = cast(offset); } - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), increment); - auto mulOp = rewriter.create(loc, castOp.getResult(), - cast(stride)); + auto castOp = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), increment); + auto mulOp = arith::MulIOp::create(rewriter, loc, castOp.getResult(), + cast(stride)); auto addOp = - rewriter.create(loc, mulOp.getResult(), offsetValue); + arith::AddIOp::create(rewriter, loc, mulOp.getResult(), offsetValue); newOffsets.push_back(addOp.getResult()); } @@ -995,8 +995,8 @@ void PtrAnalysis::rewriteYieldOp( // zeroes. if (auto sIntAttr = getIntAttr(s)) { assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); + auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(0)); operands.push_back(constOp.getResult()); } else { operands.push_back(cast(s)); @@ -1166,8 +1166,8 @@ void PtrAnalysis::rewriteForOp( for (auto [j, s] : llvm::enumerate(state.offsets)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); newInitArgs.push_back(constOp.getResult()); state.offsets[j] = constOp.getResult(); } else { @@ -1178,8 +1178,8 @@ void PtrAnalysis::rewriteForOp( for (auto [j, s] : llvm::enumerate(state.strides)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); newInitArgs.push_back(constOp.getResult()); state.strides[j] = constOp.getResult(); } else { @@ -1240,9 +1240,10 @@ void PtrAnalysis::rewriteForOp( rewriter.restoreInsertionPoint(origIp); // Create a new scf::ForOp that uses updated init args and same loop body - auto newOp = rewriter.create( - op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), - newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + auto newOp = scf::ForOp::create( + rewriter, op.getLoc(), op.getLowerBound(), op.getUpperBound(), + op.getStep(), newInitArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { IRMapping mapping; mapping.map(op.getInductionVar(), iv); mapping.map(op.getInitArgs(), newInitArgs); @@ -1266,7 +1267,7 @@ void PtrAnalysis::rewriteForOp( b.setInsertionPointToStart(b.getBlock()); Value zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { SmallVector newReinterpretCasts; @@ -1274,9 +1275,9 @@ void PtrAnalysis::rewriteForOp( newReinterpretCasts.push_back(args[chunk.initArgIndex]); } - auto combinedCast = b.create( - loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, - unrealizedCastOp->getAttrs()); + auto combinedCast = UnrealizedConversionCastOp::create( + b, loc, unrealizedCastOp.getResult(0).getType(), + newReinterpretCasts, unrealizedCastOp->getAttrs()); args[chunkData[0].initArgIndex].replaceUsesWithIf( combinedCast.getResult(0), [](OpOperand &operand) { diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index 2253c807..e5b6ece3 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -63,29 +63,26 @@ static Value applyUnstructuredMask(Operation *op, Value ptr, return nullptr; } - ptr = - builder - .create( - loc, gatherScatterPtr.getBase(), - gatherScatterPtr.getGatherScatterOffset(), unstructuredMask, - gatherScatterPtr.getGatherScatterDim(), - gatherScatterPtr.getSizes(), gatherScatterPtr.getMixedStrides(), - gatherScatterPtr.getMixedOffsets()) - .getResult(); + ptr = tts::MakeGatherScatterTensorPtrOp::create( + builder, loc, gatherScatterPtr.getBase(), + gatherScatterPtr.getGatherScatterOffset(), unstructuredMask, + gatherScatterPtr.getGatherScatterDim(), + gatherScatterPtr.getSizes(), gatherScatterPtr.getMixedStrides(), + gatherScatterPtr.getMixedOffsets()) + .getResult(); } else if (auto structuredPtr = ptr.getDefiningOp()) { auto ofrToI32Value = [&](OpFoldResult ofr) { Value v = dyn_cast(ofr); if (!v) { - v = builder - .create( - loc, cast(cast(ofr))) + v = arith::ConstantOp::create(builder, loc, + cast(cast(ofr))) .getResult(); } if (isa(v.getType())) { - v = builder.create(loc, builder.getI32Type(), v) + v = arith::IndexCastOp::create(builder, loc, builder.getI32Type(), v) .getResult(); } else if (v.getType().isInteger(64)) { - v = builder.create(loc, builder.getI32Type(), v) + v = arith::TruncIOp::create(builder, loc, builder.getI32Type(), v) .getResult(); } @@ -100,22 +97,20 @@ static Value applyUnstructuredMask(Operation *op, Value ptr, // Divide stride since offset of tts::MakeTensorPtrOp already include the // stride, but gatherScatterOffset of tts::MakeGatherScatterTensorPtrOp // should not include stride. - offset = builder.create(loc, offset, stride); + offset = arith::DivUIOp::create(builder, loc, offset, stride); Value gatherScatterOffset = - builder.create(loc, offsetRowType, offset).getResult(); - Value range = builder - .create( - loc, offsetRowType, 0, structuredPtr.getSizes()[dim]) + tensor::SplatOp::create(builder, loc, offsetRowType, offset) + .getResult(); + Value range = triton::MakeRangeOp::create(builder, loc, offsetRowType, 0, + structuredPtr.getSizes()[dim]) .getResult(); gatherScatterOffset = - builder.create(loc, gatherScatterOffset, range); - ptr = builder - .create( - loc, structuredPtr.getBase(), gatherScatterOffset, - unstructuredMask, dim, structuredPtr.getSizes(), - structuredPtr.getMixedStrides(), - structuredPtr.getMixedOffsets()) + arith::AddIOp::create(builder, loc, gatherScatterOffset, range); + ptr = tts::MakeGatherScatterTensorPtrOp::create( + builder, loc, structuredPtr.getBase(), gatherScatterOffset, + unstructuredMask, dim, structuredPtr.getSizes(), + structuredPtr.getMixedStrides(), structuredPtr.getMixedOffsets()) .getResult(); } else { return nullptr; @@ -291,7 +286,7 @@ LogicalResult PtrState::addState(const PtrState &lhsState, if (lhsState.scalar && rhsState.scalar) { auto addOp = - builder.create(loc, lhsState.scalar, rhsState.scalar); + arith::AddIOp::create(builder, loc, lhsState.scalar, rhsState.scalar); scalar = addOp.getResult(); } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; @@ -569,7 +564,7 @@ LogicalResult PtrState::mulState(const PtrState &lhsState, if (lhsState.scalar && rhsState.scalar) { scalar = - builder.create(loc, lhsState.scalar, rhsState.scalar); + arith::MulIOp::create(builder, loc, lhsState.scalar, rhsState.scalar); } auto indexTy = IndexType::get(op->getContext()); @@ -662,8 +657,8 @@ tts::MakeTensorPtrOp PtrState::createTTSMakeTensorPtrOp(OpBuilder &builder, staticSizes.push_back(s.value()); } - auto op = builder.create( - loc, source, staticSizes, strides, offsets, shape, order); + auto op = tts::MakeTensorPtrOp::create(builder, loc, source, staticSizes, + strides, offsets, shape, order); LLVM_DEBUG({ llvm::dbgs() << "creating tts::make_tensor_ptr:\n"; op->dump(); @@ -700,16 +695,15 @@ PtrState::createTTSMakeGatherScatterTensorPtrOp(OpBuilder &builder, auto collapseTy = RankedTensorType::get({offsetSize}, offsetTy.getElementType()); nonContinuousOffset = - builder - .create( - loc, collapseTy, nonContinuousOffset, reassociationMap) + tensor::CollapseShapeOp::create(builder, loc, collapseTy, + nonContinuousOffset, reassociationMap) .getResult(); offsets[nonContinuousDim] = nonContinuousOffset; } // Generate tts::make_gather_scatter_tensor_ptr. - auto op = builder.create( - loc, source, nonContinuousOffset, nonContinuousDim, staticSizes, strides, - offsets); + auto op = tts::MakeGatherScatterTensorPtrOp::create( + builder, loc, source, nonContinuousOffset, nonContinuousDim, staticSizes, + strides, offsets); LLVM_DEBUG({ llvm::dbgs() << "creating tts::make_gather_scatter_tensor_ptr:\n"; op->dump(); @@ -1135,19 +1129,19 @@ PtrAnalysis::visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, for (int64_t i = 0; i < pointeeType.getRank(); i++) { state.sizes.push_back(builder.getIndexAttr(shape[i])); - auto strideCst = builder.create( - loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); + auto strideCst = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); state.strides.push_back(strideCst.getResult()); - auto offsetCst = builder.create( - loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); + auto offsetCst = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); - auto scaledOffset = builder.create( - loc, offsetCst.getResult(), strideCst.getResult()); + auto scaledOffset = arith::MulIOp::create( + builder, loc, offsetCst.getResult(), strideCst.getResult()); state.offsets.push_back(scaledOffset.getResult()); - auto shapeCst = builder.create( - loc, builder.getIndexType(), makeTPtrOp.getShape()[i]); + auto shapeCst = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), makeTPtrOp.getShape()[i]); state.shape.push_back(shapeCst.getResult()); } state.order = SmallVector(makeTPtrOp.getOrder()); @@ -1219,8 +1213,8 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, if (!isa(operand) && operand.getDefiningOp()) { builder.setInsertionPointAfter(operand.getDefiningOp()); } - auto castOp = builder.create( - loc, builder.getIndexType(), operand); + auto castOp = arith::IndexCastOp::create(builder, loc, + builder.getIndexType(), operand); state.scalar = castOp.getResult(); return success(); } else if (isa(operand.getType())) { @@ -1390,18 +1384,18 @@ LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) { llvm::zip(incrementOffsets, state.offsets, state.strides)) { Value offsetValue; if (auto offsetIntAttr = getIntAttr(offset)) { - auto constOp = builder.create( - loc, builder.getIndexAttr(offsetIntAttr.value())); + auto constOp = arith::ConstantOp::create( + builder, loc, builder.getIndexAttr(offsetIntAttr.value())); offsetValue = constOp.getResult(); } else { offsetValue = cast(offset); } - auto castOp = builder.create( - loc, builder.getIndexType(), increment); - auto mulOp = builder.create(loc, castOp.getResult(), - cast(stride)); + auto castOp = arith::IndexCastOp::create(builder, loc, + builder.getIndexType(), increment); + auto mulOp = arith::MulIOp::create(builder, loc, castOp.getResult(), + cast(stride)); auto addOp = - builder.create(loc, mulOp.getResult(), offsetValue); + arith::AddIOp::create(builder, loc, mulOp.getResult(), offsetValue); newOffsets.push_back(addOp.getResult()); } @@ -1637,15 +1631,15 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { // This operand is a pointer directly from the kernel arguments. // Use offset 0. assert(!tritonValue.getDefiningOp()); - replacements.push_back(builder.create( - op.getLoc(), builder.getIndexAttr(0))); + replacements.push_back(arith::ConstantOp::create( + builder, op.getLoc(), builder.getIndexAttr(0))); } } else { for (auto [j, s] : llvm::enumerate(state.offsets)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = builder.create( - op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + builder, op.getLoc(), builder.getIndexAttr(sIntAttr.value())); replacements.push_back(constOp.getResult()); } else { replacements.push_back(cast(s)); @@ -1655,8 +1649,8 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { for (auto [j, s] : llvm::enumerate(state.strides)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = builder.create( - op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + builder, op.getLoc(), builder.getIndexAttr(sIntAttr.value())); replacements.push_back(constOp.getResult()); } else { replacements.push_back(cast(s)); @@ -1720,7 +1714,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, } } - auto loadOp = builder.create(loc, ptr, dims, scalarOther); + auto loadOp = tts::LoadOp::create(builder, loc, ptr, dims, scalarOther); LLVM_DEBUG({ llvm::dbgs() << "creating tts::load:\n"; @@ -1850,7 +1844,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, dims = mstate.dims; } - auto storeOp = builder.create(loc, ptr, val, dims); + auto storeOp = tts::StoreOp::create(builder, loc, ptr, val, dims); LLVM_DEBUG({ llvm::dbgs() << "creating tts::store:\n"; diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index 10240298..9192e32e 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -56,8 +56,8 @@ static memref::SubViewOp getSubview(int rank, ArrayRef dims, auto dstType = memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); - return b.create(loc, cast(dstType), source, - offsets, dims, strides); + return memref::SubViewOp::create(b, loc, cast(dstType), source, + offsets, dims, strides); } static Type getElementTypeStructuredPtr(tts::MakeTensorPtrOp op) { @@ -182,8 +182,9 @@ static Value rewriteGatherScatterPtrElement( SmallVector dynSizes; // sizes are always static auto sizes = mlir::getMixedValues(staticSizes, dynSizes, rewriter); - auto castOp = rewriter.create( - op.getLoc(), resultType, basePtr, targetOffset, sizes, mixedStrides); + auto castOp = memref::ReinterpretCastOp::create( + rewriter, op.getLoc(), resultType, basePtr, targetOffset, sizes, + mixedStrides); return castOp.getResult(); } @@ -198,28 +199,28 @@ static void fillWithValue(Location loc, Value alloc, Value other, // For each dimension check if dims[i] < shape[i], or-accumulate // the result auto accBase = - rewriter.create(loc, rewriter.getBoolAttr(false)) + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)) .getResult(); for (size_t i = 0; i < shape.size(); i++) { - auto shapei = rewriter.create( - loc, rewriter.getIndexAttr(shape[i])); + auto shapei = arith::ConstantOp::create(rewriter, loc, + rewriter.getIndexAttr(shape[i])); Value dimi = dyn_cast(mixedDims[i]); if (!dimi) { - dimi = rewriter.create( - loc, rewriter.getIndexAttr(staticMaskDims[i])); + dimi = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(staticMaskDims[i])); } - Value cmp = rewriter.create(loc, arith::CmpIPredicate::slt, - dimi, shapei); - accBase = rewriter.create(loc, accBase, cmp); + Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + dimi, shapei); + accBase = arith::OrIOp::create(rewriter, loc, accBase, cmp); } // condition the memset on the or-accumulation // initialize with padding prior to CopyOp - rewriter.create(loc, accBase, [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{other}, ValueRange{alloc}); - b.create(loc); + scf::IfOp::create(rewriter, loc, accBase, [&](OpBuilder &b, Location loc) { + linalg::FillOp::create(b, loc, ValueRange{other}, ValueRange{alloc}); + scf::YieldOp::create(b, loc); }); } @@ -321,35 +322,36 @@ struct MakeTensorPtrConverter // wrapping around. ShapedType::kDynamic}); - Value rowSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[0])); - Value colSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[1])); + Value rowSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[1])); Value modN = ofrToIndexValue(op.getMixedShape()[1], loc, rewriter); - Value x = rewriter.create(loc, targetOffset, modN); - Value y = rewriter.create(loc, targetOffset, x); + Value x = arith::RemSIOp::create(rewriter, loc, targetOffset, modN); + Value y = arith::SubIOp::create(rewriter, loc, targetOffset, x); SmallVector strideVals = ofrsToIndexValues(op.getMixedStrides(), loc, rewriter); // First chunk - Value nextOffset = rewriter.create(loc, x, colSize); + Value nextOffset = arith::AddIOp::create(rewriter, loc, x, colSize); Value clampedOffset = - rewriter.create(loc, nextOffset, modN); - Value d1 = rewriter.create(loc, clampedOffset, x); + arith::MinSIOp::create(rewriter, loc, nextOffset, modN); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOffset, x); SmallVector sizes1{rowSize, d1}; - auto cast1 = rewriter.create( - loc, resultType, adaptor.getBase(), targetOffset, sizes1, strideVals); + auto cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), targetOffset, sizes1, + strideVals); // Second chunk - Value d2 = rewriter.create(loc, colSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, colSize, d1); SmallVector sizes2{rowSize, d2}; - auto cast2 = rewriter.create( - loc, resultType, adaptor.getBase(), y, sizes2, strideVals); + auto cast2 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), y, sizes2, strideVals); return {cast1, cast2}; } @@ -450,10 +452,10 @@ struct MakeTensorPtrConverter // allow this anymore. So we put dynamic instead. ShapedType::kDynamic}); - Value rowSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[0])); - Value colSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[1])); + Value rowSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[1])); Value strideRow = ofrToIndexValue(op.getMixedStrides()[0], loc, rewriter); Value strideCol = ofrToIndexValue(op.getMixedStrides()[1], loc, rewriter); @@ -462,26 +464,24 @@ struct MakeTensorPtrConverter // First chunk Value wrappedAroundOff = - rewriter.create(loc, targetOffset, strideRow); + arith::RemSIOp::create(rewriter, loc, targetOffset, strideRow); Value clampedOff = - rewriter.create(loc, modRow, wrappedAroundOff); - Value d1 = rewriter.create(loc, clampedOff, targetOffset); - d1 = rewriter.create(loc, d1, strideRow); - d1 = rewriter.create(loc, d1, rowSize); + arith::AddIOp::create(rewriter, loc, modRow, wrappedAroundOff); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOff, targetOffset); + d1 = arith::DivSIOp::create(rewriter, loc, d1, strideRow); + d1 = arith::MinSIOp::create(rewriter, loc, d1, rowSize); SmallVector sizes1{d1, colSize}; - memref::ReinterpretCastOp cast1 = - rewriter.create( - loc, resultType, adaptor.getBase(), targetOffset, sizes1, - ValueRange{strideRow, strideCol}); + memref::ReinterpretCastOp cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), targetOffset, sizes1, + ValueRange{strideRow, strideCol}); // Second chunk - Value d2 = rewriter.create(loc, rowSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, rowSize, d1); SmallVector sizes2{d2, colSize}; - memref::ReinterpretCastOp cast2 = - rewriter.create( - loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, - ValueRange{strideRow, strideCol}); + memref::ReinterpretCastOp cast2 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); return {cast1, cast2}; } @@ -515,8 +515,8 @@ struct MakeTensorPtrConverter llvm_unreachable("Unexpected split pointer shape"); } - auto combinedCast = rewriter.create( - op.getLoc(), op.getType(), casts); + auto combinedCast = UnrealizedConversionCastOp::create( + rewriter, op.getLoc(), op.getType(), casts); combinedCast->setAttr(wrapType, rewriter.getUnitAttr()); @@ -541,8 +541,8 @@ struct MakeTensorPtrConverter op, staticTargetOffset.value_or(ShapedType::kDynamic), staticStrides, resultShape); - auto castOp = rewriter.create( - op.getLoc(), resultType, adaptor.getBase(), targetOffset, + auto castOp = memref::ReinterpretCastOp::create( + rewriter, op.getLoc(), resultType, adaptor.getBase(), targetOffset, op.getMixedSizes(), mixedStrides); rewriter.replaceOp(op, castOp); @@ -626,71 +626,67 @@ struct LoadConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); - - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); - - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); - - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{zero, block1Col}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); - - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); + + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); + + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } void createStackedCopies(Value block1, Value block2, Value dst, Location loc, ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); - - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); - - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); - - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{block1Row, zero}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); - - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); + + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); + + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } memref::SubViewOp createSubview(Value src, ArrayRef offsets, @@ -700,8 +696,8 @@ struct LoadConverter : public OpConversionPattern { auto srcType = cast(src.getType()); auto dstType = memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return rewriter.create(loc, cast(dstType), - src, offsets, sizes, strides); + return memref::SubViewOp::create(rewriter, loc, cast(dstType), + src, offsets, sizes, strides); } std::pair @@ -711,9 +707,9 @@ struct LoadConverter : public OpConversionPattern { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; OpFoldResult subviewCol1 = - rewriter.create(loc, block1, 1).getResult(); + memref::DimOp::create(rewriter, loc, block1, 1).getResult(); OpFoldResult subviewCol2 = - rewriter.create(loc, block2, 1).getResult(); + memref::DimOp::create(rewriter, loc, block2, 1).getResult(); SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); @@ -732,9 +728,9 @@ struct LoadConverter : public OpConversionPattern { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; OpFoldResult subviewRow1 = - rewriter.create(loc, block1, 0).getResult(); + memref::DimOp::create(rewriter, loc, block1, 0).getResult(); OpFoldResult subviewRow2 = - rewriter.create(loc, block2, 0).getResult(); + memref::DimOp::create(rewriter, loc, block2, 0).getResult(); SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); @@ -757,8 +753,8 @@ struct LoadConverter : public OpConversionPattern { auto tensorType = cast(op.getType()); auto elemType = tensorType.getElementType(); - auto alloc = rewriter.create( - loc, MemRefType::get(tensorType.getShape(), elemType)); + auto alloc = memref::AllocOp::create( + rewriter, loc, MemRefType::get(tensorType.getShape(), elemType)); // No mask assert(!other && "other value used in non-masked load"); @@ -781,11 +777,12 @@ struct LoadConverter : public OpConversionPattern { llvm_unreachable("unexpected wraparound type"); } } else { - rewriter.create(loc, ptr, alloc); + memref::CopyOp::create(rewriter, loc, ptr, alloc); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create(rewriter, loc, tensorType, + alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -801,8 +798,8 @@ struct LoadConverter : public OpConversionPattern { auto tensorType = cast(op.getType()); auto elemType = tensorType.getElementType(); - auto alloc = rewriter.create( - loc, MemRefType::get(tensorType.getShape(), elemType)); + auto alloc = memref::AllocOp::create( + rewriter, loc, MemRefType::get(tensorType.getShape(), elemType)); SmallVector mixedDims = op.getMixedMaskDims(); @@ -842,11 +839,12 @@ struct LoadConverter : public OpConversionPattern { getSubview(tensorType.getRank(), mixedDims, ptr, loc, rewriter); memref::SubViewOp dstSubview = getSubview(tensorType.getRank(), mixedDims, alloc, loc, rewriter); - rewriter.create(loc, srcSubview, dstSubview); + memref::CopyOp::create(rewriter, loc, srcSubview, dstSubview); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create(rewriter, loc, tensorType, + alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -864,7 +862,7 @@ struct LoadConverter : public OpConversionPattern { auto indexOffsetTy = RankedTensorType::get(offsetShapedType.getShape(), rewriter.getIndexType()); gatherOffset = - rewriter.create(loc, indexOffsetTy, gatherOffset) + arith::IndexCastOp::create(rewriter, loc, indexOffsetTy, gatherOffset) .getResult(); int gatherDim = ptr.getGatherScatterDim(); @@ -881,7 +879,7 @@ struct LoadConverter : public OpConversionPattern { auto resultType = dyn_cast(op.getResult().getType()); auto allocType = MemRefType::get(resultType.getShape(), resultType.getElementType()); - auto alloc = rewriter.create(loc, allocType); + auto alloc = memref::AllocOp::create(rewriter, loc, allocType); auto allocStrides = mlir::getMixedValues( allocType.getStridesAndOffset().first, dynSizes, rewriter); @@ -892,9 +890,9 @@ struct LoadConverter : public OpConversionPattern { } // Create loop to iterate every offset in gatherOffset. - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); Value upperBound = - rewriter.create(loc, offsetSize).getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize).getResult(); if (op.hasMask()) { SmallVector mixedDims = op.getMixedMaskDims(); OpFoldResult gatherMaskDim = mixedDims[gatherDim]; @@ -911,26 +909,26 @@ struct LoadConverter : public OpConversionPattern { gatherMaskDimValue = offsetSize; } offsetSize = std::min(offsetSize, gatherMaskDimValue); - upperBound = rewriter.create(loc, offsetSize) + upperBound = arith::ConstantIndexOp::create(rewriter, loc, offsetSize) .getResult(); } else { // Use arith::MinSIOp to get the minimum value of gatherMaskDim // and offsetSize. auto gatherMaskDimVal = cast(gatherMaskDim); auto offsetSizeVal = - rewriter.create(loc, offsetSize); - upperBound = - rewriter - .create(loc, gatherMaskDimVal, offsetSizeVal) - .getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize); + upperBound = arith::MinSIOp::create(rewriter, loc, gatherMaskDimVal, + offsetSizeVal) + .getResult(); } } - auto step = rewriter.create(loc, 1); - auto loop = rewriter.create(loc, lowerBound, upperBound, step); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto loop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Create tensor from alloc and use it as the result to replace op. - Value tensor = rewriter.create( - loc, op.getType(), alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create( + rewriter, loc, op.getType(), alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); // Build loop body. @@ -941,15 +939,15 @@ struct LoadConverter : public OpConversionPattern { if (Value unstructuredMask = ptr.getGatherScatterMask()) { // If the gather scatter mask is present, we need to use it to guard the // load. - auto maskValue = rewriter.create( - loc, unstructuredMask, ValueRange{inductionVar}); - auto ifOp = rewriter.create(loc, maskValue); + auto maskValue = tensor::ExtractOp::create( + rewriter, loc, unstructuredMask, ValueRange{inductionVar}); + auto ifOp = scf::IfOp::create(rewriter, loc, maskValue); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); } // Load the offsetElt first. - auto gatherOffsetElt = rewriter.create( - loc, gatherOffset, ValueRange{inductionVar}); + auto gatherOffsetElt = tensor::ExtractOp::create( + rewriter, loc, gatherOffset, ValueRange{inductionVar}); // reinterpret_cast to current row as memRefPtr[gatherOffsetElt]. Value srcPtr = rewriteGatherScatterPtrElement(staticSizes, ptr, memRefPtr, @@ -971,11 +969,10 @@ struct LoadConverter : public OpConversionPattern { // Use oneStrides for subview. auto dstSubViewType = memref::SubViewOp::inferResultType( cast(srcPtr.getType()), maskOffsets, sizes, oneStrides); - srcPtr = - rewriter - .create(loc, cast(dstSubViewType), + srcPtr = memref::SubViewOp::create(rewriter, loc, + cast(dstSubViewType), srcPtr, maskOffsets, sizes, oneStrides) - .getResult(); + .getResult(); } // alloc[inductionVar] @@ -983,11 +980,11 @@ struct LoadConverter : public OpConversionPattern { allocOffsets[gatherDim] = inductionVar; auto dstAllocType = memref::SubViewOp::inferResultType( allocType, allocOffsets, sizes, oneStrides); - auto dstSubview = rewriter.create( - loc, cast(dstAllocType), alloc, allocOffsets, sizes, - oneStrides); + auto dstSubview = + memref::SubViewOp::create(rewriter, loc, cast(dstAllocType), + alloc, allocOffsets, sizes, oneStrides); // Copy srcPtr to alloc[inductionVar]. - rewriter.create(loc, srcPtr, dstSubview); + memref::CopyOp::create(rewriter, loc, srcPtr, dstSubview); return success(); } @@ -1027,8 +1024,8 @@ struct StoreConverter : public OpConversionPattern { auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); - return b.create(loc, dstType, source, offsets, dims, - strides); + return tensor::ExtractSliceOp::create(b, loc, dstType, source, offsets, + dims, strides); } LogicalResult rewriteScatter(tts::MakeGatherScatterTensorPtrOp ptr, @@ -1043,7 +1040,7 @@ struct StoreConverter : public OpConversionPattern { auto indexOffsetTy = RankedTensorType::get(offsetShapedType.getShape(), rewriter.getIndexType()); gatherOffset = - rewriter.create(loc, indexOffsetTy, gatherOffset) + arith::IndexCastOp::create(rewriter, loc, indexOffsetTy, gatherOffset) .getResult(); int gatherDim = ptr.getGatherScatterDim(); @@ -1057,9 +1054,9 @@ struct StoreConverter : public OpConversionPattern { auto sizes = mlir::getMixedValues(staticSizes, dynSizes, rewriter); // Create loop to iterate every offset in gatherOffset. - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); Value upperBound = - rewriter.create(loc, offsetSize).getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize).getResult(); if (op.hasMask()) { SmallVector mixedDims = op.getMixedMaskDims(); OpFoldResult gatherMaskDim = mixedDims[gatherDim]; @@ -1076,22 +1073,21 @@ struct StoreConverter : public OpConversionPattern { gatherMaskDimValue = offsetSize; } offsetSize = std::min(offsetSize, gatherMaskDimValue); - upperBound = rewriter.create(loc, offsetSize) + upperBound = arith::ConstantIndexOp::create(rewriter, loc, offsetSize) .getResult(); } else { // Use arith::MinSIOp to get the minimum value of gatherMaskDim // and offsetSize. auto gatherMaskDimVal = cast(gatherMaskDim); auto offsetSizeVal = - rewriter.create(loc, offsetSize); - upperBound = - rewriter - .create(loc, gatherMaskDimVal, offsetSizeVal) - .getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize); + upperBound = arith::MinSIOp::create(rewriter, loc, gatherMaskDimVal, + offsetSizeVal) + .getResult(); } } - auto step = rewriter.create(loc, 1); - auto loop = rewriter.create(loc, lowerBound, upperBound, step); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto loop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Build loop body. rewriter.setInsertionPointToStart(loop.getBody()); @@ -1101,15 +1097,15 @@ struct StoreConverter : public OpConversionPattern { if (Value unstructuredMask = ptr.getGatherScatterMask()) { // If the gather scatter mask is present, we need to use it to guard the // store. - auto maskValue = rewriter.create( - loc, unstructuredMask, ValueRange{inductionVar}); - auto ifOp = rewriter.create(loc, maskValue); + auto maskValue = tensor::ExtractOp::create( + rewriter, loc, unstructuredMask, ValueRange{inductionVar}); + auto ifOp = scf::IfOp::create(rewriter, loc, maskValue); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); } // Load the offsetElt first. - auto gatherOffsetElt = rewriter.create( - loc, gatherOffset, ValueRange{inductionVar}); + auto gatherOffsetElt = tensor::ExtractOp::create( + rewriter, loc, gatherOffset, ValueRange{inductionVar}); // Create extract_slice stVal[inductionVar]. unsigned rank = ptr.getSizes().size(); @@ -1125,8 +1121,8 @@ struct StoreConverter : public OpConversionPattern { } // The subview should not apply an additional stride to the source. SmallVector oneStrides(rank, OpFoldResult(step)); - auto slice = rewriter.create( - loc, stVal, stValOffsets, sizes, oneStrides); + auto slice = tensor::ExtractSliceOp::create( + rewriter, loc, stVal, stValOffsets, sizes, oneStrides); // reinterpret_cast to current row as memRefPtr[gatherOffsetElt]. Value dstPtr = rewriteGatherScatterPtrElement(staticSizes, ptr, memRefPtr, @@ -1141,14 +1137,13 @@ struct StoreConverter : public OpConversionPattern { cast(dstPtr.getType()), maskOffsets, sizes, oneStrides); dstPtr = - rewriter - .create(loc, cast(dstType), dstPtr, - maskOffsets, sizes, oneStrides) + memref::SubViewOp::create(rewriter, loc, cast(dstType), + dstPtr, maskOffsets, sizes, oneStrides) .getResult(); } // store slice to dstPtr. - auto storeOp = rewriter.create( - loc, slice, dstPtr); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, slice, dstPtr); storeOp.setWritable(true); rewriter.eraseOp(op); @@ -1182,12 +1177,12 @@ struct StoreConverter : public OpConversionPattern { getExtractSlice(rank, mixedDims, storeValue, loc, rewriter); auto dstSubview = getSubview(rank, mixedDims, ptr, loc, rewriter); - auto storeOp = rewriter.create( - loc, srcSlice, dstSubview); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, srcSlice, dstSubview); storeOp.setWritable(true); } else { - auto storeOp = rewriter.create( - loc, storeValue, ptr); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, storeValue, ptr); storeOp.setWritable(true); } diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index c0d4cd0a..2a393e8f 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -51,13 +51,15 @@ class PtrToUnrankedMemrefConverter : public TypeConverter { addTargetMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); } diff --git a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp index 8c86fe26..39cf0ba7 100644 --- a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp +++ b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -215,7 +215,8 @@ class TritonArithToLinalgPass func.getAllArgAttrs(argAttrs); func.getAllResultAttrs(resAttrs); - auto funcFunc = builder.create(func.getLoc(), name, type); + auto funcFunc = + func::FuncOp::create(builder, func.getLoc(), name, type); // Preserve the visibility attribute funcFunc.setVisibility(func.getVisibility()); funcFunc.setAllArgAttrs(argAttrs); @@ -234,7 +235,7 @@ class TritonArithToLinalgPass // considered terminators. if (isa(term)) { builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); + func::ReturnOp::create(builder, func.getLoc(), term->getOperands()); term->erase(); } } diff --git a/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp b/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp index b75e7ddb..09d8c23a 100644 --- a/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp +++ b/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp @@ -61,7 +61,8 @@ class TritonFunctionSignatureConverter : public TypeConverter { auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }; addSourceMaterialization(createUnrealizedCast); diff --git a/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp b/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp index 95061ac4..0130ede6 100644 --- a/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp @@ -83,8 +83,8 @@ struct CollapseFill : public OpRewritePattern { } auto elementType = resultType.getElementType(); - auto output = rewriter.create( - loc, + auto output = memref::CollapseShapeOp::create( + rewriter, loc, MemRefType::get(llvm::ArrayRef{resultType.getNumElements()}, elementType), result, reassociationMap); @@ -112,14 +112,15 @@ struct CollapseFill : public OpRewritePattern { reassociationMap[0].push_back(rewriter.getAffineDimExpr(i)); } - auto init = rewriter.create( - loc, RankedTensorType::get({resultType.getNumElements()}, elementType), + auto init = tensor::CollapseShapeOp::create( + rewriter, loc, + RankedTensorType::get({resultType.getNumElements()}, elementType), op.getOutputs()[0], reassociationMap); auto fillOp = - rewriter.create(loc, op.getInputs(), ValueRange{init}); + linalg::FillOp::create(rewriter, loc, op.getInputs(), ValueRange{init}); - auto expandOp = rewriter.create( - loc, result.getType(), fillOp.getResult(0), reassociationMap); + auto expandOp = tensor::ExpandShapeOp::create( + rewriter, loc, result.getType(), fillOp.getResult(0), reassociationMap); rewriter.replaceOp(op, expandOp.getResult()); return success(); @@ -224,8 +225,8 @@ struct CollapseTranspose : public OpRewritePattern { auto loc = op.getLoc(); sourceType = RankedTensorType::get(collapseShapeInput, elementType); - source = rewriter.create(loc, sourceType, source, - reassociationMap); + source = tensor::CollapseShapeOp::create(rewriter, loc, sourceType, source, + reassociationMap); SmallVector reassociationMapRe(reassociationMap.size()); int idx = 0; @@ -235,12 +236,12 @@ struct CollapseTranspose : public OpRewritePattern { } } - Value transposeInit = rewriter.create( - loc, RankedTensorType::get(transposedShape, elementType), op.getInit(), - reassociationMapRe); + Value transposeInit = tensor::CollapseShapeOp::create( + rewriter, loc, RankedTensorType::get(transposedShape, elementType), + op.getInit(), reassociationMapRe); Value transpose = - rewriter.create(loc, source, transposeInit, perm) + linalg::TransposeOp::create(rewriter, loc, source, transposeInit, perm) .getResults()[0]; rewriter.replaceOpWithNewOp( @@ -337,8 +338,8 @@ struct CollapseBroadCast : public OpRewritePattern { auto loc = op.getLoc(); sourceType = RankedTensorType::get(collapseShapeInput, elementType); - input = rewriter.create(loc, sourceType, input, - reassociationMap); + input = tensor::CollapseShapeOp::create(rewriter, loc, sourceType, input, + reassociationMap); resultType = RankedTensorType::get(collapseShapeOutput, elementType); size_t resultRank = resultType.getRank(); @@ -351,13 +352,14 @@ struct CollapseBroadCast : public OpRewritePattern { assert(op->getNumResults() == 1 && "code assumes single result!"); - auto init = rewriter.create( - loc, RankedTensorType::get(resultType.getShape(), elementType), + auto init = tensor::CollapseShapeOp::create( + rewriter, loc, + RankedTensorType::get(resultType.getShape(), elementType), op.getOutputs()[0], reassociationMap); - auto linalgOp = rewriter.create( - loc, init->getResultTypes(), ValueRange{input}, ValueRange{init}, - indexingMaps, getNParallelLoopsAttrs(resultRank)); + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, init->getResultTypes(), ValueRange{input}, + ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(resultRank)); rewriter.cloneRegionBefore(op.getRegion(), linalgOp.getRegion(), linalgOp.getRegion().begin()); linalgOp->setAttr("broadcastDims", @@ -453,8 +455,8 @@ struct CollapseReduce : public OpRewritePattern { auto elementType = inputType.getElementType(); auto loc = op.getLoc(); auto newInputType = RankedTensorType::get(collapseShapeInput, elementType); - input = rewriter.create(loc, newInputType, input, - reassociationMap); + input = tensor::CollapseShapeOp::create(rewriter, loc, newInputType, input, + reassociationMap); SmallVector reassociationMapOutput; int idx = 0; @@ -469,12 +471,12 @@ struct CollapseReduce : public OpRewritePattern { rewriter.getAffineDimExpr(idx++)); } } - auto init = rewriter.create( - loc, RankedTensorType::get(collapseShapeOutput, elementType), + auto init = tensor::CollapseShapeOp::create( + rewriter, loc, RankedTensorType::get(collapseShapeOutput, elementType), op.getInits()[0], reassociationMapOutput); - auto newReduce = rewriter.create( - loc, init->getResultTypes(), ValueRange{input}, ValueRange{init}, - newDims); + auto newReduce = + linalg::ReduceOp::create(rewriter, loc, init->getResultTypes(), + ValueRange{input}, ValueRange{init}, newDims); rewriter.cloneRegionBefore(op.getRegion(), newReduce.getRegion(), newReduce.getRegion().begin()); diff --git a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp index e8a3b7f6..0f7eda66 100644 --- a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp @@ -61,8 +61,8 @@ struct SimplifyUnrealizedCast } auto prevInput = unrealizedCast.getInputs().front(); - auto newCast = rewriter.create( - op->getLoc(), op->getResultTypes(), ValueRange{prevInput}); + auto newCast = UnrealizedConversionCastOp::create( + rewriter, op->getLoc(), op->getResultTypes(), ValueRange{prevInput}); rewriter.replaceOp(op, newCast); return success(); @@ -90,11 +90,11 @@ struct FromMemrefConverter if (unrankedInput && isa(outType)) { // from_memref only takes ranked memref, cast the unranked memref to // ranked memref first. - auto rankedMemref = rewriter.create( - op.getLoc(), MemRefType::get({1}, unrankedInput.getElementType()), - input); - auto memrefToPtr = rewriter.create( - op->getLoc(), + auto rankedMemref = memref::CastOp::create( + rewriter, op.getLoc(), + MemRefType::get({1}, unrankedInput.getElementType()), input); + auto memrefToPtr = tptr::FromMemrefOp::create( + rewriter, op->getLoc(), ptr::PtrType::get( rewriter.getContext(), tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())), @@ -128,14 +128,15 @@ struct ToMemrefConverter : public OpRewritePattern { // to_memref can only cast to ranked static shape memref, we have to cast // the resulting memref back to unranked auto elemType = outUnrankedMemrefType.getElementType(); - auto ptrToMemref = rewriter.create( - op->getLoc(), MemRefType::get({1}, elemType), input); + auto ptrToMemref = tptr::ToMemrefOp::create( + rewriter, op->getLoc(), MemRefType::get({1}, elemType), input); SmallVector sizes = {rewriter.getIndexAttr(1)}; SmallVector newStrides = {rewriter.getIndexAttr(1)}; - auto newUnrankedMemref = rewriter.create( - op->getLoc(), MemRefType::get({ShapedType::kDynamic}, elemType), - ptrToMemref, rewriter.getIndexAttr(0), sizes, newStrides); + auto newUnrankedMemref = memref::ReinterpretCastOp::create( + rewriter, op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, elemType), ptrToMemref, + rewriter.getIndexAttr(0), sizes, newStrides); rewriter.replaceAllUsesWith(output, newUnrankedMemref); rewriter.eraseOp(op); diff --git a/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp b/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp index 8234360d..6e3cc9d6 100644 --- a/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp @@ -182,9 +182,9 @@ struct AddPtrConverter : public OpConversionPattern { auto pointeeType = cast(op.getType()).getPointeeType(); auto offsetType = op.getOffset().getType(); auto pointeeSizeInBytes = - rewriter.create(loc, offsetType, pointeeType); - auto scaledOffset = - rewriter.create(loc, op.getOffset(), pointeeSizeInBytes); + tptr::TypeOffsetOp::create(rewriter, loc, offsetType, pointeeType); + auto scaledOffset = arith::MulIOp::create(rewriter, loc, op.getOffset(), + pointeeSizeInBytes); rewriter.replaceOpWithNewOp( op, ptr::PtrType::get( @@ -214,36 +214,37 @@ struct LoadConverter : public OpConversionPattern { auto pointeeType = cast(ptr.getType()).getPointeeType(); - auto memref = rewriter.create( - op->getLoc(), MemRefType::get({1}, pointeeType), adaptor.getPtr()); + auto memref = tptr::ToMemrefOp::create(rewriter, op->getLoc(), + MemRefType::get({1}, pointeeType), + adaptor.getPtr()); - auto zero = rewriter.create(op.getLoc(), 0); + auto zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); if (op.getMask()) { - auto ifOp = rewriter.create( - op->getLoc(), op.getMask(), + auto ifOp = scf::IfOp::create( + rewriter, op->getLoc(), op.getMask(), [&](OpBuilder &b, Location loc) { // Truthy case, load from the index. - Value memrefLoad = rewriter.create( - op->getLoc(), memref, ValueRange{zero}); - b.create(loc, memrefLoad); + Value memrefLoad = memref::LoadOp::create(rewriter, op->getLoc(), + memref, ValueRange{zero}); + scf::YieldOp::create(b, loc, memrefLoad); }, [&](OpBuilder &b, Location loc) { // Falsy case, yield `other` or 0 as the default value. if (op.getOther()) { - b.create(loc, op.getOther()); + scf::YieldOp::create(b, loc, op.getOther()); } else { auto elemType = op.getType(); auto zeroAttr = b.getZeroAttr(elemType); assert(zeroAttr && "unexpected element type"); - Value val = b.create(loc, zeroAttr); - b.create(loc, val); + Value val = arith::ConstantOp::create(b, loc, zeroAttr); + scf::YieldOp::create(b, loc, val); } }); rewriter.replaceOp(op, ifOp); } else { - auto memrefLoad = rewriter.create(op->getLoc(), memref, - ValueRange{zero}); + auto memrefLoad = memref::LoadOp::create(rewriter, op->getLoc(), memref, + ValueRange{zero}); rewriter.replaceOp(op, memrefLoad); } @@ -272,18 +273,19 @@ struct StoreConverter : public OpConversionPattern { IRRewriter::InsertionGuard g(rewriter); if (op.getMask()) { - auto ifOp = rewriter.create(op->getLoc(), op.getMask(), - /*withElseRegion*/ false); + auto ifOp = scf::IfOp::create(rewriter, op->getLoc(), op.getMask(), + /*withElseRegion*/ false); rewriter.setInsertionPointToStart( &ifOp.getThenRegion().getBlocks().front()); } - auto memref = rewriter.create( - op->getLoc(), MemRefType::get({1}, pointeeType), adaptor.getPtr()); - auto zero = rewriter.create(op.getLoc(), 0); + auto memref = tptr::ToMemrefOp::create(rewriter, op->getLoc(), + MemRefType::get({1}, pointeeType), + adaptor.getPtr()); + auto zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - rewriter.create(op->getLoc(), op.getValue(), memref, - ValueRange{zero}); + memref::StoreOp::create(rewriter, op->getLoc(), op.getValue(), memref, + ValueRange{zero}); rewriter.eraseOp(op); @@ -353,9 +355,10 @@ struct LinalgPtrConverter : public OpConversionPattern { return failure(); } - auto replacement = rewriter.create( - op.getLoc(), convertedTypes, adaptor.getInputs(), adaptor.getOutputs(), - op.getIndexingMapsArray(), op.getIteratorTypesArray()); + auto replacement = linalg::GenericOp::create( + rewriter, op.getLoc(), convertedTypes, adaptor.getInputs(), + adaptor.getOutputs(), op.getIndexingMapsArray(), + op.getIteratorTypesArray()); Region ®ion = op.getRegion(); Block &block = region.front(); @@ -431,7 +434,8 @@ class TritonPtrTypeConverter : public TypeConverter { }); auto createCast = [&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }; addTargetMaterialization(createCast); diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 0c3269e1..ec80fcd9 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -137,20 +137,20 @@ class TritonToStructuredPass // result is still being used by another tt.load or tt.store. converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); // Compute the target materialization, given a value with the pointer type, // convert that value to a tuple type. - converter.addTargetMaterialization([](OpBuilder &builder, - TypeRange resultTypes, - ValueRange inputs, - Location loc) -> SmallVector { - return builder - .create(loc, resultTypes, inputs.front()) - ->getResults(); - }); + converter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc) -> SmallVector { + return UnrealizedConversionCastOp::create(builder, loc, resultTypes, + inputs.front()) + ->getResults(); + }); ConversionTarget target(getContext()); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, @@ -200,8 +200,8 @@ class TritonToStructuredPass // during reconcile-unrealized-conversion-casts. converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder - .create(loc, resultType, inputs[0]) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs[0]) ->getResult(0); }); @@ -214,8 +214,8 @@ class TritonToStructuredPass converter.addTargetMaterialization([](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, Location loc) { - auto placeholder = builder.create( - loc, inputs.front().getDefiningOp()->getOperand(0)); + auto placeholder = tts::GetStructuredStateOp::create( + builder, loc, inputs.front().getDefiningOp()->getOperand(0)); assert(llvm::equal(placeholder.getResultTypes(), resultTypes)); return placeholder.getResults(); }); diff --git a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp index 88568797..867ebd4d 100644 --- a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp +++ b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp @@ -252,8 +252,8 @@ class TritonToUnstructuredPass } OpBuilder b(func->getRegion(0)); - Value zero = b.create( - arg.getLoc(), + Value zero = arith::ConstantOp::create( + b, arg.getLoc(), b.getIntegerAttr(IntegerType::get(&getContext(), defaultBitWidth), 0)); @@ -271,8 +271,8 @@ class TritonToUnstructuredPass } auto res = op.getResult(); OpBuilder b(op); - Value zero = b.create( - op.getLoc(), + Value zero = arith::ConstantOp::create( + b, op.getLoc(), b.getIntegerAttr(IntegerType::get(&getContext(), defaultBitWidth), 0)); @@ -304,8 +304,8 @@ class TritonToUnstructuredPass // We are converting a pointer to an integer here, // materialized the pointer using the accumulated offset // that we have stored so far. - auto materializedAddPtr = b.create( - op->getLoc(), offsetInfo.ptrType, offsetInfo.ptr, + auto materializedAddPtr = triton::AddPtrOp::create( + b, op->getLoc(), offsetInfo.ptrType, offsetInfo.ptr, offsetInfo.offset); // Change the op to use the "simplified" pointer above. @@ -337,19 +337,19 @@ class TritonToUnstructuredPass auto resWidth = std::max(lhsWidth, rhsWidth); if (lhsWidth < resWidth) { - prevOff = b.create( - loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), + prevOff = arith::ExtSIOp::create( + b, loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), prevOff); } if (rhsWidth < resWidth) { - off = b.create( - loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), + off = arith::ExtSIOp::create( + b, loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), off); } - auto accumulatedOff = b.create( - loc, getPtrOffsetType(addptr.getType(), resWidth), + auto accumulatedOff = arith::AddIOp::create( + b, loc, getPtrOffsetType(addptr.getType(), resWidth), prevOff, off); PtrOffset newOffsetInfo{offsetInfo.ptr, addptr.getType(), @@ -491,8 +491,8 @@ class TritonToUnstructuredPass } } - auto gather = b.create( - loc, load.getType(), offsetInfo.ptr, offsetInfo.offset, + auto gather = tts::GatherOp::create( + b, loc, load.getType(), offsetInfo.ptr, offsetInfo.offset, load.getMask(), other); load->replaceAllUsesWith(gather->getResults()); @@ -501,8 +501,9 @@ class TritonToUnstructuredPass }) .Case([&](triton::StoreOp store) { auto offsetInfo = offsetMap.at(store.getPtr()); - b.create(loc, offsetInfo.ptr, offsetInfo.offset, - store.getValue(), store.getMask()); + tts::ScatterOp::create(b, loc, offsetInfo.ptr, + offsetInfo.offset, store.getValue(), + store.getMask()); store->erase(); return success(); }) @@ -526,26 +527,26 @@ class TritonToUnstructuredPass if (baseOffType != currOffType) { if (currOffType.isIndex()) { - baseOffset = b.create( - loc, b.getIndexType(), baseOffset); + baseOffset = arith::IndexCastOp::create( + b, loc, b.getIndexType(), baseOffset); } else if (currOffType.isInteger()) { if (baseOffType.getIntOrFloatBitWidth() < currOffType.getIntOrFloatBitWidth()) { - baseOffset = b.create(loc, currOffType, - baseOffset); + baseOffset = arith::ExtSIOp::create(b, loc, currOffType, + baseOffset); } else { // MakeTensorPtrOp only takes i32 offsets, so we need // to truncate if the offsets were already in i64 makeTensorPtr.emitWarning( "truncating offsets which may result in data loss"); - baseOffset = b.create(loc, currOffType, - baseOffset); + baseOffset = arith::TruncIOp::create(b, loc, currOffType, + baseOffset); } } } - auto accumulatedOffset = b.create( - loc, currOffset.getType(), baseOffset, currOffset); + auto accumulatedOffset = arith::AddIOp::create( + b, loc, currOffset.getType(), baseOffset, currOffset); offsetOpnd.set(accumulatedOffset); diff --git a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp index 316537bf..df97ac6a 100644 --- a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp +++ b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp @@ -53,7 +53,8 @@ class PtrToUnrankedMemrefConverter : public TypeConverter { addTargetMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); } @@ -89,11 +90,11 @@ struct ScalarLoadConverter : public OpConversionPattern { auto basePtr = adaptor.getPtr(); auto offset = adaptor.getOffset(); - Value loadIndex = rewriter.create( - loc, rewriter.getIndexType(), offset); + Value loadIndex = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), offset); - auto memref = rewriter.create( - loc, + auto memref = memref::ReinterpretCastOp::create( + rewriter, loc, getMemrefTypeForScalarPtr( cast(gatherOp.getPtr().getType()), rewriter.getContext()), @@ -103,8 +104,8 @@ struct ScalarLoadConverter : public OpConversionPattern { auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - auto scalarLoadOp = rewriter.create( - loc, memref, zeroMap, ValueRange{}); + auto scalarLoadOp = affine::AffineLoadOp::create(rewriter, loc, memref, + zeroMap, ValueRange{}); rewriter.replaceOp(gatherOp, scalarLoadOp.getResult()); @@ -134,11 +135,11 @@ struct ScalarStoreConverter : public OpConversionPattern { auto basePtr = adaptor.getPtr(); auto offset = adaptor.getOffset(); - Value storeIndex = rewriter.create( - loc, rewriter.getIndexType(), offset); + Value storeIndex = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), offset); - auto memref = rewriter.create( - loc, + auto memref = memref::ReinterpretCastOp::create( + rewriter, loc, getMemrefTypeForScalarPtr( cast(scatterOp.getPtr().getType()), rewriter.getContext()), @@ -149,8 +150,8 @@ struct ScalarStoreConverter : public OpConversionPattern { auto storeVal = scatterOp.getValue(); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - rewriter.create(loc, storeVal, memref, zeroMap, - ValueRange{}); + affine::AffineStoreOp::create(rewriter, loc, storeVal, memref, zeroMap, + ValueRange{}); rewriter.eraseOp(scatterOp); return success(); @@ -186,22 +187,19 @@ struct GatherConverter : public OpConversionPattern { // Treat the base pointer (memref) as 1D because the offsets are all // relative to a single base pointer (already collapsed). - auto baseMemref = rewriter - .create( - loc, - MemRefType::get({ShapedType::kDynamic}, - resultType.getElementType()), - ptr) - .getResult(); + auto baseMemref = + memref::CastOp::create(rewriter, loc, + MemRefType::get({ShapedType::kDynamic}, + resultType.getElementType()), + ptr) + .getResult(); auto baseTensor = - rewriter - .create( - loc, - RankedTensorType::get( - SmallVector(1, ShapedType::kDynamic), - resultType.getElementType()), - baseMemref, true /* restrict */, false /* writable */) + bufferization::ToTensorOp::create( + rewriter, loc, + RankedTensorType::get(SmallVector(1, ShapedType::kDynamic), + resultType.getElementType()), + baseMemref, true /* restrict */, false /* writable */) .getResult(); // The linalg.generic op should have the following inputs: @@ -213,10 +211,10 @@ struct GatherConverter : public OpConversionPattern { inputs.push_back(gatherOp.getMask()); } - auto emptyTensor = rewriter - .create(loc, resultType.getShape(), - resultType.getElementType()) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()) + .getResult(); // Affine maps for the inputs and one additional output. SmallVector affineMaps( @@ -227,16 +225,17 @@ struct GatherConverter : public OpConversionPattern { SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); - auto genericOp = rewriter.create( - loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps, - iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + auto genericOp = linalg::GenericOp::create( + rewriter, loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, + affineMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc, Value index) -> Value { Value index0 = - b.create(loc, b.getIndexType(), index); + arith::IndexCastOp::create(b, loc, b.getIndexType(), index); - return b.create(loc, baseTensor, - ValueRange{index0}); + return tensor::ExtractOp::create(b, loc, baseTensor, + ValueRange{index0}); }; auto offset = args[0]; @@ -245,34 +244,34 @@ struct GatherConverter : public OpConversionPattern { // If there is no mask, simply extract the current element from the // base tensor and use it as the yield value. auto loadValue = getValueAtIndex(b, loc, offset); - b.create(loc, loadValue); + linalg::YieldOp::create(b, loc, loadValue); } else { // If the mask value is truthy, the current element is loaded from // the base tensor using its offset. Otherwise, if `other` is // present, yield `other`. If `other` is not present, a default // value of 0 is used. auto mask = args[1]; - auto ifOp = b.create( - loc, mask, + auto ifOp = scf::IfOp::create( + b, loc, mask, [&](OpBuilder &b, Location loc) { // Truthy case, load from the index. auto value = getValueAtIndex(b, loc, offset); - b.create(loc, value); + scf::YieldOp::create(b, loc, value); }, [&](OpBuilder &b, Location loc) { // Falsy case, yield `other` or 0 as the default value. if (gatherOp.getOther()) { - b.create(loc, gatherOp.getOther()); + scf::YieldOp::create(b, loc, gatherOp.getOther()); } else { auto elemType = resultType.getElementType(); auto zeroAttr = b.getZeroAttr(elemType); assert(zeroAttr && "unexpected element type"); - Value extract = b.create(loc, zeroAttr); - b.create(loc, extract); + Value extract = arith::ConstantOp::create(b, loc, zeroAttr); + scf::YieldOp::create(b, loc, extract); } }); - b.create(loc, ifOp->getResult(0)); + linalg::YieldOp::create(b, loc, ifOp->getResult(0)); } }); @@ -312,11 +311,10 @@ struct ScatterConverter : public OpConversionPattern { // Treat the base pointer (memref) as 1D because the offsets are all // relative to a single base pointer (already collapsed). auto baseMemref = - rewriter - .create(loc, - MemRefType::get({ShapedType::kDynamic}, - valueType.getElementType()), - ptr) + memref::CastOp::create( + rewriter, loc, + MemRefType::get({ShapedType::kDynamic}, valueType.getElementType()), + ptr) .getResult(); // The linalg.generic op should have the following inputs: @@ -339,16 +337,16 @@ struct ScatterConverter : public OpConversionPattern { rewriter.setInsertionPoint(scatterOp); - auto genericOp = rewriter.create( - loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { + auto genericOp = linalg::GenericOp::create( + rewriter, loc, TypeRange{}, inputs, ValueRange{}, affineMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc, Value index, Value value) { Value index0 = - b.create(loc, b.getIndexType(), index); + arith::IndexCastOp::create(b, loc, b.getIndexType(), index); - b.create(loc, value, baseMemref, - ValueRange{index0}); + memref::StoreOp::create(b, loc, value, baseMemref, + ValueRange{index0}); }; auto offset = args[0]; @@ -362,14 +360,14 @@ struct ScatterConverter : public OpConversionPattern { // If the mask value is truthy, insert the current value to the // the base memref using its offset. Otherwise, noop. auto mask = args[2]; - auto ifOp = - b.create(loc, mask, [&](OpBuilder &b, Location loc) { + auto ifOp = scf::IfOp::create( + b, loc, mask, [&](OpBuilder &b, Location loc) { storeValueAtIndex(b, loc, offset, value); - b.create(loc); + scf::YieldOp::create(b, loc); }); } - b.create(loc); + linalg::YieldOp::create(b, loc); }); rewriter.eraseOp(scatterOp); diff --git a/lib/Dialect/TPtr/IR/TPtrDialect.cpp b/lib/Dialect/TPtr/IR/TPtrDialect.cpp index 3235a1e0..c4b8b7c7 100644 --- a/lib/Dialect/TPtr/IR/TPtrDialect.cpp +++ b/lib/Dialect/TPtr/IR/TPtrDialect.cpp @@ -51,27 +51,30 @@ void mlir::tptr::TPtrDialect::initialize() { } bool tptr::DefaultMemorySpaceAttr::isValidLoad( - Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, mlir::ptr::AtomicOrdering ordering, + std::optional alignment, const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } bool tptr::DefaultMemorySpaceAttr::isValidStore( - Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, mlir::ptr::AtomicOrdering ordering, + std::optional alignment, const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } bool tptr::DefaultMemorySpaceAttr::isValidAtomicOp( mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering, - IntegerAttr alignment, + std::optional alignment, const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } bool tptr::DefaultMemorySpaceAttr::isValidAtomicXchg( Type type, mlir::ptr::AtomicOrdering successOrdering, - mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + mlir::ptr::AtomicOrdering failureOrdering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index df7de996..150e926d 100644 --- a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -46,14 +46,14 @@ Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return builder.create(loc, resType, src); + return arith::SIToFPOp::create(builder, loc, resType, src); }) .Case([&](Operation *op) { auto resType = op->getResults()[0].getType(); if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return builder.create(loc, resType, src); + return arith::TruncFOp::create(builder, loc, resType, src); }) .Default([](Operation *op) { llvm_unreachable("unsupported op in generating "); diff --git a/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp b/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp index da4d9767..acf118a5 100644 --- a/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp +++ b/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp @@ -29,12 +29,12 @@ Value getSlice(OpBuilder &b, Location loc, Value source, ArrayRef strides) { return TypeSwitch(source.getType()) .Case([&](RankedTensorType t) -> Value { - return b.create(loc, source, offsets, sizes, - strides); + return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes, + strides); }) .Case([&](MemRefType type) -> Value { - return b.create(loc, source, offsets, sizes, - strides); + return memref::SubViewOp::create(b, loc, source, offsets, sizes, + strides); }) .Default([&](Type t) { return nullptr; }); } diff --git a/python/examples/conftest.py b/python/examples/conftest.py index 4bac989c..33397c5a 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -119,7 +119,8 @@ def pytest_collection_modifyitems(config, items): for param_name, param_value in item.callspec.params.items(): if (param_name.startswith('dtype') or param_name.endswith('dtype')) and param_value == 'bfloat16': item.add_marker(skip_marker_bfloat) - if param_name.startswith('input_precision') and param_value.startswith('tf32'): + if param_name.startswith('input_precision') and (param_value.startswith('tf32') + or param_value.startswith('bf16')): item.add_marker(skip_marker_tf32) if (param_name.startswith('dtype') or param_name.endswith('dtype')) and ('float8' in str(param_value)): item.add_marker(skip_marker_float8) diff --git a/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir b/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir index d8308a17..3d527765 100644 --- a/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir +++ b/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir @@ -15,7 +15,6 @@ module { // CHECK-LABEL: func.func @addi // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK-NOT: separator of consecutive DAGs @@ -28,7 +27,7 @@ module { // CHECK: linalg.yield [[VAR_3_]] : i32 // CHECK: } // CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1]>> // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir b/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir index a753cd61..c2ee2e09 100644 --- a/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir +++ b/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir @@ -71,10 +71,9 @@ module { // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 -// CHECK-DAG: [[CST_3074_:%.+]] = arith.constant 3074 : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_3074_]]{{.}}, sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: ?>> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_3074_]]{{.}}, sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> // CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_130_]] : index // CHECK: [[VAR_2_:%.+]] = arith.maxsi [[VAR_1_]], [[CST_2_]] : index @@ -93,12 +92,13 @@ module { // CHECK: scf.if [[VAR_12_]] { // CHECK: linalg.fill ins([[CST_0_]] : bf16) outs([[RES_]] : memref<128x256xbf16>) // CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> // CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> // CHECK: [[VAR_13_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x256xbf16> // CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_13_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_2_]] : (tensor, memref>) -> () +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: [[CAST_0:%.+]] = memref.cast [[VAR_subview_2_]] : memref> to memref> +// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[CAST_0]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir b/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir index 5c7f4f4e..d896683d 100644 --- a/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir +++ b/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir @@ -90,15 +90,15 @@ module { // CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_14_]], [[CST_4_]] : index // CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_16_]], [[VAR_5_]] : index // CHECK: [[VAR_18_:%.+]] = arith.subi [[VAR_17_]], [[VAR_14_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_18_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: [4, [[VAR_18_]]], strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK-DAG: [[VAR_19_:%.+]] = arith.subi [[CST_4_]], [[VAR_18_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: [4, [[VAR_19_]]], strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> // CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_19_]]{{.}} [1, 1] : memref> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> // CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> // CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_18_]]{{.}} [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> diff --git a/test/Conversion/StructuredToMemref/wraparound_stacked.mlir b/test/Conversion/StructuredToMemref/wraparound_stacked.mlir index 4e5ce945..6ee1f857 100644 --- a/test/Conversion/StructuredToMemref/wraparound_stacked.mlir +++ b/test/Conversion/StructuredToMemref/wraparound_stacked.mlir @@ -84,15 +84,15 @@ module { // CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_11_]] : index // CHECK: [[VAR_15_:%.+]] = arith.divsi [[VAR_14_]], [[VAR_1_]] : index // CHECK: [[VAR_16_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: [[[VAR_16_]], 4], strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> // CHECK-DAG: [[VAR_17_:%.+]] = arith.subi [[CST_4_]], [[VAR_16_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_17_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: [[[VAR_17_]], 4], strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> // CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref> to memref> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref> to memref> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref> to memref> // CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<4x4xf32> to memref> // CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_16_]], 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref> to memref> diff --git a/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir b/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir index d8c015c7..f4138993 100644 --- a/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir +++ b/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir @@ -29,7 +29,6 @@ module { // CHECK-LABEL: func.func @simple_cf_into_structured_load // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { // CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset f32 : i32 -// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 // CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_0_]] : memref<*xf32> to memref<1xf32> @@ -48,9 +47,9 @@ module { // CHECK: scf.yield [[VAR_7_1_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } // CHECK: [[VAR_4_:%.+]] = tptr.to_memref [[VAR_3_]] : <#tptr.default_memory_space> to memref<1xf32> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[VAR_4_]] to offset: {{.}}[[CST_6_]]{{.}}, sizes: [4], strides: [1] : memref<1xf32> to memref<4xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[VAR_4_]] to offset: [6], sizes: [4], strides: [1] : memref<1xf32> to memref<4xf32, strided<[1], offset: 6>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4xf32, strided<[1], offset: 6>> to memref<4xf32> // CHECK-DAG: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4xf32> to tensor<4xf32> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref<*xf32> to memref<4xf32, strided<[1]>> // CHECK: bufferization.materialize_in_destination [[VAR_5_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4xf32>, memref<4xf32, strided<[1]>>) -> () diff --git a/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir b/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir index 9fd9cc52..d48ebeb8 100644 --- a/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir +++ b/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir @@ -41,8 +41,8 @@ module { // CHECK: } -> tensor<1xi1> // CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_6_]]{{.}}[[CST_0_1_]]{{.}} : tensor<1xi1> // CHECK: scf.if [[VAR_extracted_]] { -// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_1_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[CST_42_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1], offset: ?>> +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store [[CST_42_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1]>> // CHECK: } // CHECK: return // CHECK: } diff --git a/triton-hash.txt b/triton-hash.txt index 80f8308f..0f8b487b 100644 --- a/triton-hash.txt +++ b/triton-hash.txt @@ -1 +1 @@ -e44bd1c83c1c3e8deac7c4f02683cfb3cc395c8b +dbfbc1e1e6cca56eeaa853050b7962e7445f0a82 \ No newline at end of file From 422b446957e899b6f56364ba293327e38db65e3a Mon Sep 17 00:00:00 2001 From: enjustli <798634436@qq.com> Date: Fri, 5 Dec 2025 00:03:45 +0800 Subject: [PATCH 3/3] Bump to triton-lang/triton@dbfbc1e1e6c --- lib/Analysis/MaskAnalysis.cpp | 3 +-- .../StructuredToMemref/StructuredToMemref.cpp | 3 +-- .../StructuredToMemref/wraparound_side_by_side.mlir | 11 +++++++---- .../StructuredToMemref/wraparound_stacked.mlir | 10 ++++++---- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 2fcdded5..5b84859b 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -65,8 +65,7 @@ tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, SmallVector offsets(getRank(), builder.getIndexAttr(0)); SmallVector strides(getRank(), builder.getIndexAttr(1)); - auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, - dims, strides); + auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, dims); return tensor::ExtractSliceOp::create(builder, loc, dstType, source, offsets, dims, strides); diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index 9192e32e..6a8ca9ea 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -1021,8 +1021,7 @@ struct StoreConverter : public OpConversionPattern { SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector strides(rank, b.getIndexAttr(1)); - auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, - dims, strides); + auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, dims); return tensor::ExtractSliceOp::create(b, loc, dstType, source, offsets, dims, strides); diff --git a/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir b/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir index d896683d..0075387c 100644 --- a/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir +++ b/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir @@ -57,6 +57,7 @@ module { // CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { +// CHECK: [[CONSTANT_0:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 @@ -96,11 +97,13 @@ module { // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: [4, [[VAR_19_]]], strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> // CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK: [[DIM_0:%.+]] = memref.dim [[VAR_reinterpret_cast_0_]], [[CONSTANT_0]] : memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: [[DIM_1:%.+]] = memref.dim [[VAR_reinterpret_cast_1_]], [[CONSTANT_0]] : memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[DIM_0]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_18_]]{{.}} [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[DIM_1]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[DIM_0]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[DIM_0]]{{.}} [2, [[DIM_1]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> // CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>> // CHECK: [[VAR_20_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> diff --git a/test/Conversion/StructuredToMemref/wraparound_stacked.mlir b/test/Conversion/StructuredToMemref/wraparound_stacked.mlir index 6ee1f857..953b3724 100644 --- a/test/Conversion/StructuredToMemref/wraparound_stacked.mlir +++ b/test/Conversion/StructuredToMemref/wraparound_stacked.mlir @@ -90,11 +90,13 @@ module { // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: [[[VAR_17_]], 4], strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> // CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref> to memref> +// CHECK: [[DIM_0:%.+]] = memref.dim [[VAR_reinterpret_cast_0_]], [[CST_0_1_]] : memref> +// CHECK: [[DIM_1:%.+]] = memref.dim [[VAR_reinterpret_cast_1_]], [[CST_0_1_]] : memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[DIM_0]], 3] [1, 1] : memref> to memref> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref> to memref> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_16_]], 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[DIM_1]], 3] [1, 1] : memref> to memref> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[DIM_0]], 3] [1, 1] : memref<4x4xf32> to memref> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[DIM_0]], 0] {{.}}[[DIM_1]], 3] [1, 1] : memref<4x4xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref> to memref> // CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref> to memref> // CHECK: [[VAR_18_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>