diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..32e279b4 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1845,49 +1845,181 @@ template struct MinMaxConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + template static T getOnlyUserOfType(Value val) { + if (!val || !val.hasOneUse()) { + return nullptr; + } + return dyn_cast(*val.getUsers().begin()); + } + + // Only handle the Cmp + OrIOp + Select pattern here. + static arith::SelectOp findSelectThroughOr(Value cond) { + if (auto ori = getOnlyUserOfType(cond)) { + return getOnlyUserOfType(ori.getResult()); + } + return nullptr; + } + MinMaxConverter(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/10) {} + /// Helper that maps a floating-point compare predicate to the + /// corresponding min/max operation. THis is parametrized by + /// whether we want NaN-aware operations (MaximumFOp/MinimumFOp) or + /// numeric operations (MaxNumFOp/MinNumFOp). + FailureOr foldCmpToMinMax(PatternRewriter &rewriter, Location loc, + Value lhs, Value rhs, + arith::CmpFPredicate pred, + bool useNaNOps) const { + switch (pred) { + case arith::CmpFPredicate::OGT: + case arith::CmpFPredicate::OGE: + if (useNaNOps) { + return rewriter.create(loc, lhs, rhs).getResult(); + } else { + return rewriter.create(loc, lhs, rhs).getResult(); + } + return success(); + case arith::CmpFPredicate::OLT: + case arith::CmpFPredicate::OLE: + if (useNaNOps) { + return rewriter.create(loc, lhs, rhs).getResult(); + } else { + return rewriter.create(loc, lhs, rhs).getResult(); + } + default: + return failure(); + } + } + LogicalResult matchAndRewrite(CmpOp cmpOp, PatternRewriter &rewriter) const final { - if (!cmpOp.getResult().hasOneUse()) { + Value result = cmpOp.getResult(); + if (!result.hasOneUse()) { return failure(); } - auto selectOp = - dyn_cast(*cmpOp.getResult().getUsers().begin()); + + // 1. Simple pattern: cmpf + select. + if (auto selectOp = dyn_cast(*result.getUsers().begin())) { + if (!(result == selectOp.getCondition() && + (cmpOp.getLhs() == selectOp.getTrueValue() && + cmpOp.getRhs() == selectOp.getFalseValue()))) { + return failure(); + } + + rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); + rewriter.eraseOp(cmpOp); + return success(); + } + + // 2. NaN-aware pattern: cmpf + or + select. + auto selectOp = findSelectThroughOr(result); if (!selectOp) { return failure(); } - if (!(cmpOp.getResult() == selectOp.getCondition() && - cmpOp.getLhs() == selectOp.getTrueValue() && - cmpOp.getRhs() == selectOp.getFalseValue())) { + if (failed(foldCmpSelectToMinMax(rewriter, selectOp))) { return failure(); } + return success(); + } - rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); - rewriter.eraseOp(cmpOp); + /// foldCmpSelectToMinMax performs an optimization pattern that matches + /// 'arith.select' operations based on a floating-point comparison + /// and rewrites them into equivalent numeric min/max operations. + /// + /// This pattern handles the following case: + /// + /// ** NaN-Aware Min/Max Reduction ** + /// - select (cmpf ogt a, b) || cmpf une a, a), a, b --> arith.maximumf(a, + /// b) + /// - select (cmpf olt a, b) || cmpf une a, a), a, b --> arith.minimumf(a, + /// b) + /// + /// These transformations not only improve IR canonicalization but also + /// allow the successful lowering of tt.reduce operations to linalg + /// operations, which is already supported in the triton-shared dialect + /// conversion pipeline. + + LogicalResult foldCmpSelectToMinMax(PatternRewriter &rewriter, + arith::SelectOp sel) const { + + if (!isa(sel.getType())) { + return failure(); + } - return success(); + Operation *condOp = sel.getCondition().getDefiningOp(); + if (!condOp) { + return failure(); + } + + Value trueVal = sel.getTrueValue(); + Value falseVal = sel.getFalseValue(); + + // NaN-Aware Min/Max Reduction. + auto ori = dyn_cast(condOp); + if (!ori) + return failure(); + // Extract both sides of the OR condition. + auto cmp1 = ori.getLhs().getDefiningOp(); + auto cmp2 = ori.getRhs().getDefiningOp(); + if (!cmp1 || !cmp2) + return failure(); + + // Helper lambdas to identify comparison patterns. + auto isOGT = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::OGT && + trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); + }; + auto isOLT = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::OLT && + trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); + }; + auto isNaN = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::UNE && + trueVal == cmp.getLhs() && trueVal == cmp.getRhs(); + }; + + // Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b). + if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sel); + FailureOr foldResult = foldCmpToMinMax( + rewriter, sel.getLoc(), trueVal, falseVal, arith::CmpFPredicate::OGT, + /*useNaNOps=*/true); + if (failed(foldResult)) { + return failure(); + } + rewriter.replaceOp(sel, *foldResult); + return success(); + } + + // Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b). + if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sel); + FailureOr foldResult = foldCmpToMinMax( + rewriter, sel.getLoc(), trueVal, falseVal, arith::CmpFPredicate::OLT, + /*useNaNOps=*/true); + if (failed(foldResult)) { + return failure(); + } + rewriter.replaceOp(sel, *foldResult); + return success(); + } + return failure(); } void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp, arith::SelectOp selectOp, arith::CmpFPredicate pred) const { - switch (pred) { - case arith::CmpFPredicate::OGT: - case arith::CmpFPredicate::OGE: - rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), - cmpOp.getRhs()); - break; - case arith::CmpFPredicate::OLT: - case arith::CmpFPredicate::OLE: - rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), - cmpOp.getRhs()); - break; - default: + FailureOr foldedResult = + foldCmpToMinMax(rewriter, selectOp.getLoc(), cmpOp.getLhs(), + cmpOp.getRhs(), pred, /*useNaNOps=*/true); + if (failed(foldedResult)) { llvm_unreachable("Unhandled predicate"); } + rewriter.replaceOp(selectOp, *foldedResult); } void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpIOp cmpOp, diff --git a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir index eaf30630..c02a2a81 100644 --- a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir +++ b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir @@ -123,4 +123,36 @@ module { // CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> // CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> // CHECK: return -// CHECK: } \ No newline at end of file +// CHECK: } + +// ----- + +module { + tt.func public @nan_aware_max(%arg0: tensor<1024xf32>, %arg_out: !tt.ptr) { + %res = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%lhs: f32, %rhs: f32): + %cmp_gt = arith.cmpf ogt, %lhs, %rhs : f32 + %lhs_nan = arith.cmpf une, %lhs, %lhs : f32 + %pred = arith.ori %cmp_gt, %lhs_nan : i1 + %sel = arith.select %pred, %lhs, %rhs : f32 + tt.reduce.return %sel : f32 + }) : (tensor<1024xf32>) -> f32 + tt.store %arg_out, %res : !tt.ptr + tt.return +} +} + +// CHECK-LABEL: func.func @nan_aware_max +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1024xf32>, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_nan_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.alloc_tensor() : tensor +// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_nan_]] into [[VAR_0_]][] : tensor +// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[PARAM_0_]] : tensor<1024xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] +// CHECK: ([[in_:%.+]]: f32, [[in_]]it: f32) { +// CHECK: [[CMP_gt_:%.+]] = arith.maximumf [[in_]], [[in_]]it : f32 +// CHECK: linalg.yield [[CMP_gt_]] : f32 +// CHECK: } +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor +// CHECK: tt.store [[PARAM_1_]], [[VAR_extracted_]] : !tt.ptr +// CHECK: return +// CHECK: }