From df485d7f8c4fddf0bfc7d73f4e35d25c3ff327cd Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Fri, 12 Sep 2025 18:07:36 +0800 Subject: [PATCH 1/3] Add tl.gather support --- .../ConversionPatterns.hpp | 66 +++++++++++++++++++ .../TritonArithToLinalg.cpp | 1 + .../TritonToLinalg/TritonToLinalg.cpp | 1 + python/examples/conftest.py | 1 + test/Conversion/TritonToLinalg/gather.mlir | 31 +++++++++ 5 files changed, 100 insertions(+) create mode 100644 test/Conversion/TritonToLinalg/gather.mlir diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..9154ada5 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -2162,6 +2162,72 @@ class ReshapeConverter : public OpConversionPattern { return success(); } }; +struct GatherConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + Value castIntToIndex(OpBuilder &b, Location loc, Value v) const { + return b.createOrFold(loc, b.getIndexType(), v); + } + + void createGatherPayload(OpBuilder &b, Location loc, Value input, Value index, + int64_t axis, int64_t rank) const { + SmallVector indices; + for (int i = 0; i < rank; i++) { + if (i == axis) { + indices.push_back(castIntToIndex(b, loc, index)); + } else { + indices.push_back(b.create(loc, i)); + } + } + // Assert index < input.sizes[axis] + auto dim = b.create(loc, input, axis); + auto indexOverflow = b.create( + loc, arith::CmpIPredicate::slt, castIntToIndex(b, loc, index), dim); + b.create( + loc, indexOverflow, + b.getStringAttr("index must be smaller than axis size")); + + // Assert index >= 0 + auto cst0 = + b.create(loc, b.getZeroAttr(index.getType())); + auto indexUnderflow = + b.create(loc, arith::CmpIPredicate::sge, index, cst0); + b.create( + loc, indexUnderflow, + b.getStringAttr("index must be larger or equal to 0")); + + Value extract = b.create(loc, input, indices); + b.create(loc, extract); + } + + LogicalResult + matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto indices = adaptor.getIndices(); + auto axis = op.getAxis(); + auto resultType = cast(op.getType()); + int64_t rank = resultType.getRank(); + + auto empty = rewriter + .create(loc, resultType.getShape(), + resultType.getElementType()) + .getResult(); + + SmallVector affineMaps(2, + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + rewriter.replaceOpWithNewOp( + op, resultType, indices, empty, affineMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto index = args[0]; + createGatherPayload(b, loc, src, index, axis, rank); + }); + return success(); + } +}; class ExternElementwiseBinaryOpConverter : public OpConversionPattern { diff --git a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp index 44ecd721..11de188a 100644 --- a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp +++ b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -82,6 +82,7 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); populateExternElementwiseOpToMLIROps(patterns); diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index 1c8ed9cf..4fb70d2d 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -68,6 +68,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); populateExternElementwiseOpToMLIROps(patterns); diff --git a/python/examples/conftest.py b/python/examples/conftest.py index 4bac989c..d58077ff 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -88,6 +88,7 @@ def with_allocator(): "test_trans_4d", "test_unsplat", "test_arange", + "test_gather", } annotations_tests_supported = { diff --git a/test/Conversion/TritonToLinalg/gather.mlir b/test/Conversion/TritonToLinalg/gather.mlir new file mode 100644 index 00000000..6ace6e76 --- /dev/null +++ b/test/Conversion/TritonToLinalg/gather.mlir @@ -0,0 +1,31 @@ +// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @gather_test_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<4> : tensor<8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<4x1xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %2 = arith.muli %1, %cst_0 : tensor<4x1xi32> + %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %4 = tt.broadcast %2 : tensor<4x1xi32> -> tensor<4x4xi32> + %5 = tt.broadcast %3 : tensor<1x4xi32> -> tensor<4x4xi32> + %6 = arith.addi %4, %5 : tensor<4x4xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %9 = tt.load %8 : tensor<4x4x!tt.ptr> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %12 = arith.muli %11, %cst : tensor<8x1xi32> + %13 = tt.broadcast %12 : tensor<8x1xi32> -> tensor<8x4xi32> + %14 = tt.broadcast %3 : tensor<1x4xi32> -> tensor<8x4xi32> + %15 = arith.addi %13, %14 : tensor<8x4xi32> + %16 = tt.splat %arg1 : !tt.ptr -> tensor<8x4x!tt.ptr> + %17 = tt.addptr %16, %15 : tensor<8x4x!tt.ptr>, tensor<8x4xi32> + %18 = tt.load %17 : tensor<8x4x!tt.ptr> + %19 = tt.gather %9[%18] {axis = 0 : i32} : (tensor<4x4xf32>, tensor<8x4xi64>) -> tensor<8x4xf32> + %20 = tt.splat %arg2 : !tt.ptr -> tensor<8x4x!tt.ptr> + %21 = tt.addptr %20, %15 : tensor<8x4x!tt.ptr>, tensor<8x4xi32> + tt.store %21, %19 : tensor<8x4x!tt.ptr> + tt.return + } +} From 1384e0a9e836501548e4eea32640499f36cecaeb Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Sat, 13 Sep 2025 10:46:05 +0800 Subject: [PATCH 2/3] Update tests --- python/examples/test_gather.py | 66 ++++++++++++++++++++++ test/Conversion/TritonToLinalg/gather.mlir | 40 +++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 python/examples/test_gather.py diff --git a/python/examples/test_gather.py b/python/examples/test_gather.py new file mode 100644 index 00000000..535d9bab --- /dev/null +++ b/python/examples/test_gather.py @@ -0,0 +1,66 @@ +import torch +import triton +import pytest + +import triton.language as tl + +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + +@triton.jit +def gather_test_kernel_1d(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, idx_dim0: tl.constexpr, + out_dim0: tl.constexpr): + src_offs = tl.arange(0, src_dim0) + src = tl.load(src_ptr + src_offs) + + idx_offs = tl.arange(0, idx_dim0) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = tl.arange(0, out_dim0) + tl.store(out_ptr + out_offs, out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([32], [64], 0), + ([4, 4], [8, 4], 0), + ([128, 64], [256, 64], 0), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis, device): + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + if len(src_shape) == 1: + gather_test_kernel_1d[(1, )](src, indices, output, axis, src.shape[0], indices.shape[0], output.shape[0]) + else: + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1)) + + return output + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) diff --git a/test/Conversion/TritonToLinalg/gather.mlir b/test/Conversion/TritonToLinalg/gather.mlir index 6ace6e76..d626a038 100644 --- a/test/Conversion/TritonToLinalg/gather.mlir +++ b/test/Conversion/TritonToLinalg/gather.mlir @@ -29,3 +29,43 @@ module { tt.return } } + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: func.func @gather_test_kernel( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32> {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xi64> {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32> {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[VAL_4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[VAL_5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[VAL_6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[VAL_7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[VAL_8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_10:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [4, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xf32> to memref<4x4xf32, strided<[?, 1]>> +// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<4x4xf32> +// CHECK: memref.copy %[[VAL_11]], %[[VAL_12]] : memref<4x4xf32, strided<[?, 1]>> to memref<4x4xf32> +// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<4x4xf32> to tensor<4x4xf32> +// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [8, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xi64> to memref<8x4xi64, strided<[?, 1]>> +// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<8x4xi64> +// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<8x4xi64, strided<[?, 1]>> to memref<8x4xi64> +// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<8x4xi64> to tensor<8x4xi64> +// CHECK: %[[VAL_17:.*]] = tensor.empty() : tensor<8x4xf32> +// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_16]] : tensor<8x4xi64>) outs(%[[VAL_17]] : tensor<8x4xf32>) { +// CHECK: ^bb0(%[[VAL_19:.*]]: i64, %[[VAL_20:.*]]: f32): +// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_19]] : i64 to index +// CHECK: %[[VAL_22:.*]] = linalg.index 1 : index +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : i64 to index +// CHECK: %[[VAL_24:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_10]] : index +// CHECK: cf.assert %[[VAL_24]], "index must be smaller than axis size" +// CHECK: %[[VAL_25:.*]] = arith.cmpi sge, %[[VAL_19]], %[[VAL_9]] : i64 +// CHECK: cf.assert %[[VAL_25]], "index must be larger or equal to 0" +// CHECK: %[[VAL_26:.*]] = tensor.extract %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : tensor<4x4xf32> +// CHECK: linalg.yield %[[VAL_26]] : f32 +// CHECK: } -> tensor<8x4xf32> +// CHECK: %[[VAL_27:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [8, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xf32> to memref<8x4xf32, strided<[?, 1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_18]] in writable %[[VAL_27]] : (tensor<8x4xf32>, memref<8x4xf32, strided<[?, 1]>>) -> () +// CHECK: return +// CHECK: } \ No newline at end of file From 1107149b859ebe7e2b3126762999272f5a41cabc Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Sat, 13 Sep 2025 14:25:58 +0800 Subject: [PATCH 3/3] . --- .../Conversion/TritonArithToLinalg/ConversionPatterns.hpp | 6 ++---- test/Conversion/TritonToLinalg/gather.mlir | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 9154ada5..77cbe0a3 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -2210,11 +2210,9 @@ struct GatherConverter : public OpConversionPattern { auto resultType = cast(op.getType()); int64_t rank = resultType.getRank(); - auto empty = rewriter + Value empty = rewriter .create(loc, resultType.getShape(), - resultType.getElementType()) - .getResult(); - + resultType.getElementType()); SmallVector affineMaps(2, rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( diff --git a/test/Conversion/TritonToLinalg/gather.mlir b/test/Conversion/TritonToLinalg/gather.mlir index d626a038..a9fa775b 100644 --- a/test/Conversion/TritonToLinalg/gather.mlir +++ b/test/Conversion/TritonToLinalg/gather.mlir @@ -68,4 +68,4 @@ module { // CHECK: %[[VAL_27:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [8, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xf32> to memref<8x4xf32, strided<[?, 1]>> // CHECK: bufferization.materialize_in_destination %[[VAL_18]] in writable %[[VAL_27]] : (tensor<8x4xf32>, memref<8x4xf32, strided<[?, 1]>>) -> () // CHECK: return -// CHECK: } \ No newline at end of file +// CHECK: }