diff --git a/backend/compiler.py b/backend/compiler.py index 9efdb48c..ee941bb8 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -64,6 +64,7 @@ def _ttir_to_ttsharedir(mod): subprocess_args.insert(2, "--add-llvm-debug-info") subprocess.check_call(subprocess_args) + _dump_ir_if_needed([dst_path]) return Path(dst_path).read_text() @@ -113,6 +114,7 @@ def _ttsharedir_to_llir(ttsharedir: str): "--mlir-print-debuginfo", "-o", llmlir_path]) + _dump_ir_if_needed([llmlir_path]) # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") @@ -120,7 +122,7 @@ def _ttsharedir_to_llir(ttsharedir: str): "--mlir-to-llvmir", "-o", llir_path]) - _dump_ir_if_needed([ttshared_path, llmlir_path, llir_path]) + _dump_ir_if_needed([llir_path]) return Path(llir_path).read_text() @@ -151,7 +153,7 @@ def _llir_to_bin(llir: str, metadata): sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None)) if not sanitizer_attributes_pass_path: - raise Exception(f"libSanitizerAttributes.so does not exist.") + raise Exception("libSanitizerAttributes.so does not exist.") subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path, "-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path, @@ -194,6 +196,7 @@ class CPUOptions: allow_fp8e4nv: bool = False allowed_dot_input_precisions: Tuple[str] = ("ieee", ) sanitize_overflow: bool = True + instrumentation_mode: str = "" def __post_init__(self): pass @@ -256,7 +259,7 @@ def make_ttir(mod, metadata, options): passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) passes.common.add_cse(pm) - pm.run(mod) + pm.run(mod, 'make_ttir') return mod def add_stages(self, stages, options, language): diff --git a/backend/driver.py b/backend/driver.py index dcfe9b74..5bcd9cdf 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -1,8 +1,9 @@ import hashlib import tempfile -import sysconfig -import os, subprocess, tempfile, platform +import os +import subprocess +import platform import importlib.util import sys @@ -63,7 +64,18 @@ def _ty_to_cpp(ty): "fp64": "double", }[ty] +def _flatten_signature(sig, output): + # Flatten tuples + if isinstance(sig, tuple): + for x in sig: + _flatten_signature(x, output) + else: + output.append(sig) + def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" if ty[0] == '*': return "PyObject*" if ty == "constexpr": @@ -71,6 +83,15 @@ def _extracted_type(ty): return _ty_to_cpp(ty) def _format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(_format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty == "constexpr": + return "O" + if ty.startswith("tensordesc"): + return "O" return { "PyObject*": "O", "constexpr": "O", @@ -85,15 +106,20 @@ def _format_of(ty): "uint16_t": "H", "uint32_t": "I", "uint64_t": "K", - }[ty] + }[_ty_to_cpp(ty)] def _generate_launcher(constants, signature, kernel_name): - arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()]) + args_format = ''.join([_format_of(ty) for ty in signature.values()]) format = "iiiOOOO" + args_format + + flat_signature = [] + for sig in signature.values(): + _flatten_signature(sig, flat_signature) + signature = {i: s for i, s in enumerate(flat_signature)} + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if ty != "constexpr") + kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else "int64_t, void*" for i, ty in signature.items() if ty != "constexpr") kernel_arg_decls += ', ' if kernel_arg_decls else '' kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr") @@ -327,7 +353,7 @@ def launch( libomp_path = next(Path(Path(_get_llvm_bin_path("")).parent).rglob("libomp.so"), None) if not libomp_path: - raise Exception(f"libomp.so does not exist.") + raise Exception("libomp.so does not exist.") libomp_path = str(libomp_path.parent) @@ -364,7 +390,8 @@ def __init__(self, src, metadata): kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + def cst_key(i): + return src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..c509be31 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -59,14 +59,14 @@ static Value getScalarValue(Value operand, Location loc, if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return rewriter.create(loc, resType, src); + return arith::SIToFPOp::create(rewriter, loc, resType, src); }) .Case([&](Operation *op) { auto resType = op->getResults()[0].getType(); if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return rewriter.create(loc, resType, src); + return arith::TruncFOp::create(rewriter, loc, resType, src); }) .Default([](Operation *op) { llvm_unreachable("unsupported op in generating "); @@ -134,11 +134,11 @@ static Value getTransposedValue(Value source, const Location loc, } } - Value transposeInit = rewriter.create( - loc, transposedShape, sourceType.getElementType()); + Value transposeInit = tensor::EmptyOp::create(rewriter, loc, transposedShape, + sourceType.getElementType()); Value transpose = - rewriter.create(loc, source, transposeInit, perm) + linalg::TransposeOp::create(rewriter, loc, source, transposeInit, perm) .getResults()[0]; return transpose; @@ -191,8 +191,8 @@ struct MakeTensorPtrConverter Location loc) const { for (auto opnd : ops) { if (isa(opnd.getType())) { - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), opnd); + auto castOp = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), opnd); vec.push_back(castOp.getResult()); } else { assert(isa(opnd.getType())); @@ -224,8 +224,8 @@ struct MakeTensorPtrConverter SmallVector newOffsets; for (auto [offset, stride] : llvm::zip(pointerState.offsets, pointerState.strides)) { - auto mulOp = rewriter.create(loc, cast(offset), - cast(stride)); + auto mulOp = arith::MulIOp::create(rewriter, loc, cast(offset), + cast(stride)); newOffsets.push_back(mulOp.getResult()); } @@ -278,71 +278,67 @@ struct LoadConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{zero, block1Col}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } void createStackedCopies(Value block1, Value block2, Value dst, Location loc, ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{block1Row, zero}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } public: @@ -359,8 +355,8 @@ struct LoadConverter : public OpConversionPattern { auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), adaptor.getPtr(), loc, rewriter); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - auto loadOp = rewriter.create( - op.getLoc(), sMemRef, zeroMap, ValueRange{}); + auto loadOp = affine::AffineLoadOp::create(rewriter, op.getLoc(), sMemRef, + zeroMap, ValueRange{}); rewriter.replaceOp(op, loadOp.getResult()); return success(); } @@ -378,8 +374,8 @@ struct LoadConverter : public OpConversionPattern { auto tensorType = RankedTensorType::get(type.getShape(), type.getElementType()); - auto alloc = rewriter.create( - loc, MemRefType::get(type.getShape(), type.getElementType())); + auto alloc = memref::AllocOp::create( + rewriter, loc, MemRefType::get(type.getShape(), type.getElementType())); if (!mask) { assert(!other && "other value used in non-masked load"); @@ -404,11 +400,12 @@ struct LoadConverter : public OpConversionPattern { } } else { - rewriter.create(loc, ptr, alloc); + memref::CopyOp::create(rewriter, loc, ptr, alloc); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create( + rewriter, loc, tensorType, alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -435,31 +432,33 @@ struct LoadConverter : public OpConversionPattern { // the result auto shape = type.getShape(); auto accBase = - rewriter.create(loc, rewriter.getBoolAttr(false)) + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)) .getResult(); for (size_t i = 0; i < type.getShape().size(); i++) { - auto shapei = rewriter.create( - loc, rewriter.getIndexAttr(shape[i])); + auto shapei = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(shape[i])); Value dimi = dyn_cast(mstate.dims[i]); if (!dimi) { - dimi = rewriter.create( - loc, cast(cast(mstate.dims[i]))); + dimi = arith::ConstantOp::create( + rewriter, loc, + cast(cast(mstate.dims[i]))); } - auto cmpOp = rewriter.create( - loc, arith::CmpIPredicate::slt, dimi, shapei); - accBase = rewriter.create(loc, accBase, cmpOp.getResult()) - .getResult(); + auto cmpOp = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, dimi, shapei); + accBase = + arith::OrIOp::create(rewriter, loc, accBase, cmpOp.getResult()) + .getResult(); } // condition the memset on the or-accumulation // initialize with padding prior to CopyOp - rewriter.create( - loc, accBase, [&](OpBuilder &builder, Location loc) { - builder.create(loc, ValueRange{scalarOther}, - ValueRange{alloc}); - builder.create(loc); + scf::IfOp::create( + rewriter, loc, accBase, [&](OpBuilder &builder, Location loc) { + linalg::FillOp::create(builder, loc, ValueRange{scalarOther}, + ValueRange{alloc}); + scf::YieldOp::create(builder, loc); }); } @@ -492,11 +491,12 @@ struct LoadConverter : public OpConversionPattern { } else { memref::SubViewOp srcSubview = mstate.getSubview(ptr, loc, rewriter); memref::SubViewOp dstSubview = mstate.getSubview(alloc, loc, rewriter); - rewriter.create(loc, srcSubview, dstSubview); + memref::CopyOp::create(rewriter, loc, srcSubview, dstSubview); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create(rewriter, loc, tensorType, + alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -519,16 +519,16 @@ struct StoreConverter : public OpConversionPattern { auto sMemRef = PtrAnalysis::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - rewriter.create(loc, val, sMemRef, zeroMap, - ValueRange{}); + affine::AffineStoreOp::create(rewriter, loc, val, sMemRef, zeroMap, + ValueRange{}); rewriter.eraseOp(op); return success(); } // 1. Simple case where no mask is used. if (!mask) { - auto storeOp = rewriter.create( - loc, val, ptr); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, val, ptr); storeOp.setWritable(true); rewriter.eraseOp(op); return success(); @@ -546,8 +546,8 @@ struct StoreConverter : public OpConversionPattern { auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); auto dstSubview = mstate.getSubview(ptr, loc, rewriter); - auto storeOp = rewriter.create( - loc, srcSlice, dstSubview); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, srcSlice, dstSubview); storeOp.setWritable(true); rewriter.eraseOp(op); @@ -635,13 +635,12 @@ struct SplatConverter : public OpConversionPattern { auto opType = cast(op.getType()); auto loc = op.getLoc(); - auto init = rewriter.create(loc, opType.getShape(), - opType.getElementType()); + auto init = tensor::EmptyOp::create(rewriter, loc, opType.getShape(), + opType.getElementType()); auto filledTensor = - rewriter - .create(loc, ValueRange{adaptor.getSrc()}, - ValueRange{init}) + linalg::FillOp::create(rewriter, loc, ValueRange{adaptor.getSrc()}, + ValueRange{init}) .result(); rewriter.replaceOp(op, filledTensor); @@ -697,16 +696,16 @@ struct BroadcastConverter : public OpConversionPattern { rewriter.getMultiDimIdentityMap(resultRank)); assert(op->getNumResults() == 1 && "code assumes single result!"); - auto init = rewriter.create(loc, resultType.getShape(), - elementType); + auto init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + elementType); - auto linalgOp = rewriter.create( - loc, op->getResultTypes(), ValueRange{adaptor.getSrc()}, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, op->getResultTypes(), ValueRange{adaptor.getSrc()}, ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(resultRank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value opResult = blockArgs[0]; - nestedBuilder.create(loc, opResult); + linalg::YieldOp::create(nestedBuilder, loc, opResult); }); linalgOp->setAttr("broadcastDims", @@ -740,8 +739,8 @@ struct ExpandDimsConverter : public OpConversionPattern { reassoc.push_back(g); } - auto expandShapeOp = rewriter.create( - op.getLoc(), resType, src, reassoc); + auto expandShapeOp = tensor::ExpandShapeOp::create(rewriter, op.getLoc(), + resType, src, reassoc); rewriter.replaceOp(op, expandShapeOp.getResult()); return success(); @@ -781,22 +780,22 @@ struct MakeRangeConverter : public OpConversionPattern { /* dimCount */ 1, /* symbolCount */ 0, SmallVector{mlir::getAffineDimExpr(0, context)}, context)}; - auto init = rewriter.create(loc, shape, elementType); - auto linalgOp = rewriter.create( - loc, op->getResultTypes(), /* operands */ ValueRange{}, + auto init = tensor::EmptyOp::create(rewriter, loc, shape, elementType); + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, op->getResultTypes(), /* operands */ ValueRange{}, ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(1), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { - Value index = nestedBuilder.create(loc, 0); - Value res = nestedBuilder.create( - loc, type.getElementType(), index); + Value index = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value res = arith::IndexCastOp::create(nestedBuilder, loc, + type.getElementType(), index); if (op.getStart()) { - auto start = rewriter.create( - op.getLoc(), op.getStart(), + auto start = arith::ConstantIntOp::create( + rewriter, op.getLoc(), op.getStart(), type.getElementType().getIntOrFloatBitWidth()); - res = nestedBuilder.create(loc, res, start); + res = arith::AddIOp::create(nestedBuilder, loc, res, start); } - nestedBuilder.create(loc, res); + linalg::YieldOp::create(nestedBuilder, loc, res); }); rewriter.replaceOp(op, linalgOp->getResults()); @@ -819,8 +818,7 @@ struct AssertConverter : public OpConversionPattern { // TritonOps.td. Tensors will always be RankedTensorType. if (isa(condVal.getType())) { // handle scalar case - rewriter.create(op.getLoc(), condVal, - assertMessage.str()); + cf::AssertOp::create(rewriter, op.getLoc(), condVal, assertMessage.str()); } else if (auto tensorType = dyn_cast(condVal.getType())) { // handle tensor case @@ -834,8 +832,8 @@ struct AssertConverter : public OpConversionPattern { SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - rewriter.create( - op.getLoc(), TypeRange{}, condVal, ValueRange{}, + linalg::GenericOp::create( + rewriter, op.getLoc(), TypeRange{}, condVal, ValueRange{}, ArrayRef{indexingMaps}, ArrayRef{iteratorTypes}, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -843,9 +841,9 @@ struct AssertConverter : public OpConversionPattern { Value element = args[0]; // make a cf.assert for the current element - b.create(loc, element, assertMessage.str()); + cf::AssertOp::create(b, loc, element, assertMessage.str()); - b.create(loc); + linalg::YieldOp::create(b, loc); }); } else { op.emitError("Unexpected type in triton::AssertOp"); @@ -868,8 +866,8 @@ struct BitcastConverter : public OpConversionPattern { return failure(); } - auto arithBitcast = rewriter.create( - op.getLoc(), op.getType(), op.getOperand()); + auto arithBitcast = arith::BitcastOp::create(rewriter, op.getLoc(), + op.getType(), op.getOperand()); rewriter.replaceOp(op, arithBitcast.getResult()); return success(); @@ -908,8 +906,8 @@ struct CallConverter : public OpConversionPattern { } } - auto call = rewriter.create(op.getLoc(), op.getCallee(), - op.getResultTypes(), args); + auto call = func::CallOp::create(rewriter, op.getLoc(), op.getCallee(), + op.getResultTypes(), args); if (!call) { op.emitError("Failed to create func::CallOp"); @@ -946,14 +944,14 @@ struct FpToFpConverter : public OpConversionPattern { "Not a float-like operand or result"); if (operandWidth.value() > resultWidth.value()) { - Value truncatedValue = rewriter.create( - op.getLoc(), resultType, op.getOperand()); + Value truncatedValue = arith::TruncFOp::create( + rewriter, op.getLoc(), resultType, op.getOperand()); rewriter.replaceOp(op, truncatedValue); return success(); } - Value extendedValue = rewriter.create( - op.getLoc(), resultType, op.getOperand()); + Value extendedValue = arith::ExtFOp::create(rewriter, op.getLoc(), + resultType, op.getOperand()); rewriter.replaceOp(op, extendedValue); return success(); @@ -975,11 +973,11 @@ struct ClampConverter : public OpConversionPattern { Value clamp; if (propagateNan) { - Value maxMin = rewriter.create(loc, x, min); - clamp = rewriter.create(loc, maxMin, max); + Value maxMin = arith::MaximumFOp::create(rewriter, loc, x, min); + clamp = arith::MinimumFOp::create(rewriter, loc, maxMin, max); } else { - Value maxMin = rewriter.create(loc, x, min); - clamp = rewriter.create(loc, maxMin, max); + Value maxMin = arith::MaxNumFOp::create(rewriter, loc, x, min); + clamp = arith::MinNumFOp::create(rewriter, loc, maxMin, max); } rewriter.replaceOp(op, clamp); @@ -995,7 +993,7 @@ struct PreciseSqrtConverter matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto replacement = - rewriter.create(op.getLoc(), adaptor.getOperands()); + math::SqrtOp::create(rewriter, op.getLoc(), adaptor.getOperands()); rewriter.replaceOp(op, replacement); return success(); @@ -1009,7 +1007,7 @@ struct PreciseDivConverter : public OpConversionPattern { matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto replacement = - rewriter.create(op.getLoc(), adaptor.getOperands()); + arith::DivFOp::create(rewriter, op.getLoc(), adaptor.getOperands()); rewriter.replaceOp(op, replacement); return success(); @@ -1022,8 +1020,8 @@ struct CatConverter : public OpConversionPattern { LogicalResult matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto replacement = rewriter.create( - op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); + auto replacement = tensor::ConcatOp::create( + rewriter, op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); rewriter.replaceOp(op, replacement); @@ -1060,8 +1058,8 @@ struct SplitConverter : public OpConversionPattern { offsets.push_back(rewriter.getIndexAttr(i)); sizes.push_back(rewriter.getIndexAttr(1)); - Value slice = rewriter.create( - loc, resultTensor, input, offsets, sizes, strides); + Value slice = tensor::ExtractSliceOp::create( + rewriter, loc, resultTensor, input, offsets, sizes, strides); results.push_back(slice); } @@ -1081,8 +1079,8 @@ struct JoinConverter : public OpConversionPattern { auto resultType = cast(op.getResult().getType()); auto loc = op.getLoc(); - Value result = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + Value result = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()); auto shape = resultType.getShape(); @@ -1099,8 +1097,8 @@ struct JoinConverter : public OpConversionPattern { offsets.push_back(rewriter.getIndexAttr(i)); sizes.push_back(rewriter.getIndexAttr(1)); - result = rewriter.create(loc, inputs[i], result, - offsets, sizes, strides); + result = tensor::InsertSliceOp::create(rewriter, loc, inputs[i], result, + offsets, sizes, strides); } rewriter.replaceOp(op, result); @@ -1118,7 +1116,7 @@ struct MulHiUIOpConverter : public OpConversionPattern { Location loc = op.getLoc(); auto mulResult = - rewriter.create(loc, adaptor.getOperands()); + arith::MulUIExtendedOp::create(rewriter, loc, adaptor.getOperands()); rewriter.replaceOp(op, mulResult.getHigh()); return success(); @@ -1169,29 +1167,28 @@ struct MatmulConverter : public OpConversionPattern { bool integers = elementType.isInteger(); bool skipC = isZeroTensor(opc, integers); auto init = - rewriter.create(loc, dstType.getShape(), elementType); + tensor::EmptyOp::create(rewriter, loc, dstType.getShape(), elementType); TypedAttr constantAttr = integers ? static_cast(rewriter.getIntegerAttr(elementType, 0)) : static_cast(rewriter.getFloatAttr(elementType, 0)); - auto zero = rewriter.create( - op.getLoc(), elementType, constantAttr); + auto zero = arith::ConstantOp::create(rewriter, op.getLoc(), elementType, + constantAttr); - auto zeroes = - rewriter.create(loc, ValueRange{zero}, ValueRange{init}) - .result(); + auto zeroes = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{init}) + .result(); - auto res = rewriter - .create(loc, ValueRange{opa, opb}, - ValueRange{zeroes}) + auto res = linalg::MatmulOp::create(rewriter, loc, ValueRange{opa, opb}, + ValueRange{zeroes}) .getResult(0); if (!skipC) { if (integers) { - res = rewriter.create(loc, opc, res); + res = arith::AddIOp::create(rewriter, loc, opc, res); } else { - res = rewriter.create(loc, opc, res); + res = arith::AddFOp::create(rewriter, loc, opc, res); } } @@ -1273,8 +1270,8 @@ struct ReduceConverter : public OpConversionPattern { return nullptr; }); - return rewriter.create(redOp->getLoc(), constantType, - attr); + return arith::ConstantOp::create(rewriter, redOp->getLoc(), constantType, + attr); } bool requiresF32Conversion(const Type elemType, Operation *redOp) const { @@ -1291,16 +1288,16 @@ struct ReduceConverter : public OpConversionPattern { return llvm::TypeSwitch(redOp) .Case([&](auto redOp) { if (convertLhsToF32Precision) { - lhs = b.create(loc, Float32Type::get(b.getContext()), - lhs); + lhs = arith::ExtFOp::create(b, loc, + Float32Type::get(b.getContext()), lhs); } - return b.create(loc, lhs, rhs); + return decltype(redOp)::create(b, loc, lhs, rhs); }) .Case([&](auto redOp) { - return b.create(loc, lhs, rhs); + return decltype(redOp)::create(b, loc, lhs, rhs); }) .Default([](Operation *op) { op->dump(); @@ -1379,40 +1376,40 @@ struct ReduceConverter : public OpConversionPattern { // directly instead of EmptyOp so that the subsequent pass can recognize // the patterns (EmptyOp is susceptible to being CSE'd away, making it // harder to match the patterns correctly). - initTensor = rewriter.create( - loc, RankedTensorType::get({}, constantType), ValueRange{}); - initTensor = rewriter.create(loc, accBaseConstOp, - initTensor, ValueRange{}); + initTensor = bufferization::AllocTensorOp::create( + rewriter, loc, RankedTensorType::get({}, constantType), ValueRange{}); + initTensor = tensor::InsertOp::create(rewriter, loc, accBaseConstOp, + initTensor, ValueRange{}); } else { - Value init = rewriter.create( - loc, cast(resType).getShape(), constantType); - initTensor = rewriter - .create(loc, ValueRange{accBaseConstOp}, - ValueRange{init}) - .result(); + Value init = tensor::EmptyOp::create( + rewriter, loc, cast(resType).getShape(), + constantType); + initTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{accBaseConstOp}, + ValueRange{init}) + .result(); } Value finalResult = - rewriter - .create( - loc, ValueRange{source}, ValueRange{initTensor}, - SmallVector{axis}, - [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { - assert(inputs.size() == 2); - Value result = - getRedElement(inputs[0], inputs[1], loc, rop, opBuilder, - convertToF32Precision); - opBuilder.create(loc, result); - }) + linalg::ReduceOp::create( + rewriter, loc, ValueRange{source}, ValueRange{initTensor}, + SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + assert(inputs.size() == 2); + Value result = getRedElement(inputs[0], inputs[1], loc, rop, + opBuilder, convertToF32Precision); + linalg::YieldOp::create(opBuilder, loc, result); + }) .getResult(0); if (isVectorReduce) { finalResult = - rewriter.create(loc, constantType, finalResult); + tensor::ExtractOp::create(rewriter, loc, constantType, finalResult); } if (convertToF32Precision) { - finalResult = rewriter.create(loc, resType, finalResult); + finalResult = + arith::TruncFOp::create(rewriter, loc, resType, finalResult); } rewriter.replaceOp(op, finalResult); @@ -1570,10 +1567,9 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { ArrayRef shape, Value fillValue, Location loc) const { Value initTensor = - rewriter.create(loc, shape, fillValue.getType()); - return rewriter - .create(loc, ValueRange{fillValue}, - ValueRange{initTensor}) + tensor::EmptyOp::create(rewriter, loc, shape, fillValue.getType()); + return linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{initTensor}) .result(); } @@ -1648,15 +1644,15 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { // the result value to either -inf or +inf depending on // whether we're dealing with argmax or argmin auto valueType = elemTypes[0]; - auto valuesAccBaseVal = rewriter.create( - loc, valueType, + auto valuesAccBaseVal = arith::ConstantOp::create( + rewriter, loc, valueType, rewriter.getFloatAttr(valueType, T::getBaseReductionValue())); // Set the initial value of the rank-0 tensor containing the index of the // min or max value to -1 auto indexType = elemTypes[1]; - auto indicesAccBaseVal = rewriter.create( - loc, indexType, rewriter.getIntegerAttr(indexType, -1)); + auto indicesAccBaseVal = arith::ConstantOp::create( + rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, -1)); // Get the shape of the resulting tensors (both for values and indices). If // we are reducing to a single scalar, then the result's type is a tensor of @@ -1671,8 +1667,8 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; - auto linalgOp = rewriter.create( - loc, adaptor.getOperands(), outputs, + auto linalgOp = linalg::ReduceOp::create( + rewriter, loc, adaptor.getOperands(), outputs, SmallVector{adaptor.getAxis()}, [&](OpBuilder &b, Location loc, ValueRange inputs) { assert(inputs.size() == 4); @@ -1690,15 +1686,15 @@ class ArgMinMaxBaseConverter : public OpConversionPattern { llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { return mapping.lookup(val); }); - b.create(loc, results); + linalg::YieldOp::create(b, loc, results); }); if (isScalarReduce) { SmallVector reduceResults{ - rewriter.create( - loc, valueType, linalgOp.getResults()[0], ValueRange{}), - rewriter.create( - loc, indexType, linalgOp.getResults()[1], ValueRange{})}; + tensor::ExtractOp::create(rewriter, loc, valueType, + linalgOp.getResults()[0], ValueRange{}), + tensor::ExtractOp::create(rewriter, loc, indexType, + linalgOp.getResults()[1], ValueRange{})}; rewriter.replaceOp(op, reduceResults); } else { rewriter.replaceOp(op, linalgOp); @@ -1927,8 +1923,9 @@ struct DenseConstantConverter : public OpConversionPattern { auto splatConst = arith::ConstantOp::materialize( rewriter, attr.getSplatValue(), attr.getElementType(), loc); - auto init = rewriter.create( - loc, cast(op.getResult().getType()).getShape(), + auto init = tensor::EmptyOp::create( + rewriter, loc, + cast(op.getResult().getType()).getShape(), attr.getElementType()); rewriter.replaceOpWithNewOp(op, ValueRange{splatConst}, @@ -1996,8 +1993,8 @@ class CumSumConverter : public OpConversionPattern { "= {1, 2} and axis = rank - 1"); } - Value init = rewriter.create(op.getLoc(), type.getShape(), - type.getElementType()); + Value init = tensor::EmptyOp::create(rewriter, op.getLoc(), type.getShape(), + type.getElementType()); rewriter.replaceOpWithNewOp( op, input, rewriter.getUI32IntegerAttr(axis), init); @@ -2032,7 +2029,7 @@ class AddPtrConverter : public OpConversionPattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } @@ -2063,8 +2060,8 @@ class TensorOpConverter : public OpConversionPattern { rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - SmallVector outputs = {rewriter.create( - op->getLoc(), resultTensorType.getShape(), + SmallVector outputs = {tensor::EmptyOp::create( + rewriter, op->getLoc(), resultTensorType.getShape(), resultTensorType.getElementType())}; rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps, @@ -2078,7 +2075,7 @@ class TensorOpConverter : public OpConversionPattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } @@ -2116,7 +2113,7 @@ class StorePtrToLinalgConverter : public OpConversionPattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } @@ -2154,8 +2151,8 @@ class ReshapeConverter : public OpConversionPattern { ArrayRef outputShape = outputType.getShape(); - auto shape = rewriter.create( - loc, rewriter.getI64TensorAttr(outputShape)); + auto shape = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64TensorAttr(outputShape)); rewriter.replaceOpWithNewOp(op, outputType, input, shape); diff --git a/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td b/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td index 04d894ff..3aa7245a 100644 --- a/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td +++ b/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td @@ -93,122 +93,4 @@ def TPTR_PtrToIntOp : TPTR_Op<"ptrtoint", [ let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; } -def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> { - let summary = "Creates a type offset constant."; - let description = [{ - The `addr.type_offset` operation produces an int or index-typed SSA value - equal to a target-specific constant representing the offset of a single - element of the given type. The default return type is `index`. - Example: - - ```mlir - %0 = addr.type_offset f32 - %1 = addr.type_offset memref<12 x f64> : i32 - ``` - }]; - - let arguments = (ins TypeAttr:$baseType); - let results = (outs AnySignlessIntegerOrIndex:$result); - let assemblyFormat = [{ - attr-dict $baseType custom(type($result)) - }]; - let hasFolder = 1; -} - -def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs Ptr_PtrType:$result); - let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; -} - -def TPTR_ToMemrefOp : TPTR_Op<"to_memref", [ - Pure ]> { - let arguments = (ins Ptr_PtrType:$arg); - let results = (outs AnyStaticShapeMemRef:$res); - let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; -} - -def TPTR_PtrAddOp : TPTR_Op<"ptradd", [Pure, AllTypesMatch<["base", "result"]>]> { - let summary = "Pointer-index add operation"; - let description = [{ - The `ptradd` operation adds an `address` and an integer or index to - produce a new address. - - Example: - ```mlir - %addr = ptr.ptradd %addr : !ptr.ptr<3 : i32>, %c10 : i32 - ``` - }]; - - let arguments = (ins Ptr_PtrType:$base, AnySignlessIntegerOrIndex:$offset); - let results = (outs Ptr_PtrType:$result); - let assemblyFormat = "$base $offset attr-dict `:` type($base) `,` type($offset) `to` type($result)"; -} - -def TPTR_LoadOp : TPTR_Op<"load", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Load operation"; - let description = [{ - The `load` operation is used to read from memory. A load may be marked as - atomic, volatile, and/or nontemporal, and takes a number of optional - attributes that specify aliasing information. - - An atomic load only supports a limited set of pointer, integer, and - floating point types, and requires an explicit alignment. - - Examples: - ```mlir - // A volatile load of a float variable. - %0 = ptr.load volatile %ptr : !ptr.ptr -> f32 - - // A nontemporal load of a float variable. - %0 = ptr.load %ptr {nontemporal} : !ptr.ptr -> f32 - - // An atomic load of an integer variable. - %0 = ptr.load %ptr atomic monotonic {alignment = 8 : i64} - : !ptr.ptr -> i64 - ``` - }]; - let arguments = (ins AnyType:$addr); - let results = (outs AnyType:$res); - let assemblyFormat = [{ - $addr - attr-dict `:` qualified(type($addr)) `->` type($res) - }]; -} - -def TTPTR_StoreOp : TPTR_Op<"store", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Store operation"; - let description = [{ - The `store` operation is used to write to memory. A store may be marked as - atomic, volatile, and/or nontemporal, and takes a number of optional - attributes that specify aliasing information. - - An atomic store only supports a limited set of pointer, integer, and - floating point types, and requires an explicit alignment. - - Examples: - ```mlir - // A volatile store of a float variable. - ptr.store volatile %val, %ptr : f32, !ptr.ptr - - // A nontemporal store of a float variable. - ptr.store %val, %ptr {nontemporal} : f32, !ptr.ptr - - // An atomic store of an integer variable. - ptr.store %val, %ptr atomic monotonic {alignment = 8 : i64} - : i64, !ptr.ptr - ``` - }]; - let arguments = (ins AnyType:$value, - AnyType:$addr); - let assemblyFormat = [{ - $value `,` $addr - attr-dict `:` type($value) `,` qualified(type($addr)) - }]; -} - #endif // TPTR_DIALECT diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 1aa9760f..2fcdded5 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -68,8 +68,8 @@ tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); - return builder.create(loc, dstType, source, offsets, - dims, strides); + return tensor::ExtractSliceOp::create(builder, loc, dstType, source, offsets, + dims, strides); } memref::SubViewOp MaskState::getSubview(Value source, const Location loc, @@ -80,8 +80,8 @@ memref::SubViewOp MaskState::getSubview(Value source, const Location loc, auto dstType = memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); - return builder.create(loc, cast(dstType), - source, offsets, dims, strides); + return memref::SubViewOp::create(builder, loc, cast(dstType), + source, offsets, dims, strides); } static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, @@ -91,8 +91,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, auto srcType = cast(src.getType()); auto dstType = memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return b.create(loc, cast(dstType), src, - offsets, sizes, strides); + return memref::SubViewOp::create(b, loc, cast(dstType), src, + offsets, sizes, strides); } // Assume block1 wraps around and the remainder is block2. @@ -155,7 +155,8 @@ MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, OpBuilder &builder) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; - OpFoldResult col1 = builder.create(loc, block1, 1).getResult(); + OpFoldResult col1 = + memref::DimOp::create(builder, loc, block1, 1).getResult(); OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, builder); OpFoldResult subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, builder); @@ -174,7 +175,8 @@ MaskState::getStackedSubviews(Value block1, Value block2, const Location loc, OpBuilder &builder) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; - OpFoldResult row1 = builder.create(loc, block1, 0).getResult(); + OpFoldResult row1 = + memref::DimOp::create(builder, loc, block1, 0).getResult(); OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder); OpFoldResult subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, builder); @@ -314,8 +316,8 @@ LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, if (scalar.getType().isInteger(1)) { this->scalar = scalar; } else { - auto castOp = - builder.create(loc, builder.getIndexType(), scalar); + auto castOp = arith::IndexCastOp::create(builder, loc, + builder.getIndexType(), scalar); this->scalar = castOp.getResult(); } return success(); @@ -407,18 +409,17 @@ LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, } auto targetTensorType = RankedTensorType::get({size}, builder.getI32Type()); - Value range = - builder - .create(loc, targetTensorType, 0, size) - .getResult(); + Value range = triton::MakeRangeOp::create(builder, loc, + targetTensorType, 0, size) + .getResult(); Value v = ofrToIndexValue(ofr, loc, builder); - v = builder - .create(loc, builder.getI32Type(), v) + v = arith::IndexCastUIOp::create(builder, loc, builder.getI32Type(), + v) .getResult(); - v = builder.create(loc, targetTensorType, v) + v = triton::SplatOp::create(builder, loc, targetTensorType, v) .getResult(); - return builder - .create(loc, arith::CmpIPredicate::ult, range, v) + return arith::CmpIOp::create(builder, loc, + arith::CmpIPredicate::ult, range, v) .getResult(); }; if (!lhsV) { @@ -436,7 +437,7 @@ LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, continue; } // And the mask. - masks.push_back(builder.create(loc, lhsV, rhsV)); + masks.push_back(arith::AndIOp::create(builder, loc, lhsV, rhsV)); } } // Only support one unstructured mask. @@ -494,8 +495,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, SmallVector reassociation = *maybeReassociationMap; // Set masks. - unstructuredMask = builder.create( - loc, flatType, cmpOp, reassociation); + unstructuredMask = tensor::CollapseShapeOp::create( + builder, loc, flatType, cmpOp, reassociation); } masks[cmpOpDim] = unstructuredMask; } diff --git a/lib/Analysis/OpFoldResultUtils.cpp b/lib/Analysis/OpFoldResultUtils.cpp index 49ae4c44..85d7842f 100644 --- a/lib/Analysis/OpFoldResultUtils.cpp +++ b/lib/Analysis/OpFoldResultUtils.cpp @@ -59,7 +59,7 @@ Value ofrToValue(const OpFoldResult ofr, const Location loc, OpBuilder &b) { auto attr = dyn_cast(ofr); auto typedAttr = dyn_cast(attr); - return b.create(loc, typedAttr); + return arith::ConstantOp::create(b, loc, typedAttr); } Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, @@ -67,14 +67,14 @@ Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, if (Value val = dyn_cast(ofr)) { assert(val.getType().isIntOrIndex()); if (!val.getType().isIndex()) { - val = b.create(loc, b.getIndexType(), val); + val = arith::IndexCastOp::create(b, loc, b.getIndexType(), val); } return val; } auto intVal = getIntAttr(ofr); if (intVal.has_value()) { - return b.create(loc, b.getIndexAttr(intVal.value())); + return arith::ConstantOp::create(b, loc, b.getIndexAttr(intVal.value())); } llvm_unreachable("Unexpected OpFoldResult state"); return nullptr; @@ -93,14 +93,14 @@ Value indexTypeCast(Value v, Type targetTy, const Location loc, OpBuilder &b) { if (isa(targetTy) || isa(ty)) { assert((isa(targetTy) || isa(ty)) && "Only cast between index type and integer type"); - return b.create(loc, targetTy, v).getResult(); + return arith::IndexCastOp::create(b, loc, targetTy, v).getResult(); } else { auto targetIntTy = cast(targetTy); auto intTy = cast(ty); if (targetIntTy.getWidth() > intTy.getWidth()) - return b.create(loc, targetTy, v).getResult(); + return arith::ExtSIOp::create(b, loc, targetTy, v).getResult(); else - return b.create(loc, targetTy, v).getResult(); + return arith::TruncIOp::create(b, loc, targetTy, v).getResult(); } } @@ -114,8 +114,8 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy, Value v = dyn_cast(ofr); if (!v) - v = b.create(loc, - cast(cast(ofr))); + v = arith::ConstantOp::create(b, loc, + cast(cast(ofr))); Type ty = v.getType(); if (targetTy == ty) @@ -127,7 +127,7 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy, // cast to target element type first. if (targetEltTy != ty) v = indexTypeCast(v, targetEltTy, loc, b); - return b.create(loc, targetTy, v).getResult(); + return triton::SplatOp::create(b, loc, targetTy, v).getResult(); } else if (targetShapedTy && shapedTy) { Type targetEltTy = targetShapedTy.getElementType(); Type eltTy = shapedTy.getElementType(); @@ -148,26 +148,26 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy, SmallVector shapeValues; for (auto dim : targetShapedTy.getShape()) { shapeValues.push_back( - b.create(loc, b.getIndexAttr(dim))); + arith::ConstantOp::create(b, loc, b.getIndexAttr(dim))); } RankedTensorType targetShapeTensorTy = RankedTensorType::get( targetShapedTy.getShape().size(), b.getIndexType()); - auto shapeTensor = b.create( - loc, targetShapeTensorTy, shapeValues); - return b.create(loc, targetTy, v, shapeTensor) + auto shapeTensor = tensor::FromElementsOp::create( + b, loc, targetShapeTensorTy, shapeValues); + return triton::ReshapeOp::create(b, loc, targetTy, v, shapeTensor) .getResult(); } if (isa(targetEltTy) || isa(eltTy)) { assert((isa(targetEltTy) || isa(eltTy)) && "Only cast between index type and integer type"); - return b.create(loc, targetTy, v).getResult(); + return arith::IndexCastOp::create(b, loc, targetTy, v).getResult(); } else { auto targetIntTy = cast(targetEltTy); auto intTy = cast(eltTy); if (targetIntTy.getWidth() > intTy.getWidth()) - return b.create(loc, targetTy, v).getResult(); + return arith::ExtSIOp::create(b, loc, targetTy, v).getResult(); else - return b.create(loc, targetTy, v).getResult(); + return arith::TruncIOp::create(b, loc, targetTy, v).getResult(); } } else { assert(!shapedTy && "src type rank should be >= target type rank"); @@ -194,18 +194,18 @@ OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - return b.create(loc, lhsValue, rhsValue).getResult(); + return arith::AddIOp::create(b, loc, lhsValue, rhsValue).getResult(); } OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, @@ -225,18 +225,18 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto sumOp = b.create(loc, lhsValue, rhsValue); + auto sumOp = arith::SubIOp::create(b, loc, lhsValue, rhsValue); return sumOp.getResult(); } @@ -280,17 +280,17 @@ OpFoldResult mulOFRs(const OpFoldResult lhs, const OpFoldResult rhs, // otherwise, need to create instructions to calculate new attribute value if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - return b.create(loc, lhsValue, rhsValue).getResult(); + return arith::MulIOp::create(b, loc, lhsValue, rhsValue).getResult(); } OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, @@ -306,18 +306,18 @@ OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto minOp = b.create(loc, lhsValue, rhsValue); + auto minOp = arith::MinSIOp::create(b, loc, lhsValue, rhsValue); return minOp.getResult(); } @@ -334,18 +334,18 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = dyn_cast(lhs); if (lhsIntAttr) { auto lhsOp = - b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(lhsIntAttr.value())); lhsValue = lhsOp.getResult(); } auto rhsValue = dyn_cast(rhs); if (rhsIntAttr) { auto rhsOp = - b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + arith::ConstantOp::create(b, loc, b.getIndexAttr(rhsIntAttr.value())); rhsValue = rhsOp.getResult(); } - auto maxOp = b.create(loc, lhsValue, rhsValue); + auto maxOp = arith::MaxSIOp::create(b, loc, lhsValue, rhsValue); return maxOp.getResult(); } @@ -359,7 +359,7 @@ OpFoldResult selectOFRs(const OpFoldResult condOFR, const OpFoldResult trueOFR, "Condition for selectOp must be a bool type"); auto selectOp = - b.create(loc, condValue, trueValue, falseValue); + arith::SelectOp::create(b, loc, condValue, trueValue, falseValue); return selectOp.getResult(); } @@ -398,7 +398,7 @@ OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs, auto lhsValue = ofrToIndexValue(lhs, loc, b); auto rhsValue = ofrToIndexValue(rhs, loc, b); - auto cmpOp = b.create(loc, pred, lhsValue, rhsValue); + auto cmpOp = arith::CmpIOp::create(b, loc, pred, lhsValue, rhsValue); return selectOFRs(cmpOp.getResult(), trueOFR, falseOFR, loc, b); } diff --git a/lib/Analysis/PtrAnalysis.cpp b/lib/Analysis/PtrAnalysis.cpp index 2fc2f85e..1a175670 100644 --- a/lib/Analysis/PtrAnalysis.cpp +++ b/lib/Analysis/PtrAnalysis.cpp @@ -82,7 +82,7 @@ void PtrState::addState(const PtrState &lhsState, const PtrState &rhsState, if (lhsState.scalar && rhsState.scalar) { auto addOp = - rewriter.create(loc, lhsState.scalar, rhsState.scalar); + arith::AddIOp::create(rewriter, loc, lhsState.scalar, rhsState.scalar); scalar = addOp.getResult(); } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; @@ -210,28 +210,28 @@ PtrState::createStackedCastOps(ArrayRef resultShape, Value strideRow = ofrToIndexValue(strides[0], loc, rewriter); Value strideCol = ofrToIndexValue(strides[1], loc, rewriter); - Value modRow = rewriter.create( - loc, rewriter.getIndexType(), modulos[0]->size); + Value modRow = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), modulos[0]->size); // First chunk Value wrappedAroundOff = - rewriter.create(loc, targetOffset, strideRow); - Value clampedOff = rewriter.create(loc, modRow, strideRow); + arith::RemSIOp::create(rewriter, loc, targetOffset, strideRow); + Value clampedOff = arith::MulIOp::create(rewriter, loc, modRow, strideRow); clampedOff = - rewriter.create(loc, clampedOff, wrappedAroundOff); - Value d1 = rewriter.create(loc, clampedOff, targetOffset); - d1 = rewriter.create(loc, d1, strideRow); + arith::AddIOp::create(rewriter, loc, clampedOff, wrappedAroundOff); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOff, targetOffset); + d1 = arith::DivSIOp::create(rewriter, loc, d1, strideRow); SmallVector sizes1{d1, colSize}; - memref::ReinterpretCastOp cast1 = rewriter.create( - loc, resultType, source, targetOffset, sizes1, + memref::ReinterpretCastOp cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, source, targetOffset, sizes1, ValueRange{strideRow, strideCol}); // Second chunk - Value d2 = rewriter.create(loc, rowSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, rowSize, d1); SmallVector sizes2{d2, colSize}; - memref::ReinterpretCastOp cast2 = rewriter.create( - loc, resultType, source, wrappedAroundOff, sizes2, + memref::ReinterpretCastOp cast2 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, source, wrappedAroundOff, sizes2, ValueRange{strideRow, strideCol}); return {cast1, cast2}; @@ -299,29 +299,29 @@ PtrState::createSideBySideCastOps(ArrayRef resultShape, Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); - Value modN = rewriter.create(loc, rewriter.getIndexType(), - modulos[1]->size); + Value modN = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), modulos[1]->size); - Value x = rewriter.create(loc, targetOffset, modN); - Value y = rewriter.create(loc, targetOffset, x); + Value x = arith::RemSIOp::create(rewriter, loc, targetOffset, modN); + Value y = arith::SubIOp::create(rewriter, loc, targetOffset, x); SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); // First chunk - Value nextOffset = rewriter.create(loc, x, colSize); - Value clampedOffset = rewriter.create(loc, nextOffset, modN); - Value d1 = rewriter.create(loc, clampedOffset, x); + Value nextOffset = arith::AddIOp::create(rewriter, loc, x, colSize); + Value clampedOffset = arith::MinSIOp::create(rewriter, loc, nextOffset, modN); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOffset, x); SmallVector sizes1{rowSize, d1}; - auto cast1 = rewriter.create( - loc, resultType, source, targetOffset, sizes1, strideVals); + auto cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, source, targetOffset, sizes1, strideVals); // Second chunk - Value d2 = rewriter.create(loc, colSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, colSize, d1); SmallVector sizes2{rowSize, d2}; - auto cast2 = rewriter.create( - loc, resultType, source, y, sizes2, strideVals); + auto cast2 = memref::ReinterpretCastOp::create(rewriter, loc, resultType, + source, y, sizes2, strideVals); return {cast1, cast2}; } @@ -341,8 +341,8 @@ PtrState::createCastOp(ArrayRef resultShape, const Location loc, getResultMemrefType(rewriter.getContext(), staticOffset[0], resultShape); // Create reinterpret cast - return rewriter.create( - loc, resultType, source, targetOffset, sizes, strides); + return memref::ReinterpretCastOp::create(rewriter, loc, resultType, source, + targetOffset, sizes, strides); } void PtrAnalysis::visitOperandAdd( @@ -648,8 +648,8 @@ void PtrAnalysis::visitOperand( } if (isa(operand.getType())) { - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), operand); + auto castOp = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), operand); state.scalar = castOp.getResult(); return; } @@ -805,10 +805,10 @@ void PtrAnalysis::rewriteAddptrOp( rewriter.getContext(), ShapedType::kDynamic, resultShape); UnrealizedConversionCastOp combinedCast = - rewriter.create( - op.getLoc(), resultType, - ValueRange{casts[0].getResult(), casts[1].getResult(), - op.getResult()}); + UnrealizedConversionCastOp::create(rewriter, op.getLoc(), resultType, + ValueRange{casts[0].getResult(), + casts[1].getResult(), + op.getResult()}); combinedCast->setAttr(ModuloState::WraparoundAttr, rewriter.getStringAttr(type)); @@ -858,18 +858,18 @@ void PtrAnalysis::rewriteAdvanceOp( llvm::zip(incrementOffsets, ptrState.offsets, ptrState.strides)) { Value offsetValue; if (auto offsetIntAttr = getIntAttr(offset)) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); + auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(0)); offsetValue = constOp.getResult(); } else { offsetValue = cast(offset); } - auto castOp = rewriter.create( - loc, rewriter.getIndexType(), increment); - auto mulOp = rewriter.create(loc, castOp.getResult(), - cast(stride)); + auto castOp = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), increment); + auto mulOp = arith::MulIOp::create(rewriter, loc, castOp.getResult(), + cast(stride)); auto addOp = - rewriter.create(loc, mulOp.getResult(), offsetValue); + arith::AddIOp::create(rewriter, loc, mulOp.getResult(), offsetValue); newOffsets.push_back(addOp.getResult()); } @@ -995,8 +995,8 @@ void PtrAnalysis::rewriteYieldOp( // zeroes. if (auto sIntAttr = getIntAttr(s)) { assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); + auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(0)); operands.push_back(constOp.getResult()); } else { operands.push_back(cast(s)); @@ -1166,8 +1166,8 @@ void PtrAnalysis::rewriteForOp( for (auto [j, s] : llvm::enumerate(state.offsets)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); newInitArgs.push_back(constOp.getResult()); state.offsets[j] = constOp.getResult(); } else { @@ -1178,8 +1178,8 @@ void PtrAnalysis::rewriteForOp( for (auto [j, s] : llvm::enumerate(state.strides)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); newInitArgs.push_back(constOp.getResult()); state.strides[j] = constOp.getResult(); } else { @@ -1240,9 +1240,10 @@ void PtrAnalysis::rewriteForOp( rewriter.restoreInsertionPoint(origIp); // Create a new scf::ForOp that uses updated init args and same loop body - auto newOp = rewriter.create( - op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), - newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + auto newOp = scf::ForOp::create( + rewriter, op.getLoc(), op.getLowerBound(), op.getUpperBound(), + op.getStep(), newInitArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { IRMapping mapping; mapping.map(op.getInductionVar(), iv); mapping.map(op.getInitArgs(), newInitArgs); @@ -1266,7 +1267,7 @@ void PtrAnalysis::rewriteForOp( b.setInsertionPointToStart(b.getBlock()); Value zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { SmallVector newReinterpretCasts; @@ -1274,9 +1275,9 @@ void PtrAnalysis::rewriteForOp( newReinterpretCasts.push_back(args[chunk.initArgIndex]); } - auto combinedCast = b.create( - loc, unrealizedCastOp.getResult(0).getType(), newReinterpretCasts, - unrealizedCastOp->getAttrs()); + auto combinedCast = UnrealizedConversionCastOp::create( + b, loc, unrealizedCastOp.getResult(0).getType(), + newReinterpretCasts, unrealizedCastOp->getAttrs()); args[chunkData[0].initArgIndex].replaceUsesWithIf( combinedCast.getResult(0), [](OpOperand &operand) { diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index 2253c807..e5b6ece3 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -63,29 +63,26 @@ static Value applyUnstructuredMask(Operation *op, Value ptr, return nullptr; } - ptr = - builder - .create( - loc, gatherScatterPtr.getBase(), - gatherScatterPtr.getGatherScatterOffset(), unstructuredMask, - gatherScatterPtr.getGatherScatterDim(), - gatherScatterPtr.getSizes(), gatherScatterPtr.getMixedStrides(), - gatherScatterPtr.getMixedOffsets()) - .getResult(); + ptr = tts::MakeGatherScatterTensorPtrOp::create( + builder, loc, gatherScatterPtr.getBase(), + gatherScatterPtr.getGatherScatterOffset(), unstructuredMask, + gatherScatterPtr.getGatherScatterDim(), + gatherScatterPtr.getSizes(), gatherScatterPtr.getMixedStrides(), + gatherScatterPtr.getMixedOffsets()) + .getResult(); } else if (auto structuredPtr = ptr.getDefiningOp()) { auto ofrToI32Value = [&](OpFoldResult ofr) { Value v = dyn_cast(ofr); if (!v) { - v = builder - .create( - loc, cast(cast(ofr))) + v = arith::ConstantOp::create(builder, loc, + cast(cast(ofr))) .getResult(); } if (isa(v.getType())) { - v = builder.create(loc, builder.getI32Type(), v) + v = arith::IndexCastOp::create(builder, loc, builder.getI32Type(), v) .getResult(); } else if (v.getType().isInteger(64)) { - v = builder.create(loc, builder.getI32Type(), v) + v = arith::TruncIOp::create(builder, loc, builder.getI32Type(), v) .getResult(); } @@ -100,22 +97,20 @@ static Value applyUnstructuredMask(Operation *op, Value ptr, // Divide stride since offset of tts::MakeTensorPtrOp already include the // stride, but gatherScatterOffset of tts::MakeGatherScatterTensorPtrOp // should not include stride. - offset = builder.create(loc, offset, stride); + offset = arith::DivUIOp::create(builder, loc, offset, stride); Value gatherScatterOffset = - builder.create(loc, offsetRowType, offset).getResult(); - Value range = builder - .create( - loc, offsetRowType, 0, structuredPtr.getSizes()[dim]) + tensor::SplatOp::create(builder, loc, offsetRowType, offset) + .getResult(); + Value range = triton::MakeRangeOp::create(builder, loc, offsetRowType, 0, + structuredPtr.getSizes()[dim]) .getResult(); gatherScatterOffset = - builder.create(loc, gatherScatterOffset, range); - ptr = builder - .create( - loc, structuredPtr.getBase(), gatherScatterOffset, - unstructuredMask, dim, structuredPtr.getSizes(), - structuredPtr.getMixedStrides(), - structuredPtr.getMixedOffsets()) + arith::AddIOp::create(builder, loc, gatherScatterOffset, range); + ptr = tts::MakeGatherScatterTensorPtrOp::create( + builder, loc, structuredPtr.getBase(), gatherScatterOffset, + unstructuredMask, dim, structuredPtr.getSizes(), + structuredPtr.getMixedStrides(), structuredPtr.getMixedOffsets()) .getResult(); } else { return nullptr; @@ -291,7 +286,7 @@ LogicalResult PtrState::addState(const PtrState &lhsState, if (lhsState.scalar && rhsState.scalar) { auto addOp = - builder.create(loc, lhsState.scalar, rhsState.scalar); + arith::AddIOp::create(builder, loc, lhsState.scalar, rhsState.scalar); scalar = addOp.getResult(); } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; @@ -569,7 +564,7 @@ LogicalResult PtrState::mulState(const PtrState &lhsState, if (lhsState.scalar && rhsState.scalar) { scalar = - builder.create(loc, lhsState.scalar, rhsState.scalar); + arith::MulIOp::create(builder, loc, lhsState.scalar, rhsState.scalar); } auto indexTy = IndexType::get(op->getContext()); @@ -662,8 +657,8 @@ tts::MakeTensorPtrOp PtrState::createTTSMakeTensorPtrOp(OpBuilder &builder, staticSizes.push_back(s.value()); } - auto op = builder.create( - loc, source, staticSizes, strides, offsets, shape, order); + auto op = tts::MakeTensorPtrOp::create(builder, loc, source, staticSizes, + strides, offsets, shape, order); LLVM_DEBUG({ llvm::dbgs() << "creating tts::make_tensor_ptr:\n"; op->dump(); @@ -700,16 +695,15 @@ PtrState::createTTSMakeGatherScatterTensorPtrOp(OpBuilder &builder, auto collapseTy = RankedTensorType::get({offsetSize}, offsetTy.getElementType()); nonContinuousOffset = - builder - .create( - loc, collapseTy, nonContinuousOffset, reassociationMap) + tensor::CollapseShapeOp::create(builder, loc, collapseTy, + nonContinuousOffset, reassociationMap) .getResult(); offsets[nonContinuousDim] = nonContinuousOffset; } // Generate tts::make_gather_scatter_tensor_ptr. - auto op = builder.create( - loc, source, nonContinuousOffset, nonContinuousDim, staticSizes, strides, - offsets); + auto op = tts::MakeGatherScatterTensorPtrOp::create( + builder, loc, source, nonContinuousOffset, nonContinuousDim, staticSizes, + strides, offsets); LLVM_DEBUG({ llvm::dbgs() << "creating tts::make_gather_scatter_tensor_ptr:\n"; op->dump(); @@ -1135,19 +1129,19 @@ PtrAnalysis::visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, for (int64_t i = 0; i < pointeeType.getRank(); i++) { state.sizes.push_back(builder.getIndexAttr(shape[i])); - auto strideCst = builder.create( - loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); + auto strideCst = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); state.strides.push_back(strideCst.getResult()); - auto offsetCst = builder.create( - loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); + auto offsetCst = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); - auto scaledOffset = builder.create( - loc, offsetCst.getResult(), strideCst.getResult()); + auto scaledOffset = arith::MulIOp::create( + builder, loc, offsetCst.getResult(), strideCst.getResult()); state.offsets.push_back(scaledOffset.getResult()); - auto shapeCst = builder.create( - loc, builder.getIndexType(), makeTPtrOp.getShape()[i]); + auto shapeCst = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), makeTPtrOp.getShape()[i]); state.shape.push_back(shapeCst.getResult()); } state.order = SmallVector(makeTPtrOp.getOrder()); @@ -1219,8 +1213,8 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, if (!isa(operand) && operand.getDefiningOp()) { builder.setInsertionPointAfter(operand.getDefiningOp()); } - auto castOp = builder.create( - loc, builder.getIndexType(), operand); + auto castOp = arith::IndexCastOp::create(builder, loc, + builder.getIndexType(), operand); state.scalar = castOp.getResult(); return success(); } else if (isa(operand.getType())) { @@ -1390,18 +1384,18 @@ LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) { llvm::zip(incrementOffsets, state.offsets, state.strides)) { Value offsetValue; if (auto offsetIntAttr = getIntAttr(offset)) { - auto constOp = builder.create( - loc, builder.getIndexAttr(offsetIntAttr.value())); + auto constOp = arith::ConstantOp::create( + builder, loc, builder.getIndexAttr(offsetIntAttr.value())); offsetValue = constOp.getResult(); } else { offsetValue = cast(offset); } - auto castOp = builder.create( - loc, builder.getIndexType(), increment); - auto mulOp = builder.create(loc, castOp.getResult(), - cast(stride)); + auto castOp = arith::IndexCastOp::create(builder, loc, + builder.getIndexType(), increment); + auto mulOp = arith::MulIOp::create(builder, loc, castOp.getResult(), + cast(stride)); auto addOp = - builder.create(loc, mulOp.getResult(), offsetValue); + arith::AddIOp::create(builder, loc, mulOp.getResult(), offsetValue); newOffsets.push_back(addOp.getResult()); } @@ -1637,15 +1631,15 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { // This operand is a pointer directly from the kernel arguments. // Use offset 0. assert(!tritonValue.getDefiningOp()); - replacements.push_back(builder.create( - op.getLoc(), builder.getIndexAttr(0))); + replacements.push_back(arith::ConstantOp::create( + builder, op.getLoc(), builder.getIndexAttr(0))); } } else { for (auto [j, s] : llvm::enumerate(state.offsets)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = builder.create( - op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + builder, op.getLoc(), builder.getIndexAttr(sIntAttr.value())); replacements.push_back(constOp.getResult()); } else { replacements.push_back(cast(s)); @@ -1655,8 +1649,8 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { for (auto [j, s] : llvm::enumerate(state.strides)) { auto sIntAttr = getIntAttr(s); if (sIntAttr) { - auto constOp = builder.create( - op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + auto constOp = arith::ConstantOp::create( + builder, op.getLoc(), builder.getIndexAttr(sIntAttr.value())); replacements.push_back(constOp.getResult()); } else { replacements.push_back(cast(s)); @@ -1720,7 +1714,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, } } - auto loadOp = builder.create(loc, ptr, dims, scalarOther); + auto loadOp = tts::LoadOp::create(builder, loc, ptr, dims, scalarOther); LLVM_DEBUG({ llvm::dbgs() << "creating tts::load:\n"; @@ -1850,7 +1844,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, dims = mstate.dims; } - auto storeOp = builder.create(loc, ptr, val, dims); + auto storeOp = tts::StoreOp::create(builder, loc, ptr, val, dims); LLVM_DEBUG({ llvm::dbgs() << "creating tts::store:\n"; diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index 10240298..9192e32e 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -56,8 +56,8 @@ static memref::SubViewOp getSubview(int rank, ArrayRef dims, auto dstType = memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); - return b.create(loc, cast(dstType), source, - offsets, dims, strides); + return memref::SubViewOp::create(b, loc, cast(dstType), source, + offsets, dims, strides); } static Type getElementTypeStructuredPtr(tts::MakeTensorPtrOp op) { @@ -182,8 +182,9 @@ static Value rewriteGatherScatterPtrElement( SmallVector dynSizes; // sizes are always static auto sizes = mlir::getMixedValues(staticSizes, dynSizes, rewriter); - auto castOp = rewriter.create( - op.getLoc(), resultType, basePtr, targetOffset, sizes, mixedStrides); + auto castOp = memref::ReinterpretCastOp::create( + rewriter, op.getLoc(), resultType, basePtr, targetOffset, sizes, + mixedStrides); return castOp.getResult(); } @@ -198,28 +199,28 @@ static void fillWithValue(Location loc, Value alloc, Value other, // For each dimension check if dims[i] < shape[i], or-accumulate // the result auto accBase = - rewriter.create(loc, rewriter.getBoolAttr(false)) + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)) .getResult(); for (size_t i = 0; i < shape.size(); i++) { - auto shapei = rewriter.create( - loc, rewriter.getIndexAttr(shape[i])); + auto shapei = arith::ConstantOp::create(rewriter, loc, + rewriter.getIndexAttr(shape[i])); Value dimi = dyn_cast(mixedDims[i]); if (!dimi) { - dimi = rewriter.create( - loc, rewriter.getIndexAttr(staticMaskDims[i])); + dimi = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(staticMaskDims[i])); } - Value cmp = rewriter.create(loc, arith::CmpIPredicate::slt, - dimi, shapei); - accBase = rewriter.create(loc, accBase, cmp); + Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + dimi, shapei); + accBase = arith::OrIOp::create(rewriter, loc, accBase, cmp); } // condition the memset on the or-accumulation // initialize with padding prior to CopyOp - rewriter.create(loc, accBase, [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{other}, ValueRange{alloc}); - b.create(loc); + scf::IfOp::create(rewriter, loc, accBase, [&](OpBuilder &b, Location loc) { + linalg::FillOp::create(b, loc, ValueRange{other}, ValueRange{alloc}); + scf::YieldOp::create(b, loc); }); } @@ -321,35 +322,36 @@ struct MakeTensorPtrConverter // wrapping around. ShapedType::kDynamic}); - Value rowSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[0])); - Value colSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[1])); + Value rowSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[1])); Value modN = ofrToIndexValue(op.getMixedShape()[1], loc, rewriter); - Value x = rewriter.create(loc, targetOffset, modN); - Value y = rewriter.create(loc, targetOffset, x); + Value x = arith::RemSIOp::create(rewriter, loc, targetOffset, modN); + Value y = arith::SubIOp::create(rewriter, loc, targetOffset, x); SmallVector strideVals = ofrsToIndexValues(op.getMixedStrides(), loc, rewriter); // First chunk - Value nextOffset = rewriter.create(loc, x, colSize); + Value nextOffset = arith::AddIOp::create(rewriter, loc, x, colSize); Value clampedOffset = - rewriter.create(loc, nextOffset, modN); - Value d1 = rewriter.create(loc, clampedOffset, x); + arith::MinSIOp::create(rewriter, loc, nextOffset, modN); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOffset, x); SmallVector sizes1{rowSize, d1}; - auto cast1 = rewriter.create( - loc, resultType, adaptor.getBase(), targetOffset, sizes1, strideVals); + auto cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), targetOffset, sizes1, + strideVals); // Second chunk - Value d2 = rewriter.create(loc, colSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, colSize, d1); SmallVector sizes2{rowSize, d2}; - auto cast2 = rewriter.create( - loc, resultType, adaptor.getBase(), y, sizes2, strideVals); + auto cast2 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), y, sizes2, strideVals); return {cast1, cast2}; } @@ -450,10 +452,10 @@ struct MakeTensorPtrConverter // allow this anymore. So we put dynamic instead. ShapedType::kDynamic}); - Value rowSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[0])); - Value colSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[1])); + Value rowSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[0])); + Value colSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(op.getSizes()[1])); Value strideRow = ofrToIndexValue(op.getMixedStrides()[0], loc, rewriter); Value strideCol = ofrToIndexValue(op.getMixedStrides()[1], loc, rewriter); @@ -462,26 +464,24 @@ struct MakeTensorPtrConverter // First chunk Value wrappedAroundOff = - rewriter.create(loc, targetOffset, strideRow); + arith::RemSIOp::create(rewriter, loc, targetOffset, strideRow); Value clampedOff = - rewriter.create(loc, modRow, wrappedAroundOff); - Value d1 = rewriter.create(loc, clampedOff, targetOffset); - d1 = rewriter.create(loc, d1, strideRow); - d1 = rewriter.create(loc, d1, rowSize); + arith::AddIOp::create(rewriter, loc, modRow, wrappedAroundOff); + Value d1 = arith::SubIOp::create(rewriter, loc, clampedOff, targetOffset); + d1 = arith::DivSIOp::create(rewriter, loc, d1, strideRow); + d1 = arith::MinSIOp::create(rewriter, loc, d1, rowSize); SmallVector sizes1{d1, colSize}; - memref::ReinterpretCastOp cast1 = - rewriter.create( - loc, resultType, adaptor.getBase(), targetOffset, sizes1, - ValueRange{strideRow, strideCol}); + memref::ReinterpretCastOp cast1 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), targetOffset, sizes1, + ValueRange{strideRow, strideCol}); // Second chunk - Value d2 = rewriter.create(loc, rowSize, d1); + Value d2 = arith::SubIOp::create(rewriter, loc, rowSize, d1); SmallVector sizes2{d2, colSize}; - memref::ReinterpretCastOp cast2 = - rewriter.create( - loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, - ValueRange{strideRow, strideCol}); + memref::ReinterpretCastOp cast2 = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, + ValueRange{strideRow, strideCol}); return {cast1, cast2}; } @@ -515,8 +515,8 @@ struct MakeTensorPtrConverter llvm_unreachable("Unexpected split pointer shape"); } - auto combinedCast = rewriter.create( - op.getLoc(), op.getType(), casts); + auto combinedCast = UnrealizedConversionCastOp::create( + rewriter, op.getLoc(), op.getType(), casts); combinedCast->setAttr(wrapType, rewriter.getUnitAttr()); @@ -541,8 +541,8 @@ struct MakeTensorPtrConverter op, staticTargetOffset.value_or(ShapedType::kDynamic), staticStrides, resultShape); - auto castOp = rewriter.create( - op.getLoc(), resultType, adaptor.getBase(), targetOffset, + auto castOp = memref::ReinterpretCastOp::create( + rewriter, op.getLoc(), resultType, adaptor.getBase(), targetOffset, op.getMixedSizes(), mixedStrides); rewriter.replaceOp(op, castOp); @@ -626,71 +626,67 @@ struct LoadConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); - - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); - - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); - - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{zero, block1Col}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); - - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); + + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); + + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } void createStackedCopies(Value block1, Value block2, Value dst, Location loc, ConversionPatternRewriter &rewriter) const { auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); - - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); - - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); - - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{block1Row, zero}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); - - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + Value block1Row = memref::DimOp::create(rewriter, loc, block1, 0); + Value block1Col = memref::DimOp::create(rewriter, loc, block1, 1); + + Value block2Row = memref::DimOp::create(rewriter, loc, block2, 0); + Value block2Col = memref::DimOp::create(rewriter, loc, block2, 1); + + auto block1Dst = memref::SubViewOp::create(rewriter, loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = memref::SubViewOp::create(rewriter, loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + memref::CopyOp::create(rewriter, loc, block1, block1Dst); + memref::CopyOp::create(rewriter, loc, block2, block2Dst); } memref::SubViewOp createSubview(Value src, ArrayRef offsets, @@ -700,8 +696,8 @@ struct LoadConverter : public OpConversionPattern { auto srcType = cast(src.getType()); auto dstType = memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return rewriter.create(loc, cast(dstType), - src, offsets, sizes, strides); + return memref::SubViewOp::create(rewriter, loc, cast(dstType), + src, offsets, sizes, strides); } std::pair @@ -711,9 +707,9 @@ struct LoadConverter : public OpConversionPattern { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; OpFoldResult subviewCol1 = - rewriter.create(loc, block1, 1).getResult(); + memref::DimOp::create(rewriter, loc, block1, 1).getResult(); OpFoldResult subviewCol2 = - rewriter.create(loc, block2, 1).getResult(); + memref::DimOp::create(rewriter, loc, block2, 1).getResult(); SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); @@ -732,9 +728,9 @@ struct LoadConverter : public OpConversionPattern { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; OpFoldResult subviewRow1 = - rewriter.create(loc, block1, 0).getResult(); + memref::DimOp::create(rewriter, loc, block1, 0).getResult(); OpFoldResult subviewRow2 = - rewriter.create(loc, block2, 0).getResult(); + memref::DimOp::create(rewriter, loc, block2, 0).getResult(); SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); @@ -757,8 +753,8 @@ struct LoadConverter : public OpConversionPattern { auto tensorType = cast(op.getType()); auto elemType = tensorType.getElementType(); - auto alloc = rewriter.create( - loc, MemRefType::get(tensorType.getShape(), elemType)); + auto alloc = memref::AllocOp::create( + rewriter, loc, MemRefType::get(tensorType.getShape(), elemType)); // No mask assert(!other && "other value used in non-masked load"); @@ -781,11 +777,12 @@ struct LoadConverter : public OpConversionPattern { llvm_unreachable("unexpected wraparound type"); } } else { - rewriter.create(loc, ptr, alloc); + memref::CopyOp::create(rewriter, loc, ptr, alloc); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create(rewriter, loc, tensorType, + alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -801,8 +798,8 @@ struct LoadConverter : public OpConversionPattern { auto tensorType = cast(op.getType()); auto elemType = tensorType.getElementType(); - auto alloc = rewriter.create( - loc, MemRefType::get(tensorType.getShape(), elemType)); + auto alloc = memref::AllocOp::create( + rewriter, loc, MemRefType::get(tensorType.getShape(), elemType)); SmallVector mixedDims = op.getMixedMaskDims(); @@ -842,11 +839,12 @@ struct LoadConverter : public OpConversionPattern { getSubview(tensorType.getRank(), mixedDims, ptr, loc, rewriter); memref::SubViewOp dstSubview = getSubview(tensorType.getRank(), mixedDims, alloc, loc, rewriter); - rewriter.create(loc, srcSubview, dstSubview); + memref::CopyOp::create(rewriter, loc, srcSubview, dstSubview); } - Value tensor = rewriter.create( - loc, tensorType, alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create(rewriter, loc, tensorType, + alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); return success(); @@ -864,7 +862,7 @@ struct LoadConverter : public OpConversionPattern { auto indexOffsetTy = RankedTensorType::get(offsetShapedType.getShape(), rewriter.getIndexType()); gatherOffset = - rewriter.create(loc, indexOffsetTy, gatherOffset) + arith::IndexCastOp::create(rewriter, loc, indexOffsetTy, gatherOffset) .getResult(); int gatherDim = ptr.getGatherScatterDim(); @@ -881,7 +879,7 @@ struct LoadConverter : public OpConversionPattern { auto resultType = dyn_cast(op.getResult().getType()); auto allocType = MemRefType::get(resultType.getShape(), resultType.getElementType()); - auto alloc = rewriter.create(loc, allocType); + auto alloc = memref::AllocOp::create(rewriter, loc, allocType); auto allocStrides = mlir::getMixedValues( allocType.getStridesAndOffset().first, dynSizes, rewriter); @@ -892,9 +890,9 @@ struct LoadConverter : public OpConversionPattern { } // Create loop to iterate every offset in gatherOffset. - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); Value upperBound = - rewriter.create(loc, offsetSize).getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize).getResult(); if (op.hasMask()) { SmallVector mixedDims = op.getMixedMaskDims(); OpFoldResult gatherMaskDim = mixedDims[gatherDim]; @@ -911,26 +909,26 @@ struct LoadConverter : public OpConversionPattern { gatherMaskDimValue = offsetSize; } offsetSize = std::min(offsetSize, gatherMaskDimValue); - upperBound = rewriter.create(loc, offsetSize) + upperBound = arith::ConstantIndexOp::create(rewriter, loc, offsetSize) .getResult(); } else { // Use arith::MinSIOp to get the minimum value of gatherMaskDim // and offsetSize. auto gatherMaskDimVal = cast(gatherMaskDim); auto offsetSizeVal = - rewriter.create(loc, offsetSize); - upperBound = - rewriter - .create(loc, gatherMaskDimVal, offsetSizeVal) - .getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize); + upperBound = arith::MinSIOp::create(rewriter, loc, gatherMaskDimVal, + offsetSizeVal) + .getResult(); } } - auto step = rewriter.create(loc, 1); - auto loop = rewriter.create(loc, lowerBound, upperBound, step); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto loop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Create tensor from alloc and use it as the result to replace op. - Value tensor = rewriter.create( - loc, op.getType(), alloc, true /* restrict */, true /* writable */); + Value tensor = bufferization::ToTensorOp::create( + rewriter, loc, op.getType(), alloc, true /* restrict */, + true /* writable */); rewriter.replaceOp(op, tensor); // Build loop body. @@ -941,15 +939,15 @@ struct LoadConverter : public OpConversionPattern { if (Value unstructuredMask = ptr.getGatherScatterMask()) { // If the gather scatter mask is present, we need to use it to guard the // load. - auto maskValue = rewriter.create( - loc, unstructuredMask, ValueRange{inductionVar}); - auto ifOp = rewriter.create(loc, maskValue); + auto maskValue = tensor::ExtractOp::create( + rewriter, loc, unstructuredMask, ValueRange{inductionVar}); + auto ifOp = scf::IfOp::create(rewriter, loc, maskValue); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); } // Load the offsetElt first. - auto gatherOffsetElt = rewriter.create( - loc, gatherOffset, ValueRange{inductionVar}); + auto gatherOffsetElt = tensor::ExtractOp::create( + rewriter, loc, gatherOffset, ValueRange{inductionVar}); // reinterpret_cast to current row as memRefPtr[gatherOffsetElt]. Value srcPtr = rewriteGatherScatterPtrElement(staticSizes, ptr, memRefPtr, @@ -971,11 +969,10 @@ struct LoadConverter : public OpConversionPattern { // Use oneStrides for subview. auto dstSubViewType = memref::SubViewOp::inferResultType( cast(srcPtr.getType()), maskOffsets, sizes, oneStrides); - srcPtr = - rewriter - .create(loc, cast(dstSubViewType), + srcPtr = memref::SubViewOp::create(rewriter, loc, + cast(dstSubViewType), srcPtr, maskOffsets, sizes, oneStrides) - .getResult(); + .getResult(); } // alloc[inductionVar] @@ -983,11 +980,11 @@ struct LoadConverter : public OpConversionPattern { allocOffsets[gatherDim] = inductionVar; auto dstAllocType = memref::SubViewOp::inferResultType( allocType, allocOffsets, sizes, oneStrides); - auto dstSubview = rewriter.create( - loc, cast(dstAllocType), alloc, allocOffsets, sizes, - oneStrides); + auto dstSubview = + memref::SubViewOp::create(rewriter, loc, cast(dstAllocType), + alloc, allocOffsets, sizes, oneStrides); // Copy srcPtr to alloc[inductionVar]. - rewriter.create(loc, srcPtr, dstSubview); + memref::CopyOp::create(rewriter, loc, srcPtr, dstSubview); return success(); } @@ -1027,8 +1024,8 @@ struct StoreConverter : public OpConversionPattern { auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); - return b.create(loc, dstType, source, offsets, dims, - strides); + return tensor::ExtractSliceOp::create(b, loc, dstType, source, offsets, + dims, strides); } LogicalResult rewriteScatter(tts::MakeGatherScatterTensorPtrOp ptr, @@ -1043,7 +1040,7 @@ struct StoreConverter : public OpConversionPattern { auto indexOffsetTy = RankedTensorType::get(offsetShapedType.getShape(), rewriter.getIndexType()); gatherOffset = - rewriter.create(loc, indexOffsetTy, gatherOffset) + arith::IndexCastOp::create(rewriter, loc, indexOffsetTy, gatherOffset) .getResult(); int gatherDim = ptr.getGatherScatterDim(); @@ -1057,9 +1054,9 @@ struct StoreConverter : public OpConversionPattern { auto sizes = mlir::getMixedValues(staticSizes, dynSizes, rewriter); // Create loop to iterate every offset in gatherOffset. - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); Value upperBound = - rewriter.create(loc, offsetSize).getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize).getResult(); if (op.hasMask()) { SmallVector mixedDims = op.getMixedMaskDims(); OpFoldResult gatherMaskDim = mixedDims[gatherDim]; @@ -1076,22 +1073,21 @@ struct StoreConverter : public OpConversionPattern { gatherMaskDimValue = offsetSize; } offsetSize = std::min(offsetSize, gatherMaskDimValue); - upperBound = rewriter.create(loc, offsetSize) + upperBound = arith::ConstantIndexOp::create(rewriter, loc, offsetSize) .getResult(); } else { // Use arith::MinSIOp to get the minimum value of gatherMaskDim // and offsetSize. auto gatherMaskDimVal = cast(gatherMaskDim); auto offsetSizeVal = - rewriter.create(loc, offsetSize); - upperBound = - rewriter - .create(loc, gatherMaskDimVal, offsetSizeVal) - .getResult(); + arith::ConstantIndexOp::create(rewriter, loc, offsetSize); + upperBound = arith::MinSIOp::create(rewriter, loc, gatherMaskDimVal, + offsetSizeVal) + .getResult(); } } - auto step = rewriter.create(loc, 1); - auto loop = rewriter.create(loc, lowerBound, upperBound, step); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto loop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Build loop body. rewriter.setInsertionPointToStart(loop.getBody()); @@ -1101,15 +1097,15 @@ struct StoreConverter : public OpConversionPattern { if (Value unstructuredMask = ptr.getGatherScatterMask()) { // If the gather scatter mask is present, we need to use it to guard the // store. - auto maskValue = rewriter.create( - loc, unstructuredMask, ValueRange{inductionVar}); - auto ifOp = rewriter.create(loc, maskValue); + auto maskValue = tensor::ExtractOp::create( + rewriter, loc, unstructuredMask, ValueRange{inductionVar}); + auto ifOp = scf::IfOp::create(rewriter, loc, maskValue); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); } // Load the offsetElt first. - auto gatherOffsetElt = rewriter.create( - loc, gatherOffset, ValueRange{inductionVar}); + auto gatherOffsetElt = tensor::ExtractOp::create( + rewriter, loc, gatherOffset, ValueRange{inductionVar}); // Create extract_slice stVal[inductionVar]. unsigned rank = ptr.getSizes().size(); @@ -1125,8 +1121,8 @@ struct StoreConverter : public OpConversionPattern { } // The subview should not apply an additional stride to the source. SmallVector oneStrides(rank, OpFoldResult(step)); - auto slice = rewriter.create( - loc, stVal, stValOffsets, sizes, oneStrides); + auto slice = tensor::ExtractSliceOp::create( + rewriter, loc, stVal, stValOffsets, sizes, oneStrides); // reinterpret_cast to current row as memRefPtr[gatherOffsetElt]. Value dstPtr = rewriteGatherScatterPtrElement(staticSizes, ptr, memRefPtr, @@ -1141,14 +1137,13 @@ struct StoreConverter : public OpConversionPattern { cast(dstPtr.getType()), maskOffsets, sizes, oneStrides); dstPtr = - rewriter - .create(loc, cast(dstType), dstPtr, - maskOffsets, sizes, oneStrides) + memref::SubViewOp::create(rewriter, loc, cast(dstType), + dstPtr, maskOffsets, sizes, oneStrides) .getResult(); } // store slice to dstPtr. - auto storeOp = rewriter.create( - loc, slice, dstPtr); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, slice, dstPtr); storeOp.setWritable(true); rewriter.eraseOp(op); @@ -1182,12 +1177,12 @@ struct StoreConverter : public OpConversionPattern { getExtractSlice(rank, mixedDims, storeValue, loc, rewriter); auto dstSubview = getSubview(rank, mixedDims, ptr, loc, rewriter); - auto storeOp = rewriter.create( - loc, srcSlice, dstSubview); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, srcSlice, dstSubview); storeOp.setWritable(true); } else { - auto storeOp = rewriter.create( - loc, storeValue, ptr); + auto storeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, storeValue, ptr); storeOp.setWritable(true); } diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index c0d4cd0a..fdd3bd90 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -51,13 +51,15 @@ class PtrToUnrankedMemrefConverter : public TypeConverter { addTargetMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); } @@ -70,9 +72,9 @@ class StructuredToMemrefPass public: void getDependentDialects(DialectRegistry ®istry) const override { registry - .insert(); } diff --git a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp index 8c86fe26..39cf0ba7 100644 --- a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp +++ b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -215,7 +215,8 @@ class TritonArithToLinalgPass func.getAllArgAttrs(argAttrs); func.getAllResultAttrs(resAttrs); - auto funcFunc = builder.create(func.getLoc(), name, type); + auto funcFunc = + func::FuncOp::create(builder, func.getLoc(), name, type); // Preserve the visibility attribute funcFunc.setVisibility(func.getVisibility()); funcFunc.setAllArgAttrs(argAttrs); @@ -234,7 +235,7 @@ class TritonArithToLinalgPass // considered terminators. if (isa(term)) { builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); + func::ReturnOp::create(builder, func.getLoc(), term->getOperands()); term->erase(); } } diff --git a/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp b/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp index b75e7ddb..09d8c23a 100644 --- a/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp +++ b/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp @@ -61,7 +61,8 @@ class TritonFunctionSignatureConverter : public TypeConverter { auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }; addSourceMaterialization(createUnrealizedCast); diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp index 25b7db85..46eabbe5 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp @@ -195,7 +195,7 @@ class TritonToLinalgPass : public TritonToLinalgBase { func.getAllArgAttrs(argAttrs); func.getAllResultAttrs(resAttrs); - auto funcFunc = builder.create(func.getLoc(), name, type); + auto funcFunc = func::FuncOp::create(builder, func.getLoc(), name, type); funcFunc.setAllArgAttrs(argAttrs); funcFunc.setAllResultAttrs(resAttrs); @@ -208,7 +208,7 @@ class TritonToLinalgPass : public TritonToLinalgBase { for (Block &block : funcFuncBody.getBlocks()) { auto term = block.getTerminator(); builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); + func::ReturnOp::create(builder, func.getLoc(), term->getOperands()); term->erase(); } func.erase(); diff --git a/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp b/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp index 95061ac4..0130ede6 100644 --- a/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/CollapseShape.cpp @@ -83,8 +83,8 @@ struct CollapseFill : public OpRewritePattern { } auto elementType = resultType.getElementType(); - auto output = rewriter.create( - loc, + auto output = memref::CollapseShapeOp::create( + rewriter, loc, MemRefType::get(llvm::ArrayRef{resultType.getNumElements()}, elementType), result, reassociationMap); @@ -112,14 +112,15 @@ struct CollapseFill : public OpRewritePattern { reassociationMap[0].push_back(rewriter.getAffineDimExpr(i)); } - auto init = rewriter.create( - loc, RankedTensorType::get({resultType.getNumElements()}, elementType), + auto init = tensor::CollapseShapeOp::create( + rewriter, loc, + RankedTensorType::get({resultType.getNumElements()}, elementType), op.getOutputs()[0], reassociationMap); auto fillOp = - rewriter.create(loc, op.getInputs(), ValueRange{init}); + linalg::FillOp::create(rewriter, loc, op.getInputs(), ValueRange{init}); - auto expandOp = rewriter.create( - loc, result.getType(), fillOp.getResult(0), reassociationMap); + auto expandOp = tensor::ExpandShapeOp::create( + rewriter, loc, result.getType(), fillOp.getResult(0), reassociationMap); rewriter.replaceOp(op, expandOp.getResult()); return success(); @@ -224,8 +225,8 @@ struct CollapseTranspose : public OpRewritePattern { auto loc = op.getLoc(); sourceType = RankedTensorType::get(collapseShapeInput, elementType); - source = rewriter.create(loc, sourceType, source, - reassociationMap); + source = tensor::CollapseShapeOp::create(rewriter, loc, sourceType, source, + reassociationMap); SmallVector reassociationMapRe(reassociationMap.size()); int idx = 0; @@ -235,12 +236,12 @@ struct CollapseTranspose : public OpRewritePattern { } } - Value transposeInit = rewriter.create( - loc, RankedTensorType::get(transposedShape, elementType), op.getInit(), - reassociationMapRe); + Value transposeInit = tensor::CollapseShapeOp::create( + rewriter, loc, RankedTensorType::get(transposedShape, elementType), + op.getInit(), reassociationMapRe); Value transpose = - rewriter.create(loc, source, transposeInit, perm) + linalg::TransposeOp::create(rewriter, loc, source, transposeInit, perm) .getResults()[0]; rewriter.replaceOpWithNewOp( @@ -337,8 +338,8 @@ struct CollapseBroadCast : public OpRewritePattern { auto loc = op.getLoc(); sourceType = RankedTensorType::get(collapseShapeInput, elementType); - input = rewriter.create(loc, sourceType, input, - reassociationMap); + input = tensor::CollapseShapeOp::create(rewriter, loc, sourceType, input, + reassociationMap); resultType = RankedTensorType::get(collapseShapeOutput, elementType); size_t resultRank = resultType.getRank(); @@ -351,13 +352,14 @@ struct CollapseBroadCast : public OpRewritePattern { assert(op->getNumResults() == 1 && "code assumes single result!"); - auto init = rewriter.create( - loc, RankedTensorType::get(resultType.getShape(), elementType), + auto init = tensor::CollapseShapeOp::create( + rewriter, loc, + RankedTensorType::get(resultType.getShape(), elementType), op.getOutputs()[0], reassociationMap); - auto linalgOp = rewriter.create( - loc, init->getResultTypes(), ValueRange{input}, ValueRange{init}, - indexingMaps, getNParallelLoopsAttrs(resultRank)); + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, init->getResultTypes(), ValueRange{input}, + ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(resultRank)); rewriter.cloneRegionBefore(op.getRegion(), linalgOp.getRegion(), linalgOp.getRegion().begin()); linalgOp->setAttr("broadcastDims", @@ -453,8 +455,8 @@ struct CollapseReduce : public OpRewritePattern { auto elementType = inputType.getElementType(); auto loc = op.getLoc(); auto newInputType = RankedTensorType::get(collapseShapeInput, elementType); - input = rewriter.create(loc, newInputType, input, - reassociationMap); + input = tensor::CollapseShapeOp::create(rewriter, loc, newInputType, input, + reassociationMap); SmallVector reassociationMapOutput; int idx = 0; @@ -469,12 +471,12 @@ struct CollapseReduce : public OpRewritePattern { rewriter.getAffineDimExpr(idx++)); } } - auto init = rewriter.create( - loc, RankedTensorType::get(collapseShapeOutput, elementType), + auto init = tensor::CollapseShapeOp::create( + rewriter, loc, RankedTensorType::get(collapseShapeOutput, elementType), op.getInits()[0], reassociationMapOutput); - auto newReduce = rewriter.create( - loc, init->getResultTypes(), ValueRange{input}, ValueRange{init}, - newDims); + auto newReduce = + linalg::ReduceOp::create(rewriter, loc, init->getResultTypes(), + ValueRange{input}, ValueRange{init}, newDims); rewriter.cloneRegionBefore(op.getRegion(), newReduce.getRegion(), newReduce.getRegion().begin()); diff --git a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp index e8a3b7f6..3547efe5 100644 --- a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" @@ -61,8 +62,8 @@ struct SimplifyUnrealizedCast } auto prevInput = unrealizedCast.getInputs().front(); - auto newCast = rewriter.create( - op->getLoc(), op->getResultTypes(), ValueRange{prevInput}); + auto newCast = UnrealizedConversionCastOp::create( + rewriter, op->getLoc(), op->getResultTypes(), ValueRange{prevInput}); rewriter.replaceOp(op, newCast); return success(); @@ -90,14 +91,17 @@ struct FromMemrefConverter if (unrankedInput && isa(outType)) { // from_memref only takes ranked memref, cast the unranked memref to // ranked memref first. - auto rankedMemref = rewriter.create( + auto memSpace = tptr::DefaultMemorySpaceAttr::get(rewriter.getContext()); + Value rankedMemref = rewriter.create( op.getLoc(), MemRefType::get({1}, unrankedInput.getElementType()), input); - auto memrefToPtr = rewriter.create( - op->getLoc(), - ptr::PtrType::get( - rewriter.getContext(), - tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())), + rankedMemref = rewriter.create( + op.getLoc(), + MemRefType::get({1}, unrankedInput.getElementType(), + MemRefLayoutAttrInterface{}, memSpace), + rankedMemref); + auto memrefToPtr = rewriter.create( + op->getLoc(), ptr::PtrType::get(rewriter.getContext(), memSpace), rankedMemref); rewriter.replaceAllUsesWith(output, memrefToPtr); @@ -128,14 +132,27 @@ struct ToMemrefConverter : public OpRewritePattern { // to_memref can only cast to ranked static shape memref, we have to cast // the resulting memref back to unranked auto elemType = outUnrankedMemrefType.getElementType(); - auto ptrToMemref = rewriter.create( - op->getLoc(), MemRefType::get({1}, elemType), input); + mlir::Attribute memSpace = + tptr::DefaultMemorySpaceAttr::get(rewriter.getContext()); + if (auto ptrType = dyn_cast(inType)) { + memSpace = ptrType.getMemorySpace(); + } + Value ptrToMemref = rewriter.create( + op->getLoc(), + MemRefType::get({1}, elemType, MemRefLayoutAttrInterface{}, memSpace), + input); + ptrToMemref = rewriter.create( + op->getLoc(), + MemRefType::get({1}, elemType, MemRefLayoutAttrInterface{}, + outUnrankedMemrefType.getMemorySpace()), + ptrToMemref); SmallVector sizes = {rewriter.getIndexAttr(1)}; SmallVector newStrides = {rewriter.getIndexAttr(1)}; - auto newUnrankedMemref = rewriter.create( - op->getLoc(), MemRefType::get({ShapedType::kDynamic}, elemType), - ptrToMemref, rewriter.getIndexAttr(0), sizes, newStrides); + auto newUnrankedMemref = memref::ReinterpretCastOp::create( + rewriter, op->getLoc(), + MemRefType::get({ShapedType::kDynamic}, elemType), ptrToMemref, + rewriter.getIndexAttr(0), sizes, newStrides); rewriter.replaceAllUsesWith(output, newUnrankedMemref); rewriter.eraseOp(op); @@ -151,7 +168,8 @@ class ReconcilePtrCastsPass public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { diff --git a/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp b/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp index 8234360d..727ec91a 100644 --- a/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" @@ -180,12 +181,12 @@ struct AddPtrConverter : public OpConversionPattern { } auto loc = op->getLoc(); auto pointeeType = cast(op.getType()).getPointeeType(); - auto offsetType = op.getOffset().getType(); + auto offsetType = adaptor.getOffset().getType(); auto pointeeSizeInBytes = - rewriter.create(loc, offsetType, pointeeType); - auto scaledOffset = - rewriter.create(loc, op.getOffset(), pointeeSizeInBytes); - rewriter.replaceOpWithNewOp( + rewriter.create(loc, offsetType, pointeeType); + auto scaledOffset = rewriter.create(loc, adaptor.getOffset(), + pointeeSizeInBytes); + auto dddd = rewriter.replaceOpWithNewOp( op, ptr::PtrType::get( rewriter.getContext(), @@ -213,37 +214,40 @@ struct LoadConverter : public OpConversionPattern { auto ptr = op.getPtr(); auto pointeeType = cast(ptr.getType()).getPointeeType(); + auto ptrType = cast(adaptor.getPtr().getType()); + auto memref = rewriter.create( + op->getLoc(), + MemRefType::get({1}, pointeeType, MemRefLayoutAttrInterface{}, + ptrType.getMemorySpace()), + adaptor.getPtr()); - auto memref = rewriter.create( - op->getLoc(), MemRefType::get({1}, pointeeType), adaptor.getPtr()); - - auto zero = rewriter.create(op.getLoc(), 0); + auto zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); if (op.getMask()) { - auto ifOp = rewriter.create( - op->getLoc(), op.getMask(), + auto ifOp = scf::IfOp::create( + rewriter, op->getLoc(), op.getMask(), [&](OpBuilder &b, Location loc) { // Truthy case, load from the index. - Value memrefLoad = rewriter.create( - op->getLoc(), memref, ValueRange{zero}); - b.create(loc, memrefLoad); + Value memrefLoad = memref::LoadOp::create(rewriter, op->getLoc(), + memref, ValueRange{zero}); + scf::YieldOp::create(b, loc, memrefLoad); }, [&](OpBuilder &b, Location loc) { // Falsy case, yield `other` or 0 as the default value. if (op.getOther()) { - b.create(loc, op.getOther()); + scf::YieldOp::create(b, loc, op.getOther()); } else { auto elemType = op.getType(); auto zeroAttr = b.getZeroAttr(elemType); assert(zeroAttr && "unexpected element type"); - Value val = b.create(loc, zeroAttr); - b.create(loc, val); + Value val = arith::ConstantOp::create(b, loc, zeroAttr); + scf::YieldOp::create(b, loc, val); } }); rewriter.replaceOp(op, ifOp); } else { - auto memrefLoad = rewriter.create(op->getLoc(), memref, - ValueRange{zero}); + auto memrefLoad = memref::LoadOp::create(rewriter, op->getLoc(), memref, + ValueRange{zero}); rewriter.replaceOp(op, memrefLoad); } @@ -272,18 +276,21 @@ struct StoreConverter : public OpConversionPattern { IRRewriter::InsertionGuard g(rewriter); if (op.getMask()) { - auto ifOp = rewriter.create(op->getLoc(), op.getMask(), - /*withElseRegion*/ false); + auto ifOp = scf::IfOp::create(rewriter, op->getLoc(), op.getMask(), + /*withElseRegion*/ false); rewriter.setInsertionPointToStart( &ifOp.getThenRegion().getBlocks().front()); } - - auto memref = rewriter.create( - op->getLoc(), MemRefType::get({1}, pointeeType), adaptor.getPtr()); + auto ptrType = cast(adaptor.getPtr().getType()); + auto memref = rewriter.create( + op->getLoc(), + MemRefType::get({1}, pointeeType, MemRefLayoutAttrInterface{}, + ptrType.getMemorySpace()), + adaptor.getPtr()); auto zero = rewriter.create(op.getLoc(), 0); - rewriter.create(op->getLoc(), op.getValue(), memref, - ValueRange{zero}); + memref::StoreOp::create(rewriter, op->getLoc(), op.getValue(), memref, + ValueRange{zero}); rewriter.eraseOp(op); @@ -353,9 +360,10 @@ struct LinalgPtrConverter : public OpConversionPattern { return failure(); } - auto replacement = rewriter.create( - op.getLoc(), convertedTypes, adaptor.getInputs(), adaptor.getOutputs(), - op.getIndexingMapsArray(), op.getIteratorTypesArray()); + auto replacement = linalg::GenericOp::create( + rewriter, op.getLoc(), convertedTypes, adaptor.getInputs(), + adaptor.getOutputs(), op.getIndexingMapsArray(), + op.getIteratorTypesArray()); Region ®ion = op.getRegion(); Block &block = region.front(); @@ -431,7 +439,8 @@ class TritonPtrTypeConverter : public TypeConverter { }); auto createCast = [&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }; addTargetMaterialization(createCast); @@ -479,7 +488,8 @@ class TritonToPtrPass : public impl::TritonToPtrBase { target.addLegalDialect(); + tptr::TPtrDialect, ptr::PtrDialect, + memref::MemRefDialect>(); patterns .add(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); // Compute the target materialization, given a value with the pointer type, // convert that value to a tuple type. - converter.addTargetMaterialization([](OpBuilder &builder, - TypeRange resultTypes, - ValueRange inputs, - Location loc) -> SmallVector { - return builder - .create(loc, resultTypes, inputs.front()) - ->getResults(); - }); + converter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc) -> SmallVector { + return UnrealizedConversionCastOp::create(builder, loc, resultTypes, + inputs.front()) + ->getResults(); + }); ConversionTarget target(getContext()); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, @@ -200,8 +200,8 @@ class TritonToStructuredPass // during reconcile-unrealized-conversion-casts. converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder - .create(loc, resultType, inputs[0]) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs[0]) ->getResult(0); }); @@ -214,8 +214,8 @@ class TritonToStructuredPass converter.addTargetMaterialization([](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, Location loc) { - auto placeholder = builder.create( - loc, inputs.front().getDefiningOp()->getOperand(0)); + auto placeholder = tts::GetStructuredStateOp::create( + builder, loc, inputs.front().getDefiningOp()->getOperand(0)); assert(llvm::equal(placeholder.getResultTypes(), resultTypes)); return placeholder.getResults(); }); diff --git a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp index 88568797..867ebd4d 100644 --- a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp +++ b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp @@ -252,8 +252,8 @@ class TritonToUnstructuredPass } OpBuilder b(func->getRegion(0)); - Value zero = b.create( - arg.getLoc(), + Value zero = arith::ConstantOp::create( + b, arg.getLoc(), b.getIntegerAttr(IntegerType::get(&getContext(), defaultBitWidth), 0)); @@ -271,8 +271,8 @@ class TritonToUnstructuredPass } auto res = op.getResult(); OpBuilder b(op); - Value zero = b.create( - op.getLoc(), + Value zero = arith::ConstantOp::create( + b, op.getLoc(), b.getIntegerAttr(IntegerType::get(&getContext(), defaultBitWidth), 0)); @@ -304,8 +304,8 @@ class TritonToUnstructuredPass // We are converting a pointer to an integer here, // materialized the pointer using the accumulated offset // that we have stored so far. - auto materializedAddPtr = b.create( - op->getLoc(), offsetInfo.ptrType, offsetInfo.ptr, + auto materializedAddPtr = triton::AddPtrOp::create( + b, op->getLoc(), offsetInfo.ptrType, offsetInfo.ptr, offsetInfo.offset); // Change the op to use the "simplified" pointer above. @@ -337,19 +337,19 @@ class TritonToUnstructuredPass auto resWidth = std::max(lhsWidth, rhsWidth); if (lhsWidth < resWidth) { - prevOff = b.create( - loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), + prevOff = arith::ExtSIOp::create( + b, loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), prevOff); } if (rhsWidth < resWidth) { - off = b.create( - loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), + off = arith::ExtSIOp::create( + b, loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), off); } - auto accumulatedOff = b.create( - loc, getPtrOffsetType(addptr.getType(), resWidth), + auto accumulatedOff = arith::AddIOp::create( + b, loc, getPtrOffsetType(addptr.getType(), resWidth), prevOff, off); PtrOffset newOffsetInfo{offsetInfo.ptr, addptr.getType(), @@ -491,8 +491,8 @@ class TritonToUnstructuredPass } } - auto gather = b.create( - loc, load.getType(), offsetInfo.ptr, offsetInfo.offset, + auto gather = tts::GatherOp::create( + b, loc, load.getType(), offsetInfo.ptr, offsetInfo.offset, load.getMask(), other); load->replaceAllUsesWith(gather->getResults()); @@ -501,8 +501,9 @@ class TritonToUnstructuredPass }) .Case([&](triton::StoreOp store) { auto offsetInfo = offsetMap.at(store.getPtr()); - b.create(loc, offsetInfo.ptr, offsetInfo.offset, - store.getValue(), store.getMask()); + tts::ScatterOp::create(b, loc, offsetInfo.ptr, + offsetInfo.offset, store.getValue(), + store.getMask()); store->erase(); return success(); }) @@ -526,26 +527,26 @@ class TritonToUnstructuredPass if (baseOffType != currOffType) { if (currOffType.isIndex()) { - baseOffset = b.create( - loc, b.getIndexType(), baseOffset); + baseOffset = arith::IndexCastOp::create( + b, loc, b.getIndexType(), baseOffset); } else if (currOffType.isInteger()) { if (baseOffType.getIntOrFloatBitWidth() < currOffType.getIntOrFloatBitWidth()) { - baseOffset = b.create(loc, currOffType, - baseOffset); + baseOffset = arith::ExtSIOp::create(b, loc, currOffType, + baseOffset); } else { // MakeTensorPtrOp only takes i32 offsets, so we need // to truncate if the offsets were already in i64 makeTensorPtr.emitWarning( "truncating offsets which may result in data loss"); - baseOffset = b.create(loc, currOffType, - baseOffset); + baseOffset = arith::TruncIOp::create(b, loc, currOffType, + baseOffset); } } } - auto accumulatedOffset = b.create( - loc, currOffset.getType(), baseOffset, currOffset); + auto accumulatedOffset = arith::AddIOp::create( + b, loc, currOffset.getType(), baseOffset, currOffset); offsetOpnd.set(accumulatedOffset); diff --git a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp index 316537bf..df97ac6a 100644 --- a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp +++ b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp @@ -53,7 +53,8 @@ class PtrToUnrankedMemrefConverter : public TypeConverter { addTargetMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Value { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); } @@ -89,11 +90,11 @@ struct ScalarLoadConverter : public OpConversionPattern { auto basePtr = adaptor.getPtr(); auto offset = adaptor.getOffset(); - Value loadIndex = rewriter.create( - loc, rewriter.getIndexType(), offset); + Value loadIndex = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), offset); - auto memref = rewriter.create( - loc, + auto memref = memref::ReinterpretCastOp::create( + rewriter, loc, getMemrefTypeForScalarPtr( cast(gatherOp.getPtr().getType()), rewriter.getContext()), @@ -103,8 +104,8 @@ struct ScalarLoadConverter : public OpConversionPattern { auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - auto scalarLoadOp = rewriter.create( - loc, memref, zeroMap, ValueRange{}); + auto scalarLoadOp = affine::AffineLoadOp::create(rewriter, loc, memref, + zeroMap, ValueRange{}); rewriter.replaceOp(gatherOp, scalarLoadOp.getResult()); @@ -134,11 +135,11 @@ struct ScalarStoreConverter : public OpConversionPattern { auto basePtr = adaptor.getPtr(); auto offset = adaptor.getOffset(); - Value storeIndex = rewriter.create( - loc, rewriter.getIndexType(), offset); + Value storeIndex = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), offset); - auto memref = rewriter.create( - loc, + auto memref = memref::ReinterpretCastOp::create( + rewriter, loc, getMemrefTypeForScalarPtr( cast(scatterOp.getPtr().getType()), rewriter.getContext()), @@ -149,8 +150,8 @@ struct ScalarStoreConverter : public OpConversionPattern { auto storeVal = scatterOp.getValue(); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); - rewriter.create(loc, storeVal, memref, zeroMap, - ValueRange{}); + affine::AffineStoreOp::create(rewriter, loc, storeVal, memref, zeroMap, + ValueRange{}); rewriter.eraseOp(scatterOp); return success(); @@ -186,22 +187,19 @@ struct GatherConverter : public OpConversionPattern { // Treat the base pointer (memref) as 1D because the offsets are all // relative to a single base pointer (already collapsed). - auto baseMemref = rewriter - .create( - loc, - MemRefType::get({ShapedType::kDynamic}, - resultType.getElementType()), - ptr) - .getResult(); + auto baseMemref = + memref::CastOp::create(rewriter, loc, + MemRefType::get({ShapedType::kDynamic}, + resultType.getElementType()), + ptr) + .getResult(); auto baseTensor = - rewriter - .create( - loc, - RankedTensorType::get( - SmallVector(1, ShapedType::kDynamic), - resultType.getElementType()), - baseMemref, true /* restrict */, false /* writable */) + bufferization::ToTensorOp::create( + rewriter, loc, + RankedTensorType::get(SmallVector(1, ShapedType::kDynamic), + resultType.getElementType()), + baseMemref, true /* restrict */, false /* writable */) .getResult(); // The linalg.generic op should have the following inputs: @@ -213,10 +211,10 @@ struct GatherConverter : public OpConversionPattern { inputs.push_back(gatherOp.getMask()); } - auto emptyTensor = rewriter - .create(loc, resultType.getShape(), - resultType.getElementType()) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()) + .getResult(); // Affine maps for the inputs and one additional output. SmallVector affineMaps( @@ -227,16 +225,17 @@ struct GatherConverter : public OpConversionPattern { SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); - auto genericOp = rewriter.create( - loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps, - iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + auto genericOp = linalg::GenericOp::create( + rewriter, loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, + affineMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc, Value index) -> Value { Value index0 = - b.create(loc, b.getIndexType(), index); + arith::IndexCastOp::create(b, loc, b.getIndexType(), index); - return b.create(loc, baseTensor, - ValueRange{index0}); + return tensor::ExtractOp::create(b, loc, baseTensor, + ValueRange{index0}); }; auto offset = args[0]; @@ -245,34 +244,34 @@ struct GatherConverter : public OpConversionPattern { // If there is no mask, simply extract the current element from the // base tensor and use it as the yield value. auto loadValue = getValueAtIndex(b, loc, offset); - b.create(loc, loadValue); + linalg::YieldOp::create(b, loc, loadValue); } else { // If the mask value is truthy, the current element is loaded from // the base tensor using its offset. Otherwise, if `other` is // present, yield `other`. If `other` is not present, a default // value of 0 is used. auto mask = args[1]; - auto ifOp = b.create( - loc, mask, + auto ifOp = scf::IfOp::create( + b, loc, mask, [&](OpBuilder &b, Location loc) { // Truthy case, load from the index. auto value = getValueAtIndex(b, loc, offset); - b.create(loc, value); + scf::YieldOp::create(b, loc, value); }, [&](OpBuilder &b, Location loc) { // Falsy case, yield `other` or 0 as the default value. if (gatherOp.getOther()) { - b.create(loc, gatherOp.getOther()); + scf::YieldOp::create(b, loc, gatherOp.getOther()); } else { auto elemType = resultType.getElementType(); auto zeroAttr = b.getZeroAttr(elemType); assert(zeroAttr && "unexpected element type"); - Value extract = b.create(loc, zeroAttr); - b.create(loc, extract); + Value extract = arith::ConstantOp::create(b, loc, zeroAttr); + scf::YieldOp::create(b, loc, extract); } }); - b.create(loc, ifOp->getResult(0)); + linalg::YieldOp::create(b, loc, ifOp->getResult(0)); } }); @@ -312,11 +311,10 @@ struct ScatterConverter : public OpConversionPattern { // Treat the base pointer (memref) as 1D because the offsets are all // relative to a single base pointer (already collapsed). auto baseMemref = - rewriter - .create(loc, - MemRefType::get({ShapedType::kDynamic}, - valueType.getElementType()), - ptr) + memref::CastOp::create( + rewriter, loc, + MemRefType::get({ShapedType::kDynamic}, valueType.getElementType()), + ptr) .getResult(); // The linalg.generic op should have the following inputs: @@ -339,16 +337,16 @@ struct ScatterConverter : public OpConversionPattern { rewriter.setInsertionPoint(scatterOp); - auto genericOp = rewriter.create( - loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { + auto genericOp = linalg::GenericOp::create( + rewriter, loc, TypeRange{}, inputs, ValueRange{}, affineMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc, Value index, Value value) { Value index0 = - b.create(loc, b.getIndexType(), index); + arith::IndexCastOp::create(b, loc, b.getIndexType(), index); - b.create(loc, value, baseMemref, - ValueRange{index0}); + memref::StoreOp::create(b, loc, value, baseMemref, + ValueRange{index0}); }; auto offset = args[0]; @@ -362,14 +360,14 @@ struct ScatterConverter : public OpConversionPattern { // If the mask value is truthy, insert the current value to the // the base memref using its offset. Otherwise, noop. auto mask = args[2]; - auto ifOp = - b.create(loc, mask, [&](OpBuilder &b, Location loc) { + auto ifOp = scf::IfOp::create( + b, loc, mask, [&](OpBuilder &b, Location loc) { storeValueAtIndex(b, loc, offset, value); - b.create(loc); + scf::YieldOp::create(b, loc); }); } - b.create(loc); + linalg::YieldOp::create(b, loc); }); rewriter.eraseOp(scatterOp); diff --git a/lib/Dialect/TPtr/IR/TPtrDialect.cpp b/lib/Dialect/TPtr/IR/TPtrDialect.cpp index 3235a1e0..c4b8b7c7 100644 --- a/lib/Dialect/TPtr/IR/TPtrDialect.cpp +++ b/lib/Dialect/TPtr/IR/TPtrDialect.cpp @@ -51,27 +51,30 @@ void mlir::tptr::TPtrDialect::initialize() { } bool tptr::DefaultMemorySpaceAttr::isValidLoad( - Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, mlir::ptr::AtomicOrdering ordering, + std::optional alignment, const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } bool tptr::DefaultMemorySpaceAttr::isValidStore( - Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, mlir::ptr::AtomicOrdering ordering, + std::optional alignment, const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } bool tptr::DefaultMemorySpaceAttr::isValidAtomicOp( mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering, - IntegerAttr alignment, + std::optional alignment, const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } bool tptr::DefaultMemorySpaceAttr::isValidAtomicXchg( Type type, mlir::ptr::AtomicOrdering successOrdering, - mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + mlir::ptr::AtomicOrdering failureOrdering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, llvm::function_ref emitError) const { return true; } diff --git a/lib/Dialect/TPtr/IR/TPtrOps.cpp b/lib/Dialect/TPtr/IR/TPtrOps.cpp index dad589ea..6cd2f4ec 100644 --- a/lib/Dialect/TPtr/IR/TPtrOps.cpp +++ b/lib/Dialect/TPtr/IR/TPtrOps.cpp @@ -18,21 +18,3 @@ using namespace mlir; using namespace mlir::tptr; - -void LoadOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable(), - SideEffects::DefaultResource::get()); -} - -void StoreOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable(), - SideEffects::DefaultResource::get()); -} - -OpFoldResult TypeOffsetOp::fold(FoldAdaptor adaptor) { - return adaptor.getBaseTypeAttr(); -} diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index df7de996..150e926d 100644 --- a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -46,14 +46,14 @@ Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return builder.create(loc, resType, src); + return arith::SIToFPOp::create(builder, loc, resType, src); }) .Case([&](Operation *op) { auto resType = op->getResults()[0].getType(); if (auto shapedType = dyn_cast(resType)) { resType = shapedType.getElementType(); } - return builder.create(loc, resType, src); + return arith::TruncFOp::create(builder, loc, resType, src); }) .Default([](Operation *op) { llvm_unreachable("unsupported op in generating "); diff --git a/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp b/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp index da4d9767..acf118a5 100644 --- a/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp +++ b/lib/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.cpp @@ -29,12 +29,12 @@ Value getSlice(OpBuilder &b, Location loc, Value source, ArrayRef strides) { return TypeSwitch(source.getType()) .Case([&](RankedTensorType t) -> Value { - return b.create(loc, source, offsets, sizes, - strides); + return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes, + strides); }) .Case([&](MemRefType type) -> Value { - return b.create(loc, source, offsets, sizes, - strides); + return memref::SubViewOp::create(b, loc, source, offsets, sizes, + strides); }) .Default([&](Type t) { return nullptr; }); } diff --git a/python/examples/conftest.py b/python/examples/conftest.py index 4bac989c..33397c5a 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -119,7 +119,8 @@ def pytest_collection_modifyitems(config, items): for param_name, param_value in item.callspec.params.items(): if (param_name.startswith('dtype') or param_name.endswith('dtype')) and param_value == 'bfloat16': item.add_marker(skip_marker_bfloat) - if param_name.startswith('input_precision') and param_value.startswith('tf32'): + if param_name.startswith('input_precision') and (param_value.startswith('tf32') + or param_value.startswith('bf16')): item.add_marker(skip_marker_tf32) if (param_name.startswith('dtype') or param_name.endswith('dtype')) and ('float8' in str(param_value)): item.add_marker(skip_marker_float8) diff --git a/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir b/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir index decf2293..542b9033 100644 --- a/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir +++ b/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir @@ -2,7 +2,7 @@ module { func.func @bitcast_ptr_as_src(%arg0: memref<*xi32>, %arg1: memref<*xi32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %0 = tptr.type_offset i32 : i32 + %0 = ptr.type_offset i32 : i32 %c1_i32 = arith.constant 1 : i32 %c2 = arith.constant 2 : index %1 = builtin.unrealized_conversion_cast %arg1 : memref<*xi32> to !tt.ptr @@ -10,14 +10,14 @@ module { %3 = builtin.unrealized_conversion_cast %arg0 : memref<*xi32> to !tt.ptr %4 = builtin.unrealized_conversion_cast %3 : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> %5 = arith.muli %c1_i32, %0 : i32 - %6 = tptr.ptradd %4 %5 : !ptr.ptr<#tptr.default_memory_space>, i32 to !ptr.ptr<#tptr.default_memory_space> + %6 = ptr.ptr_add %4, %5 : !ptr.ptr<#tptr.default_memory_space>, i32 %7 = builtin.unrealized_conversion_cast %6 : !ptr.ptr<#tptr.default_memory_space> to !tt.ptr %8 = builtin.unrealized_conversion_cast %7 : !tt.ptr to memref<*xi64> %reinterpret_cast = memref.reinterpret_cast %8 to offset: [%c2], sizes: [16], strides: [1] : memref<*xi64> to memref<16xi64, strided<[1], offset: ?>> %alloc = memref.alloc() : memref<16xi64> memref.copy %reinterpret_cast, %alloc : memref<16xi64, strided<[1], offset: ?>> to memref<16xi64> %9 = bufferization.to_tensor %alloc restrict writable : memref<16xi64> to tensor<16xi64> - %10 = tptr.ptradd %2 %5 : !ptr.ptr<#tptr.default_memory_space>, i32 to !ptr.ptr<#tptr.default_memory_space> + %10 = ptr.ptr_add %2, %5 : !ptr.ptr<#tptr.default_memory_space>, i32 %11 = builtin.unrealized_conversion_cast %10 : !ptr.ptr<#tptr.default_memory_space> to !tt.ptr %12 = builtin.unrealized_conversion_cast %11 : !tt.ptr to memref<*xi64> %reinterpret_cast_0 = memref.reinterpret_cast %12 to offset: [%c2], sizes: [16], strides: [1] : memref<*xi64> to memref<16xi64, strided<[1], offset: ?>> diff --git a/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir b/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir index d8308a17..1be6bbd1 100644 --- a/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir +++ b/test/Conversion/StructuredToMemref/convert_addi_reduce.mlir @@ -13,22 +13,26 @@ module { } } -// CHECK-LABEL: func.func @addi -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_1_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_1_]] into [[VAR_2_]][] : tensor -// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] -// CHECK: ([[in_:.+]]: i32, [[init_:.+]]: i32) { -// CHECK: [[VAR_3_:%.+]] = arith.addi [[in_]], [[init_]] : i32 -// CHECK: linalg.yield [[VAR_3_]] : i32 +// CHECK-LABEL: func.func @addi( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xi32>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : i32) outs(%[[EMPTY_0]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[ALLOC_TENSOR_0:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[INSERT_0:.*]] = tensor.insert %[[CONSTANT_0]] into %[[ALLOC_TENSOR_0]][] : tensor +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[FILL_0]] : tensor<4096xi32>) outs(%[[INSERT_0]] : tensor) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32) { +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : i32 +// CHECK: linalg.yield %[[ADDI_0]] : i32 // CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1], offset: ?>> +// CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store %[[EXTRACT_0]], %[[REINTERPRET_CAST_0]][0] : memref<1xi32, strided<[1]>> // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir b/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir index a753cd61..77a7e5ae 100644 --- a/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir +++ b/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir @@ -1,5 +1,62 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + + // RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG9:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 130 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 259 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 3 : index +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 128 : index +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 256 : index +// CHECK: %[[CONSTANT_6:.*]] = arith.constant 0xFF80 : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK: %[[MINSI_0:.*]] = arith.minsi %[[INDEX_CAST_0]], %[[CONSTANT_1]] : index +// CHECK: %[[MAXSI_0:.*]] = arith.maxsi %[[MINSI_0]], %[[CONSTANT_0]] : index +// CHECK: %[[SUBI_0:.*]] = arith.subi %[[MAXSI_0]], %[[CONSTANT_0]] : index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[MINSI_1:.*]] = arith.minsi %[[INDEX_CAST_1]], %[[CONSTANT_2]] : index +// CHECK: %[[MAXSI_1:.*]] = arith.maxsi %[[MINSI_1]], %[[CONSTANT_3]] : index +// CHECK: %[[SUBI_1:.*]] = arith.subi %[[MAXSI_1]], %[[CONSTANT_3]] : index +// CHECK: %[[MINSI_2:.*]] = arith.minsi %[[SUBI_0]], %[[CONSTANT_4]] : index +// CHECK: %[[MINSI_3:.*]] = arith.minsi %[[SUBI_1]], %[[CONSTANT_5]] : index +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<128x256xbf16> +// CHECK: %[[CMPI_0:.*]] = arith.cmpi slt, %[[MINSI_2]], %[[CONSTANT_4]] : index +// CHECK: %[[CMPI_1:.*]] = arith.cmpi slt, %[[MINSI_3]], %[[CONSTANT_5]] : index +// CHECK: %[[ORI_0:.*]] = arith.ori %[[CMPI_0]], %[[CMPI_1]] : i1 +// CHECK: scf.if %[[ORI_0]] { +// CHECK: linalg.fill ins(%[[CONSTANT_6]] : bf16) outs(%[[ALLOC_0]] : memref<128x256xbf16>) +// CHECK: } +// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[REINTERPRET_CAST_0]][0, 0] {{\[}}%[[MINSI_2]], %[[MINSI_3]]] [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: %[[SUBVIEW_1:.*]] = memref.subview %[[ALLOC_0]][0, 0] {{\[}}%[[MINSI_2]], %[[MINSI_3]]] [1, 1] : memref<128x256xbf16> to memref> +// CHECK: memref.copy %[[SUBVIEW_0]], %[[SUBVIEW_1]] : memref> to memref> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<128x256xbf16> to tensor<128x256xbf16> +// CHECK: %[[EXTRACT_SLICE_0:.*]] = tensor.extract_slice %[[TO_TENSOR_0]][0, 0] {{\[}}%[[MINSI_2]], %[[MINSI_3]]] [1, 1] : tensor<128x256xbf16> to tensor +// CHECK: %[[SUBVIEW_2:.*]] = memref.subview %[[REINTERPRET_CAST_1]][0, 0] {{\[}}%[[MINSI_2]], %[[MINSI_3]]] [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[SUBVIEW_2]] : memref> to memref> +// CHECK: bufferization.materialize_in_destination %[[EXTRACT_SLICE_0]] in writable %[[CAST_0]] : (tensor, memref>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -61,44 +118,3 @@ module { tt.return } } - -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_130_:%.+]] = arith.constant 130 : index -// CHECK-DAG: [[CST_259_:%.+]] = arith.constant 259 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 -// CHECK-DAG: [[CST_3074_:%.+]] = arith.constant 3074 : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_3074_]]{{.}}, sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: ?>> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_3074_]]{{.}}, sizes: [128, 256], strides: [1, 1024] : memref<*xbf16> to memref<128x256xbf16, strided<[1, 1024], offset: ?>> -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_130_]] : index -// CHECK: [[VAR_2_:%.+]] = arith.maxsi [[VAR_1_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.subi [[VAR_2_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_4_]], [[CST_259_]] : index -// CHECK: [[VAR_6_:%.+]] = arith.maxsi [[VAR_5_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.minsi [[VAR_3_]], [[CST_128_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[CST_256_]] : index -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x256xbf16> -// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_128_]] : index -// CHECK: [[VAR_11_:%.+]] = arith.cmpi slt, [[VAR_9_]], [[CST_256_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.ori [[VAR_10_]], [[VAR_11_]] : i1 -// CHECK: scf.if [[VAR_12_]] { -// CHECK: linalg.fill ins([[CST_0_]] : bf16) outs([[RES_]] : memref<128x256xbf16>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> -// CHECK: [[VAR_13_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x256xbf16> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_13_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_2_]] : (tensor, memref>) -> () -// CHECK: return -// CHECK: } diff --git a/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir b/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir index 5c7f4f4e..81864ce7 100644 --- a/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir +++ b/test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir @@ -55,59 +55,66 @@ module { } -// CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_2_1_]] : index -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_6_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.muli [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_1_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[VAR_8_]] : i32 to index -// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : i32 to index -// CHECK-DAG: [[VAR_12_:%.+]]:2 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_1_]], [[VAR_arg16_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg16_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_6_]], [[VAR_7_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_4_]] : index -// CHECK: [[VAR_14_:%.+]] = arith.remsi [[VAR_13_]], [[VAR_5_]] : index -// CHECK-DAG: [[VAR_15_:%.+]] = arith.subi [[VAR_13_]], [[VAR_14_]] : index -// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_14_]], [[CST_4_]] : index -// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_16_]], [[VAR_5_]] : index -// CHECK: [[VAR_18_:%.+]] = arith.subi [[VAR_17_]], [[VAR_14_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_18_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_19_:%.+]] = arith.subi [[CST_4_]], [[VAR_18_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_19_]]{{.}} [1, 1] : memref> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_18_]]{{.}} [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>> -// CHECK: [[VAR_20_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_20_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> () -// CHECK-DAG: [[VAR_21_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index -// CHECK-DAG: [[VAR_22_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_11_]] : index -// CHECK: scf.yield [[VAR_21_]], [[VAR_22_]] : index, index +// CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG9:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG10:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG11:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG12:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG13:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 4 : i32 +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 2 : index +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 6 : index +// CHECK: %[[CONSTANT_6:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_7:.*]] = arith.constant -9.900000e+01 : f32 +// CHECK: %[[CONSTANT_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG4]] : i32 to index +// CHECK: %[[MULI_0:.*]] = arith.muli %[[INDEX_CAST_0]], %[[CONSTANT_4]] : index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[ARG5]] : i32 to index +// CHECK: %[[MULI_1:.*]] = arith.muli %[[INDEX_CAST_2]], %[[CONSTANT_5]] : index +// CHECK: %[[MULI_2:.*]] = arith.muli %[[INDEX_CAST_1]], %[[INDEX_CAST_2]] : index +// CHECK: %[[INDEX_CAST_3:.*]] = arith.index_cast %[[ARG6]] : i32 to index +// CHECK: %[[INDEX_CAST_4:.*]] = arith.index_cast %[[ARG7]] : i32 to index +// CHECK: %[[MULI_3:.*]] = arith.muli %[[ARG4]], %[[CONSTANT_1]] : i32 +// CHECK: %[[INDEX_CAST_5:.*]] = arith.index_cast %[[MULI_3]] : i32 to index +// CHECK: %[[MULI_4:.*]] = arith.muli %[[ARG5]], %[[CONSTANT_1]] : i32 +// CHECK: %[[INDEX_CAST_6:.*]] = arith.index_cast %[[MULI_4]] : i32 to index +// CHECK: %[[FOR_0:.*]]:2 = scf.for %[[VAL_0:.*]] = %[[CONSTANT_8]] to %[[CONSTANT_2]] step %[[CONSTANT_3]] iter_args(%[[VAL_1:.*]] = %[[MULI_0]], %[[VAL_2:.*]] = %[[CONSTANT_6]]) -> (index, index) : i32 { +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[VAL_2]]], sizes: [4, 4], strides: {{\[}}%[[INDEX_CAST_3]], %[[INDEX_CAST_4]]] : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_1]], %[[MULI_1]] : index +// CHECK: %[[REMSI_0:.*]] = arith.remsi %[[ADDI_0]], %[[MULI_2]] : index +// CHECK: %[[SUBI_0:.*]] = arith.subi %[[ADDI_0]], %[[REMSI_0]] : index +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[REMSI_0]], %[[CONSTANT_0]] : index +// CHECK: %[[MINSI_0:.*]] = arith.minsi %[[ADDI_1]], %[[MULI_2]] : index +// CHECK: %[[SUBI_1:.*]] = arith.subi %[[MINSI_0]], %[[REMSI_0]] : index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [4, %[[SUBI_1]]], strides: {{\[}}%[[INDEX_CAST_0]], %[[INDEX_CAST_2]]] : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[SUBI_2:.*]] = arith.subi %[[CONSTANT_0]], %[[SUBI_1]] : index +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[SUBI_0]]], sizes: [4, %[[SUBI_2]]], strides: {{\[}}%[[INDEX_CAST_0]], %[[INDEX_CAST_2]]] : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x4xf32> +// CHECK: linalg.fill ins(%[[CONSTANT_7]] : f32) outs(%[[ALLOC_0]] : memref<4x4xf32>) +// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[REINTERPRET_CAST_1]][0, 0] [2, %[[SUBI_1]]] [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[SUBVIEW_1:.*]] = memref.subview %[[REINTERPRET_CAST_2]][0, 0] [2, %[[SUBI_2]]] [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[SUBVIEW_2:.*]] = memref.subview %[[ALLOC_0]][0, 0] [2, %[[SUBI_1]]] [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> +// CHECK: %[[SUBVIEW_3:.*]] = memref.subview %[[ALLOC_0]][0, %[[SUBI_1]]] [2, %[[SUBI_2]]] [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> +// CHECK: memref.copy %[[SUBVIEW_0]], %[[SUBVIEW_2]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> +// CHECK: memref.copy %[[SUBVIEW_1]], %[[SUBVIEW_3]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x4xf32> to tensor<4x4xf32> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[REINTERPRET_CAST_0]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> () +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_1]], %[[INDEX_CAST_5]] : index +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[VAL_2]], %[[INDEX_CAST_6]] : index +// CHECK: scf.yield %[[ADDI_2]], %[[ADDI_3]] : index, index // CHECK: } // CHECK: return -// CHECK: } +// CHECK: } \ No newline at end of file diff --git a/test/Conversion/StructuredToMemref/wraparound_stacked.mlir b/test/Conversion/StructuredToMemref/wraparound_stacked.mlir index 4e5ce945..1fe4fbc7 100644 --- a/test/Conversion/StructuredToMemref/wraparound_stacked.mlir +++ b/test/Conversion/StructuredToMemref/wraparound_stacked.mlir @@ -52,56 +52,64 @@ module { } } -// CHECK-LABEL: func.func @wrap_stacked_masked_loop_01234567 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_1_]] : index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_0_]], [[VAR_1_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[VAR_8_]] : i32 to index -// CHECK-DAG: [[VAR_10_:%.+]]:2 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_2_]], [[VAR_arg16_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg16_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_6_]], [[VAR_7_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_5_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.remsi [[VAR_11_]], [[VAR_1_]] : index -// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_3_]], [[VAR_12_]] : index -// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_11_]] : index -// CHECK: [[VAR_15_:%.+]] = arith.divsi [[VAR_14_]], [[VAR_1_]] : index -// CHECK: [[VAR_16_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.subi [[CST_4_]], [[VAR_16_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_17_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref> to memref> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref> to memref> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_16_]], 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref> to memref> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_18_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_18_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> () -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index -// CHECK-DAG: [[VAR_20_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_9_]] : index -// CHECK: scf.yield [[VAR_19_]], [[VAR_20_]] : index, index +// CHECK-LABEL: func.func @wrap_stacked_masked_loop_01234567( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG9:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG10:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG11:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG12:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG13:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 4 : i32 +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 2 : index +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 3 : index +// CHECK: %[[CONSTANT_6:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_7:.*]] = arith.constant -9.900000e+01 : f32 +// CHECK: %[[CONSTANT_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG4]] : i32 to index +// CHECK: %[[MULI_0:.*]] = arith.muli %[[INDEX_CAST_1]], %[[CONSTANT_4]] : index +// CHECK: %[[MULI_1:.*]] = arith.muli %[[INDEX_CAST_0]], %[[INDEX_CAST_1]] : index +// CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[ARG5]] : i32 to index +// CHECK: %[[MULI_2:.*]] = arith.muli %[[INDEX_CAST_2]], %[[CONSTANT_5]] : index +// CHECK: %[[INDEX_CAST_3:.*]] = arith.index_cast %[[ARG6]] : i32 to index +// CHECK: %[[INDEX_CAST_4:.*]] = arith.index_cast %[[ARG7]] : i32 to index +// CHECK: %[[MULI_3:.*]] = arith.muli %[[ARG5]], %[[CONSTANT_1]] : i32 +// CHECK: %[[INDEX_CAST_5:.*]] = arith.index_cast %[[MULI_3]] : i32 to index +// CHECK: %[[FOR_0:.*]]:2 = scf.for %[[VAL_0:.*]] = %[[CONSTANT_8]] to %[[CONSTANT_2]] step %[[CONSTANT_3]] iter_args(%[[VAL_1:.*]] = %[[MULI_0]], %[[VAL_2:.*]] = %[[CONSTANT_6]]) -> (index, index) : i32 { +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[VAL_2]]], sizes: [4, 4], strides: {{\[}}%[[INDEX_CAST_3]], %[[INDEX_CAST_4]]] : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_1]], %[[MULI_2]] : index +// CHECK: %[[REMSI_0:.*]] = arith.remsi %[[ADDI_0]], %[[INDEX_CAST_1]] : index +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[MULI_1]], %[[REMSI_0]] : index +// CHECK: %[[SUBI_0:.*]] = arith.subi %[[ADDI_1]], %[[ADDI_0]] : index +// CHECK: %[[DIVSI_0:.*]] = arith.divsi %[[SUBI_0]], %[[INDEX_CAST_1]] : index +// CHECK: %[[MINSI_0:.*]] = arith.minsi %[[DIVSI_0]], %[[CONSTANT_0]] : index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: {{\[}}%[[MINSI_0]], 4], strides: {{\[}}%[[INDEX_CAST_1]], %[[INDEX_CAST_2]]] : memref<*xf32> to memref> +// CHECK: %[[SUBI_1:.*]] = arith.subi %[[CONSTANT_0]], %[[MINSI_0]] : index +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[REMSI_0]]], sizes: {{\[}}%[[SUBI_1]], 4], strides: {{\[}}%[[INDEX_CAST_1]], %[[INDEX_CAST_2]]] : memref<*xf32> to memref> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x4xf32> +// CHECK: linalg.fill ins(%[[CONSTANT_7]] : f32) outs(%[[ALLOC_0]] : memref<4x4xf32>) +// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[REINTERPRET_CAST_1]][0, 0] {{\[}}%[[MINSI_0]], 3] [1, 1] : memref> to memref> +// CHECK: %[[SUBVIEW_1:.*]] = memref.subview %[[REINTERPRET_CAST_2]][0, 0] {{\[}}%[[SUBI_1]], 3] [1, 1] : memref> to memref> +// CHECK: %[[SUBVIEW_2:.*]] = memref.subview %[[ALLOC_0]][0, 0] {{\[}}%[[MINSI_0]], 3] [1, 1] : memref<4x4xf32> to memref> +// CHECK: %[[SUBVIEW_3:.*]] = memref.subview %[[ALLOC_0]]{{\[}}%[[MINSI_0]], 0] {{\[}}%[[SUBI_1]], 3] [1, 1] : memref<4x4xf32> to memref> +// CHECK: memref.copy %[[SUBVIEW_0]], %[[SUBVIEW_2]] : memref> to memref> +// CHECK: memref.copy %[[SUBVIEW_1]], %[[SUBVIEW_3]] : memref> to memref> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x4xf32> to tensor<4x4xf32> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[REINTERPRET_CAST_0]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> () +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_1]], %[[INDEX_CAST_5]] : index +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[VAL_2]], %[[INDEX_CAST_5]] : index +// CHECK: scf.yield %[[ADDI_2]], %[[ADDI_3]] : index, index // CHECK: } // CHECK: return -// CHECK: } +// CHECK: } \ No newline at end of file diff --git a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir b/test/Conversion/TritonToLinalg/addptr_2d_example.mlir index f0f7d1c7..82fec127 100644 --- a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir +++ b/test/Conversion/TritonToLinalg/addptr_2d_example.mlir @@ -1,5 +1,47 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG9:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[INDEX_CAST_0]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[INDEX_CAST_1]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_1]], %[[ALLOC_1]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_1:.*]] = bufferization.to_tensor %[[ALLOC_1]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[TO_TENSOR_0]], %[[TO_TENSOR_1]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[TO_TENSOR_0]] : tensor<4x256xbf16>) { +// CHECK: ^bb0(%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: bf16): +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG2]] to offset: {{\[}}%[[INDEX_CAST_2]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_2]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16, strided<[1, ?], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[GENERIC_0]] in writable %[[CAST_0]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -44,26 +86,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = arith.constant 5 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_15]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: bf16, %[[VAL_18:.*]]: bf16, %[[VAL_19:.*]]: bf16): -// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16 -// CHECK: linalg.yield %[[VAL_20]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_21]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_add_value.mlir b/test/Conversion/TritonToLinalg/addptr_add_value.mlir index 0ed60796..ace0df75 100644 --- a/test/Conversion/TritonToLinalg/addptr_add_value.mlir +++ b/test/Conversion/TritonToLinalg/addptr_add_value.mlir @@ -1,5 +1,43 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG9:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 10 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[INDEX_CAST_0]], %[[INDEX_CAST_1]] : index +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[CONSTANT_0]] : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_1]]], sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK: %[[INDEX_CAST_3:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[INDEX_CAST_2]], %[[INDEX_CAST_3]] : index +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[ADDI_2]], %[[CONSTANT_0]] : index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[ADDI_3]]], sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_1]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16, strided<[1, ?], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -46,23 +84,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 10 : index -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : index -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_8]] : index -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_19]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_19]] restrict writable : memref<4x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_20]] in writable %[[VAL_18]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir b/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir index 89cb4590..72d21969 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir @@ -1,5 +1,58 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG9:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG10:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 12 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[INDEX_CAST_0]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[INDEX_CAST_1]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_1]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16, strided<[?, ?], offset: ?>> +// CHECK: %[[FOR_0:.*]]:3 = scf.for %[[VAL_0:.*]] = %[[CONSTANT_2]] to %[[CONSTANT_1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[TO_TENSOR_0]], %[[VAL_2:.*]] = %[[CAST_0]], %[[VAL_3:.*]] = %[[INDEX_CAST_1]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index) { +// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_2]], %[[ALLOC_1]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_1:.*]] = bufferization.to_tensor %[[ALLOC_1]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_1]], %[[TO_TENSOR_1]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_1]] : tensor<4x256xbf16>) { +// CHECK: ^bb0(%[[VAL_4:.*]]: bf16, %[[VAL_5:.*]]: bf16, %[[VAL_6:.*]]: bf16): +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_0]] : index +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[ADDI_0]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[CAST_1:.*]] = memref.cast %[[REINTERPRET_CAST_2]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16, strided<[?, ?], offset: ?>> +// CHECK: scf.yield %[[GENERIC_0]], %[[CAST_1]], %[[ADDI_0]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index +// CHECK: } +// CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_3:.*]] = memref.reinterpret_cast %[[ARG2]] to offset: {{\[}}%[[INDEX_CAST_2]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[CAST_2:.*]] = memref.cast %[[REINTERPRET_CAST_3]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16, strided<[1, ?], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_7:.*]]#0 in writable %[[CAST_2]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -58,35 +111,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_1O:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: %[[VAL_19:.*]]:3 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index) { -// CHECK: %[[VAL_25:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_22]], %[[VAL_25]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_25]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_21]], %[[VAL_26]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_21]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_28:.*]]: bf16, %[[VAL_29:.*]]: bf16, %[[VAL_30:.*]]: bf16): -// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : bf16 -// CHECK: linalg.yield %[[VAL_31]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index -// CHECK: %[[VAL_34:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_33]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index -// CHECK: } -// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_36]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in writable %[[VAL_37]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir b/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir index 67d82948..15010680 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir @@ -1,5 +1,40 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 256 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1024 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 12 : index +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 3 : index +// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[CONSTANT_2]] to %[[CONSTANT_3]] step %[[CONSTANT_4]] iter_args(%[[VAL_1:.*]] = %[[CONSTANT_1]]) -> (index) { +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_1]], %[[CONSTANT_0]] : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [256, 256], strides: [2, 1] : memref<*xbf16> to memref<256x256xbf16, strided<[2, 1], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256x256xbf16, strided<[2, 1], offset: ?>> to memref<256x256xbf16, strided<[?, 1], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<256x256xbf16, strided<[2, 1], offset: ?>> to memref<256x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256x256xbf16> to tensor<256x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<256x256xbf16>, memref<256x256xbf16, strided<[?, 1], offset: ?>>) -> () +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_1]], %[[CONSTANT_4]] : index +// CHECK: scf.yield %[[ADDI_1]] : index +// CHECK: } +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr ) @@ -51,23 +86,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (index) { -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index -// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [256, 256], strides: {{\[}}%[[VAL_4]], 1] : memref<*xbf16> to memref<256x256xbf16, strided<[?, 1], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<256x256xbf16> -// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<256x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_16]] in writable %[[VAL_14]] -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : index -// CHECK: scf.yield %[[VAL_17]] : index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir b/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir index 4d77760e..3ae2f33c 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir @@ -1,5 +1,53 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 2 : index +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 3 : index +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 12 : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [1024], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: 1024>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256xbf16, strided<[1], offset: 1024>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [1024], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: 1024>> +// CHECK: %[[CAST_1:.*]] = memref.cast %[[REINTERPRET_CAST_1]] : memref<256xbf16, strided<[1], offset: 1024>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[FOR_0:.*]]:7 = scf.for %[[VAL_0:.*]] = %[[CONSTANT_1]] to %[[CONSTANT_5]] step %[[CONSTANT_4]] iter_args(%[[VAL_1:.*]] = %[[CONSTANT_2]], %[[VAL_2:.*]] = %[[CAST_0]], %[[VAL_3:.*]] = %[[CONSTANT_3]], %[[VAL_4:.*]] = %[[CAST_1]], %[[VAL_5:.*]] = %[[CONSTANT_4]], %[[VAL_6:.*]] = %[[CONSTANT_0]], %[[VAL_7:.*]] = %[[CONSTANT_0]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) { +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256xbf16> +// CHECK: memref.copy %[[VAL_2]], %[[ALLOC_0]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256xbf16> to tensor<256xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[VAL_4]] : (tensor<256xbf16>, memref<256xbf16, strided<[?], offset: ?>>) -> () +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_6]], %[[CONSTANT_4]] : index +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>> +// CHECK: %[[CAST_2:.*]] = memref.cast %[[REINTERPRET_CAST_2]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_1]], %[[CONSTANT_4]] : index +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_4]] : index +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[VAL_5]], %[[CONSTANT_4]] : index +// CHECK: %[[ADDI_4:.*]] = arith.addi %[[ADDI_1]], %[[ADDI_2]] : index +// CHECK: %[[ADDI_5:.*]] = arith.addi %[[ADDI_4]], %[[ADDI_3]] : index +// CHECK: %[[ADDI_6:.*]] = arith.addi %[[VAL_7]], %[[ADDI_5]] : index +// CHECK: %[[REINTERPRET_CAST_3:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[ADDI_6]]], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>> +// CHECK: %[[CAST_3:.*]] = memref.cast %[[REINTERPRET_CAST_3]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: scf.yield %[[ADDI_1]], %[[CAST_2]], %[[ADDI_2]], %[[CAST_3]], %[[ADDI_3]], %[[ADDI_0]], %[[ADDI_6]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index +// CHECK: } +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr @@ -41,31 +89,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_14:.*]]:7 = scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_9]], %[[VAL_19:.*]] = %[[VAL_13]], %[[VAL_20:.*]] = %[[VAL_10]], %[[VAL_21:.*]] = %[[VAL_6]], %[[VAL_22:.*]] = %[[VAL_6]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) { -// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_24]] in writable %[[VAL_19]] -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index -// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index -// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_10]] : index -// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_20]], %[[VAL_10]] : index -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index -// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]] : index -// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_22]], %[[VAL_31]] : index -// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: scf.yield %[[VAL_27]], %[[VAL_26]], %[[VAL_28]], %[[VAL_33]], %[[VAL_29]], %[[VAL_25]], %[[VAL_32]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir index 60b0b7fc..4cf7a3e1 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir @@ -1,5 +1,38 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 12 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 3 : index +// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[CONSTANT_1]] to %[[CONSTANT_2]] step %[[CONSTANT_3]] iter_args(%[[VAL_1:.*]] = %[[CONSTANT_0]]) -> (index) { +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_1]], %[[CONSTANT_3]] : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256xbf16> to tensor<256xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<256xbf16>, memref<256xbf16, strided<[?], offset: ?>>) -> () +// CHECK: scf.yield %[[ADDI_0]] : index +// CHECK: } +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr ) @@ -78,21 +111,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_5]]) -> (index) { -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_13]] -// CHECK: scf.yield %[[VAL_12]] : index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir index 7855730a..8875f9a4 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir @@ -1,5 +1,40 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 12 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 3 : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [1024], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: 1024>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256xbf16, strided<[1], offset: 1024>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[FOR_0:.*]]:2 = scf.for %[[VAL_0:.*]] = %[[CONSTANT_1]] to %[[CONSTANT_2]] step %[[CONSTANT_3]] iter_args(%[[VAL_1:.*]] = %[[CAST_0]], %[[VAL_2:.*]] = %[[CONSTANT_0]]) -> (memref<256xbf16, strided<[?], offset: ?>>, index) { +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256xbf16> +// CHECK: memref.copy %[[VAL_1]], %[[ALLOC_0]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256xbf16> to tensor<256xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[VAL_1]] : (tensor<256xbf16>, memref<256xbf16, strided<[?], offset: ?>>) -> () +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_3]] : index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>> +// CHECK: %[[CAST_1:.*]] = memref.cast %[[REINTERPRET_CAST_1]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: scf.yield %[[CAST_1]], %[[ADDI_0]] : memref<256xbf16, strided<[?], offset: ?>>, index +// CHECK: } +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr ) @@ -34,22 +69,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_5]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]], %[[VAL_13:.*]] = %[[VAL_5]]) -> (memref<256xbf16, strided<[?], offset: ?>>, index) { -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_12]] -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_16]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: scf.yield %[[VAL_17]], %[[VAL_16]] : memref<256xbf16, strided<[?], offset: ?>>, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_loopback.mlir index ee5cb2cc..ca190aca 100644 --- a/test/Conversion/TritonToLinalg/addptr_loopback.mlir +++ b/test/Conversion/TritonToLinalg/addptr_loopback.mlir @@ -1,5 +1,35 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[INDEX_CAST_0]]], sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[INDEX_CAST_1]]], sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_1]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16, strided<[1, ?], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -38,16 +68,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_7]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<4x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_10]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir index 61ddea4f..eaff44c5 100644 --- a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir +++ b/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir @@ -1,5 +1,36 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 20480 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG6]] : i32 to index +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[INDEX_CAST_0]], %[[CONSTANT_0]] : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [1024], strides: [11] : memref<*xbf16> to memref<1024xbf16, strided<[11], offset: ?>> +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG6]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: {{\[}}%[[INDEX_CAST_1]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<1024xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<1024xbf16, strided<[11], offset: ?>> to memref<1024xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<1024xbf16> to tensor<1024xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<1024xbf16>, memref<1024xbf16, strided<[1], offset: ?>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -32,18 +63,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 11 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 20480 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}%[[VAL_6]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1024xbf16> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> -// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_14]] in writable %[[VAL_12]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_nested.mlir b/test/Conversion/TritonToLinalg/addptr_nested.mlir index bbbc0b22..5e06b5f9 100644 --- a/test/Conversion/TritonToLinalg/addptr_nested.mlir +++ b/test/Conversion/TritonToLinalg/addptr_nested.mlir @@ -1,5 +1,51 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG1]] : i32 to index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[INDEX_CAST_0]]], sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG1]] : i32 to index +// CHECK: %[[INDEX_CAST_2:.*]] = arith.index_cast %[[ARG1]] : i32 to index +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[INDEX_CAST_1]], %[[INDEX_CAST_2]] : index +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_0]]], sizes: [4, 256], strides: [2, 10] : memref<*xbf16> to memref<4x256xbf16, strided<[2, 10], offset: ?>> +// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_1]], %[[ALLOC_1]] : memref<4x256xbf16, strided<[2, 10], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[TO_TENSOR_1:.*]] = bufferization.to_tensor %[[ALLOC_1]] restrict writable : memref<4x256xbf16> to tensor<4x256xbf16> +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[TO_TENSOR_0]], %[[TO_TENSOR_1]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[TO_TENSOR_0]] : tensor<4x256xbf16>) { +// CHECK: ^bb0(%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: bf16): +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: %[[INDEX_CAST_3:.*]] = arith.index_cast %[[ARG1]] : i32 to index +// CHECK: %[[INDEX_CAST_4:.*]] = arith.index_cast %[[ARG1]] : i32 to index +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[INDEX_CAST_3]], %[[INDEX_CAST_4]] : index +// CHECK: %[[INDEX_CAST_5:.*]] = arith.index_cast %[[ARG1]] : i32 to index +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[INDEX_CAST_5]] : index +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[ADDI_2]]], sizes: [4, 256], strides: [3, 15] : memref<*xbf16> to memref<4x256xbf16, strided<[3, 15], offset: ?>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_2]] : memref<4x256xbf16, strided<[3, 15], offset: ?>> to memref<4x256xbf16, strided<[3, ?], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[GENERIC_0]] in writable %[[CAST_0]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[3, ?], offset: ?>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : i32 @@ -40,34 +86,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 15 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 10 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index -// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_14]]], sizes: [4, 256], strides: [2, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[2, ?], offset: ?>> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_15]], %[[VAL_16]] : memref<4x256xbf16, strided<[2, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_17]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: bf16, %[[VAL_20:.*]]: bf16, %[[VAL_21:.*]]: bf16): -// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : bf16 -// CHECK: linalg.yield %[[VAL_22]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : index -// CHECK: %[[VAL_26:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index -// CHECK: %[[VAL_28:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_27]]], sizes: [4, 256], strides: [3, %[[VAL_5]]] : memref<*xbf16> to memref<4x256xbf16, strided<[3, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_29:.*]] in writable %[[VAL_28]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir index 2f508262..ee854b90 100644 --- a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir +++ b/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir @@ -1,6 +1,32 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s // TODO: expand this example to 3D module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [6656], sizes: [256, 128], strides: [1, 6] : memref<*xbf16> to memref<256x128xbf16, strided<[1, 6], offset: 6656>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256x128xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256x128xbf16> to tensor<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<256x128xbf16>, memref<256x128xbf16, strided<[1, ?], offset: 6656>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr @@ -31,13 +57,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: {{\[}}1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/dot.mlir b/test/Conversion/TritonToLinalg/dot.mlir index 95cb91b7..bf777839 100644 --- a/test/Conversion/TritonToLinalg/dot.mlir +++ b/test/Conversion/TritonToLinalg/dot.mlir @@ -1,5 +1,52 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [128, 64], strides: [128, 1] : memref<*xbf16> to memref<128x64xbf16, strided<[128, 1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<128x64xbf16, strided<[128, 1]>> to memref<128x64xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<128x64xbf16> to tensor<128x64xbf16> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [256, 64], strides: [1, 256] : memref<*xbf16> to memref<256x64xbf16, strided<[1, 256]>> +// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<256x64xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_1]], %[[ALLOC_1]] : memref<256x64xbf16, strided<[1, 256]>> to memref<256x64xbf16> +// CHECK: %[[TO_TENSOR_1:.*]] = bufferization.to_tensor %[[ALLOC_1]] restrict writable : memref<256x64xbf16> to tensor<256x64xbf16> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<64x256xbf16> +// CHECK: %[[TRANSPOSE_0:.*]] = linalg.transpose ins(%[[TO_TENSOR_1]] : tensor<256x64xbf16>) outs(%[[EMPTY_0]] : tensor<64x256xbf16>) permutation = [1, 0] +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG2]] to offset: [0], sizes: [128, 256], strides: [256, 1] : memref<*xbf16> to memref<128x256xbf16, strided<[256, 1]>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_2]] : memref<128x256xbf16, strided<[256, 1]>> to memref<128x256xbf16, strided<[?, 1]>> +// CHECK: %[[ALLOC_2:.*]] = memref.alloc() : memref<128x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_2]], %[[ALLOC_2]] : memref<128x256xbf16, strided<[256, 1]>> to memref<128x256xbf16> +// CHECK: %[[TO_TENSOR_2:.*]] = bufferization.to_tensor %[[ALLOC_2]] restrict writable : memref<128x256xbf16> to tensor<128x256xbf16> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_1]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: %[[MATMUL_0:.*]] = linalg.matmul ins(%[[TO_TENSOR_0]], %[[TRANSPOSE_0]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%[[FILL_0]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[TO_TENSOR_2]], %[[MATMUL_0]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%[[TO_TENSOR_2]] : tensor<128x256xbf16>) { +// CHECK: ^bb0(%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: bf16): +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } -> tensor<128x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[GENERIC_0]] in writable %[[CAST_0]] : (tensor<128x256xbf16>, memref<128x256xbf16, strided<[?, 1]>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -49,36 +96,3 @@ module { } } -// CHECK-DAG: [[MAP_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: {{.}}[[CST_128_]], 1] : memref<*xbf16> to memref<128x64xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[?, 1]>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [256, 64], strides: [1, [[CST_256_]]{{.}} : memref<*xbf16> to memref<256x64xbf16, strided<[1, ?]>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<256x64xbf16, strided<[1, ?]>> to memref<256x64xbf16> -// CHECK-DAG: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x64xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<64x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_transposed_:%.+]] = linalg.transpose ins([[VAR_1_]] : tensor<256x64xbf16>) outs([[VAR_2_]] : tensor<64x256xbf16>) permutation = [1, 0] -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<128x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<128x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_2_]] : memref<128x256xbf16, strided<[?, 1]>> to memref<128x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<128x256xbf16> -// CHECK: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [[[MAP_]], [[MAP_]], [[MAP_]]], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) { -// CHECK: ^bb0([[VAR_in_1:%.+]]: bf16, [[VAR_in_2:%.+]]: bf16, {{%.+}}: bf16): -// CHECK: [[VAR_8_:%.+]] = arith.addf [[VAR_in_1]], [[VAR_in_2]] : bf16 -// CHECK: linalg.yield [[VAR_8_:%.+]] : bf16 -// CHECK: } -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir b/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir index 78afe418..8987d71f 100644 --- a/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir +++ b/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir @@ -1,5 +1,38 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<256x16xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0xFF80 : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32, 256, 16], strides: [256, 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[256, 1, 1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<32x256x16xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<32x256x16xbf16, strided<[256, 1, 1]>> to memref<32x256x16xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<32x256x16xbf16> to tensor<32x256x16xbf16> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256x16xbf16> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_0]] : tensor<256x16xbf16>) -> tensor<256x16xbf16> +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[TO_TENSOR_0]] : tensor<32x256x16xbf16>) outs(%[[FILL_0]] : tensor<256x16xbf16>) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16) { +// CHECK: %[[MAXIMUMF_0:.*]] = arith.maximumf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[MAXIMUMF_0]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[REDUCE_0]] in writable %[[ARG1]] : (tensor<256x16xbf16>, memref<256x16xbf16>) -> () +// CHECK: return +// CHECK: } tt.func @kernel(%afloat : !tt.ptr, %res : tensor<256x16x!tt.ptr> ) -> () { @@ -38,21 +71,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<256x16xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0xFF80 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_5]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<32x256x16xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<32x256x16xbf16> -// CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<256x16xbf16> -// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : tensor<256x16xbf16>) -> tensor<256x16xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<32x256x16xbf16>) outs(%[[VAL_11]] : tensor<256x16xbf16>) dimensions = [0] -// CHECK: (%[[VAL_13:.*]]: bf16, %[[VAL_14:.*]]: bf16) { -// CHECK: %[[VAL_15:.*]] = arith.maximumf %[[VAL_13]], %[[VAL_14]] : bf16 -// CHECK: linalg.yield %[[VAL_15]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_1]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir index 5726ea0c..3b94f6b8 100644 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir +++ b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir @@ -1,5 +1,39 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref<*xbf16> to memref<512x256xbf16, strided<[256, 1]>> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<512x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<512x256xbf16, strided<[256, 1]>> to memref<512x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<512x256xbf16> to tensor<512x256xbf16> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256xbf16> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_0]] : tensor<256xbf16>) -> tensor<256xbf16> +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[TO_TENSOR_0]] : tensor<512x256xbf16>) outs(%[[FILL_0]] : tensor<256xbf16>) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16) { +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[REDUCE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<256xbf16>, memref<256xbf16, strided<[1]>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr ) -> () { @@ -30,22 +64,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xbf16> to memref<512x256xbf16, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[?, 1]>> to memref<512x256xbf16> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_11]] : tensor<256xbf16>) -> tensor<256xbf16> -// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_12]] : tensor<256xbf16>) dimensions = [0] -// CHECK: (%[[VAL_14:.*]]: bf16, %[[VAL_15:.*]]: bf16) { -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : bf16 -// CHECK: linalg.yield %[[VAL_16]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir index 7f37a9f7..c5b02648 100644 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir +++ b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir @@ -1,5 +1,41 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref<*xbf16> to memref<512x256xbf16, strided<[256, 1]>> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [512], strides: [1] : memref<*xbf16> to memref<512xbf16, strided<[1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<512x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<512x256xbf16, strided<[256, 1]>> to memref<512x256xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<512x256xbf16> to tensor<512x256xbf16> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256x512xbf16> +// CHECK: %[[TRANSPOSE_0:.*]] = linalg.transpose ins(%[[TO_TENSOR_0]] : tensor<512x256xbf16>) outs(%[[EMPTY_0]] : tensor<256x512xbf16>) permutation = [1, 0] +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<512xbf16> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_1]] : tensor<512xbf16>) -> tensor<512xbf16> +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[TRANSPOSE_0]] : tensor<256x512xbf16>) outs(%[[FILL_0]] : tensor<512xbf16>) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16) { +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[REDUCE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<512xbf16>, memref<512xbf16, strided<[1]>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr ) -> () { @@ -30,24 +66,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xbf16> to memref<512x256xbf16, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref<*xbf16> to memref<512xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[?, 1]>> to memref<512x256xbf16> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256x512xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_11]] : tensor<256x512xbf16>) permutation = [1, 0] -// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xbf16> -// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_13]] : tensor<512xbf16>) -> tensor<512xbf16> -// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_12]] : tensor<256x512xbf16>) outs(%[[VAL_14]] : tensor<512xbf16>) dimensions = [0] -// CHECK: (%[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16) { -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : bf16 -// CHECK: linalg.yield %[[VAL_18]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir index a63270ef..47a11ba1 100644 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir +++ b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir @@ -1,5 +1,39 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref<*xf32> to memref<512x256xf32, strided<[256, 1]>> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<512x256xf32> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<512x256xf32, strided<[256, 1]>> to memref<512x256xf32> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<512x256xf32> to tensor<512x256xf32> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256xf32> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : f32) outs(%[[EMPTY_0]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[TO_TENSOR_0]] : tensor<512x256xf32>) outs(%[[FILL_0]] : tensor<256xf32>) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32) { +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : f32 +// CHECK: linalg.yield %[[ADDF_0]] : f32 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[REDUCE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<256xf32>, memref<256xf32, strided<[1]>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr ) -> () { @@ -30,22 +64,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xf32> to memref<512x256xf32, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[?, 1]>> to memref<512x256xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xf32> -// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_12]] : tensor<256xf32>) dimensions = [0] -// CHECK: (%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32) { -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir index 175d33f6..49dd9ca7 100644 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir +++ b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir @@ -1,5 +1,41 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref<*xf32> to memref<512x256xf32, strided<[256, 1]>> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [512], strides: [1] : memref<*xf32> to memref<512xf32, strided<[1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<512x256xf32> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<512x256xf32, strided<[256, 1]>> to memref<512x256xf32> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<512x256xf32> to tensor<512x256xf32> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256x512xf32> +// CHECK: %[[TRANSPOSE_0:.*]] = linalg.transpose ins(%[[TO_TENSOR_0]] : tensor<512x256xf32>) outs(%[[EMPTY_0]] : tensor<256x512xf32>) permutation = [1, 0] +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<512xf32> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : f32) outs(%[[EMPTY_1]] : tensor<512xf32>) -> tensor<512xf32> +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[TRANSPOSE_0]] : tensor<256x512xf32>) outs(%[[FILL_0]] : tensor<512xf32>) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32) { +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : f32 +// CHECK: linalg.yield %[[ADDF_0]] : f32 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[REDUCE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<512xf32>, memref<512xf32, strided<[1]>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr ) -> () { @@ -30,24 +66,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xf32> to memref<512x256xf32, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref<*xf32> to memref<512xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[?, 1]>> to memref<512x256xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256x512xf32> -// CHECK: %[[VAL_12:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_11]] : tensor<256x512xf32>) permutation = [1, 0] -// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xf32> -// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_13]] : tensor<512xf32>) -> tensor<512xf32> -// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_12]] : tensor<256x512xf32>) outs(%[[VAL_14]] : tensor<512xf32>) dimensions = [0] -// CHECK: (%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32) { -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32 -// CHECK: linalg.yield %[[VAL_18]] : f32 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir b/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir index 33b9c7a1..1295e95e 100644 --- a/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir +++ b/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir @@ -1,5 +1,41 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<32x16xbf16>, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32, 256, 16], strides: [256, 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[256, 1, 1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<32x256x16xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<32x256x16xbf16, strided<[256, 1, 1]>> to memref<32x256x16xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<32x256x16xbf16> to tensor<32x256x16xbf16> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256x32x16xbf16> +// CHECK: %[[TRANSPOSE_0:.*]] = linalg.transpose ins(%[[TO_TENSOR_0]] : tensor<32x256x16xbf16>) outs(%[[EMPTY_0]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2] +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<32x16xbf16> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_1]] : tensor<32x16xbf16>) -> tensor<32x16xbf16> +// CHECK: %[[REDUCE_0:.*]] = linalg.reduce ins(%[[TRANSPOSE_0]] : tensor<256x32x16xbf16>) outs(%[[FILL_0]] : tensor<32x16xbf16>) dimensions = [0] +// CHECK: (%[[VAL_0:.*]]: bf16, %[[VAL_1:.*]]: bf16) { +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : bf16 +// CHECK: linalg.yield %[[ADDF_0]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[REDUCE_0]] in writable %[[ARG2]] : (tensor<32x16xbf16>, memref<32x16xbf16>) -> () +// CHECK: return +// CHECK: } tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr, %out: tensor<32x16x!tt.ptr> @@ -38,23 +74,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[ARG0:.*]]: memref<*xbf16>, %[[ARG1:.*]]: memref<*xbf16>, %[[ARG2:.*]]: memref<32x16xbf16>, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: i32) { -// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_1]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>> -// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<32x256x16xbf16> -// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16> -// CHECK: %[[VAL_4:.*]] = bufferization.to_tensor %[[VAL_3]] restrict writable : memref<32x256x16xbf16> to tensor<32x256x16xbf16> -// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<256x32x16xbf16> -// CHECK: %[[VAL_6:.*]] = linalg.transpose ins(%[[VAL_4]] : tensor<32x256x16xbf16>) outs(%[[VAL_5]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2] -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x16xbf16> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : bf16) outs(%[[VAL_7]] : tensor<32x16xbf16>) -> tensor<32x16xbf16> -// CHECK: %[[VAL_9:.*]] = linalg.reduce ins(%[[VAL_6]] : tensor<256x32x16xbf16>) outs(%[[VAL_8]] : tensor<32x16xbf16>) dimensions = [0] -// CHECK: (%[[VAL_10:.*]]: bf16, %[[VAL_11:.*]]: bf16) { -// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : bf16 -// CHECK: linalg.yield %[[VAL_12]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[ARG2]] : (tensor<32x16xbf16>, memref<32x16xbf16>) -> () -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_dot_opc.mlir b/test/Conversion/TritonToLinalg/use_dot_opc.mlir index df5f2140..aa421631 100644 --- a/test/Conversion/TritonToLinalg/use_dot_opc.mlir +++ b/test/Conversion/TritonToLinalg/use_dot_opc.mlir @@ -1,5 +1,44 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_0]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [128, 64], strides: [128, 1] : memref<*xbf16> to memref<128x64xbf16, strided<[128, 1]>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<128x64xbf16, strided<[128, 1]>> to memref<128x64xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<128x64xbf16> to tensor<128x64xbf16> +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [64, 256], strides: [256, 1] : memref<*xbf16> to memref<64x256xbf16, strided<[256, 1]>> +// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<64x256xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_1]], %[[ALLOC_1]] : memref<64x256xbf16, strided<[256, 1]>> to memref<64x256xbf16> +// CHECK: %[[TO_TENSOR_1:.*]] = bufferization.to_tensor %[[ALLOC_1]] restrict writable : memref<64x256xbf16> to tensor<64x256xbf16> +// CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast %[[ARG2]] to offset: [0], sizes: [128, 256], strides: [256, 1] : memref<*xbf16> to memref<128x256xbf16, strided<[256, 1]>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_2]] : memref<128x256xbf16, strided<[256, 1]>> to memref<128x256xbf16, strided<[?, 1]>> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: %[[FILL_1:.*]] = linalg.fill ins(%[[CONSTANT_0]] : bf16) outs(%[[EMPTY_1]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: %[[MATMUL_0:.*]] = linalg.matmul ins(%[[TO_TENSOR_0]], %[[TO_TENSOR_1]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%[[FILL_1]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[MATMUL_0]] in writable %[[CAST_0]] : (tensor<128x256xbf16>, memref<128x256xbf16, strided<[?, 1]>>) -> () +// CHECK: bufferization.materialize_in_destination %[[FILL_0]] in writable %[[CAST_0]] : (tensor<128x256xbf16>, memref<128x256xbf16, strided<[?, 1]>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -50,27 +89,3 @@ module { } } -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_0_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: {{.}}[[CST_128_]], 1] : memref<*xbf16> to memref<128x64xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[?, 1]>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [64, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<64x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<64x256xbf16, strided<[?, 1]>> to memref<64x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<128x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_4_:%.+]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_2_]], [[VAR_3_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination [[VAR_6_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_end_chain.mlir b/test/Conversion/TritonToLinalg/use_end_chain.mlir index a66116d4..232c46e8 100644 --- a/test/Conversion/TritonToLinalg/use_end_chain.mlir +++ b/test/Conversion/TritonToLinalg/use_end_chain.mlir @@ -33,63 +33,67 @@ module { tt.return } } +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : i32 -// CHECK-DAG: %[[CST_512_:.*]] = arith.constant 512 : i32 -// CHECK-DAG: %[[CST_1024_:.*]] = arith.constant 1024 : i32 -// CHECK: %[[VAL_30:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_30]] : tensor<256x128xi32>) -> tensor<256x128xi32> -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> -// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { -// CHECK: ^bb0(%[[VAL_10:.*]]: i32): -// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_12]], %[[CST_512_]] : i32 -// CHECK: linalg.yield %[[VAL_55]] : i32 +// CHECK-SAME: %[[ARG0:.*]]: memref<*xbf16>, %[[ARG1:.*]]: memref<*xbf16>, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32 +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 512 : i32 +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 6 : i32 +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256x128xi32> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_2]] : i32) outs(%[[EMPTY_0]] : tensor<256x128xi32>) -> tensor<256x128xi32> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<256xi32> +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel"]} outs(%[[EMPTY_1]] : tensor<256xi32>) { +// CHECK: ^bb0(%[[VAL_0:.*]]: i32): +// CHECK: %[[INDEX_0:.*]] = linalg.index 0 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[INDEX_0]] : index to i32 +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[INDEX_CAST_0]], %[[CONSTANT_1]] : i32 +// CHECK: linalg.yield %[[ADDI_0]] : i32 // CHECK: } -> tensor<256xi32> -// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_13]] : tensor<256x1xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32): -// CHECK: linalg.yield %[[VAL_17]] : i32 +// CHECK: %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> +// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<256x128xi32> +// CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXPAND_SHAPE_0]] : tensor<256x1xi32>) outs(%[[EMPTY_2]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { +// CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32): +// CHECK: linalg.yield %[[VAL_1]] : i32 // CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_19:.*]] = tensor.empty() : tensor<128xi32> -// CHECK: %[[VAL_20:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_19]] : tensor<128xi32>) { -// CHECK: ^bb0(%[[VAL_21:.*]]: i32): -// CHECK: %[[VAL_22:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i32 -// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_23]], %[[CST_1024_]] : i32 -// CHECK: linalg.yield %[[VAL_56]] : i32 +// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<128xi32> +// CHECK: %[[GENERIC_2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel"]} outs(%[[EMPTY_3]] : tensor<128xi32>) { +// CHECK: ^bb0(%[[VAL_3:.*]]: i32): +// CHECK: %[[INDEX_1:.*]] = linalg.index 0 : index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[INDEX_1]] : index to i32 +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[INDEX_CAST_1]], %[[CONSTANT_0]] : i32 +// CHECK: linalg.yield %[[ADDI_1]] : i32 // CHECK: } -> tensor<128xi32> -// CHECK: %[[VAL_24:.*]] = tensor.expand_shape %[[VAL_25:.*]] {{\[\[}}0, 1]] output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32> -// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_24]] : tensor<1x128xi32>) outs(%[[VAL_26]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_28:.*]]: i32, %[[VAL_29:.*]]: i32): -// CHECK: linalg.yield %[[VAL_28]] : i32 +// CHECK: %[[EXPAND_SHAPE_1:.*]] = tensor.expand_shape %[[GENERIC_2]] {{\[\[}}0, 1]] output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32> +// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<256x128xi32> +// CHECK: %[[GENERIC_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXPAND_SHAPE_1]] : tensor<1x128xi32>) outs(%[[EMPTY_4]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { +// CHECK: ^bb0(%[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32): +// CHECK: linalg.yield %[[VAL_4]] : i32 // CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_32:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_33:.*]], %[[VAL_31]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[VAL_33]] : tensor<256x128xi32>) { -// CHECK: ^bb0(%[[VAL_34:.*]]: i32, %[[VAL_35:.*]]: i32, %[[VAL_36:.*]]: i32): -// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_34]], %[[VAL_35]] : i32 -// CHECK: linalg.yield %[[VAL_37]] : i32 +// CHECK: %[[GENERIC_4:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[GENERIC_3]], %[[FILL_0]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[GENERIC_3]] : tensor<256x128xi32>) { +// CHECK: ^bb0(%[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32): +// CHECK: %[[MULI_0:.*]] = arith.muli %[[VAL_6]], %[[VAL_7]] : i32 +// CHECK: linalg.yield %[[MULI_0]] : i32 // CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_38:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_39:.*]], %[[VAL_40:.*]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[VAL_39]] : tensor<256x128xi32>) { -// CHECK: ^bb0(%[[VAL_41:.*]]: i32, %[[VAL_42:.*]]: i32, %[[VAL_43:.*]]: i32): -// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_41]], %[[VAL_42]] : i32 -// CHECK: linalg.yield %[[VAL_44]] : i32 +// CHECK: %[[GENERIC_5:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[GENERIC_1]], %[[GENERIC_4]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[GENERIC_1]] : tensor<256x128xi32>) { +// CHECK: ^bb0(%[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i32): +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : i32 +// CHECK: linalg.yield %[[ADDI_2]] : i32 // CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_45:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_46:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_45]], %[[VAL_46]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_47:.*]] = bufferization.to_tensor %[[VAL_46]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_47]] in writable %[[VAL_45]] -// CHECK: %[[VAL_48:.*]] = tensor.empty() : tensor<256x128xbf16> -// CHECK: %[[VAL_49:.*]] = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_50:.*]] : tensor<256x128xi32>) outs(%[[VAL_48]] : tensor<256x128xbf16>) { -// CHECK: ^bb0(%[[VAL_51:.*]]: i32, %[[VAL_52:.*]]: bf16): -// CHECK: %[[VAL_53:.*]] = arith.sitofp %[[VAL_51]] : i32 to bf16 -// CHECK: linalg.yield %[[VAL_53]] : bf16 +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [6656], sizes: [256, 128], strides: [1, 6] : memref<*xbf16> to memref<256x128xbf16, strided<[1, 6], offset: 6656>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256x128xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256x128xbf16> to tensor<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<256x128xbf16>, memref<256x128xbf16, strided<[1, ?], offset: 6656>>) -> () +// CHECK: %[[EMPTY_5:.*]] = tensor.empty() : tensor<256x128xbf16> +// CHECK: %[[GENERIC_6:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[GENERIC_5]] : tensor<256x128xi32>) outs(%[[EMPTY_5]] : tensor<256x128xbf16>) { +// CHECK: ^bb0(%[[VAL_12:.*]]: i32, %[[VAL_13:.*]]: bf16): +// CHECK: %[[SITOFP_0:.*]] = arith.sitofp %[[VAL_12]] : i32 to bf16 +// CHECK: linalg.yield %[[SITOFP_0]] : bf16 // CHECK: } -> tensor<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_54:.*]] in writable %[[VAL_45]] +// CHECK: bufferization.materialize_in_destination %[[GENERIC_6]] in writable %[[CAST_0]] : (tensor<256x128xbf16>, memref<256x128xbf16, strided<[1, ?], offset: 6656>>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_mid_chain.mlir b/test/Conversion/TritonToLinalg/use_mid_chain.mlir index f4a855aa..4ab25e6e 100644 --- a/test/Conversion/TritonToLinalg/use_mid_chain.mlir +++ b/test/Conversion/TritonToLinalg/use_mid_chain.mlir @@ -1,5 +1,53 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ + // RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s module { +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xbf16>, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xi32>, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 512 : i32 +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<256xi32> +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel"]} outs(%[[EMPTY_0]] : tensor<256xi32>) { +// CHECK: ^bb0(%[[VAL_0:.*]]: i32): +// CHECK: %[[INDEX_0:.*]] = linalg.index 0 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[INDEX_0]] : index to i32 +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[INDEX_CAST_0]], %[[CONSTANT_0]] : i32 +// CHECK: linalg.yield %[[ADDI_0]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<256x128xi32> +// CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXPAND_SHAPE_0]] : tensor<256x1xi32>) outs(%[[EMPTY_1]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { +// CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32): +// CHECK: linalg.yield %[[VAL_1]] : i32 +// CHECK: } -> tensor<256x128xi32> +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [6656], sizes: [256, 128], strides: [1, 6] : memref<*xbf16> to memref<256x128xbf16, strided<[1, 6], offset: 6656>> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[REINTERPRET_CAST_0]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<256x128xbf16> +// CHECK: memref.copy %[[REINTERPRET_CAST_0]], %[[ALLOC_0]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[ALLOC_0]] restrict writable : memref<256x128xbf16> to tensor<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[TO_TENSOR_0]] in writable %[[CAST_0]] : (tensor<256x128xbf16>, memref<256x128xbf16, strided<[1, ?], offset: 6656>>) -> () +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %[[ARG2]] to offset: [6656], sizes: [256, 128], strides: [1, 6] : memref<*xi32> to memref<256x128xi32, strided<[1, 6], offset: 6656>> +// CHECK: %[[CAST_1:.*]] = memref.cast %[[REINTERPRET_CAST_1]] : memref<256x128xi32, strided<[1, 6], offset: 6656>> to memref<256x128xi32, strided<[1, ?], offset: 6656>> +// CHECK: bufferization.materialize_in_destination %[[GENERIC_1]] in writable %[[CAST_1]] : (tensor<256x128xi32>, memref<256x128xi32, strided<[1, ?], offset: 6656>>) -> () +// CHECK: return +// CHECK: } tt.func @kernel( %arg0 : !tt.ptr, %arg1 : !tt.ptr, @@ -35,30 +83,3 @@ module { tt.return } } -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xi32>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_25:.*]] = arith.constant 512 : i32 -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> -// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { -// CHECK: ^bb0(%[[VAL_10:.*]]: i32): -// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_12]], %[[VAL_25]] : i32 -// CHECK: linalg.yield %[[VAL_24]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_13]] : tensor<256x1xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32): -// CHECK: linalg.yield %[[VAL_17]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_19]], %[[VAL_20]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_20]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_21]] in writable %[[VAL_19]] -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_7]]] : memref<*xi32> to memref<256x128xi32, strided<[1, ?], offset: 6656>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir b/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir index d8c015c7..7ec5ea25 100644 --- a/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir +++ b/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir @@ -28,26 +28,26 @@ module { // CHECK-LABEL: func.func @simple_cf_into_structured_load // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset f32 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = ptr.type_offset f32 : i32 // CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 // CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_0_]] : memref<*xf32> to memref<1xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = tptr.from_memref [[VAR_cast_]] : memref<1xf32> to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_1_:%.+]] = ptr.to_ptr [[VAR_cast_]] : memref<1xf32> to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_2_:%.+]] = arith.cmpi eq, [[PARAM_2_]], [[CST_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_3_:%.+]] = scf.if [[VAR_2_]] -> (!ptr.ptr<#tptr.default_memory_space>) { // CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2_]] : i32 // CHECK: [[VAR_7_:%.+]] = arith.muli [[VAR_6_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_8_:%.+]] = tptr.ptradd [[VAR_1_]] [[VAR_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_8_:%.+]] = ptr.ptr_add [[VAR_1_]] [[VAR_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: scf.yield [[VAR_8_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } else { // CHECK: [[VAR_6_1_:%.+]] = arith.muli [[PARAM_2_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_7_1_:%.+]] = tptr.ptradd [[VAR_1_]] [[VAR_6_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_7_1_:%.+]] = ptr.ptr_add [[VAR_1_]] [[VAR_6_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: scf.yield [[VAR_7_1_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -// CHECK: [[VAR_4_:%.+]] = tptr.to_memref [[VAR_3_]] : <#tptr.default_memory_space> to memref<1xf32> +// CHECK: [[VAR_4_:%.+]] = ptr.from_ptr [[VAR_3_]] : <#tptr.default_memory_space> to memref<1xf32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[VAR_4_]] to offset: {{.}}[[CST_6_]]{{.}}, sizes: [4], strides: [1] : memref<1xf32> to memref<4xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> // CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32> @@ -55,4 +55,4 @@ module { // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref<*xf32> to memref<4xf32, strided<[1]>> // CHECK: bufferization.materialize_in_destination [[VAR_5_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4xf32>, memref<4xf32, strided<[1]>>) -> () // CHECK: return -// CHECK: } +// CHECK: } \ No newline at end of file diff --git a/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir b/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir index 9fd9cc52..7fc8891f 100644 --- a/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir +++ b/test/Conversion/TritonToLinalgExperimental/convert_unsplat.mlir @@ -14,35 +14,39 @@ module { } } -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @unsplat_kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32> {maia.rank = 1 : i32, tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { -// CHECK-DAG: [[CST_42_:%.+]] = arith.constant 42 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_42_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32> -// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_0_]] : memref<*xi32> to memref -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[VAR_cast_]] restrict : memref to tensor -// CHECK-DAG: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_2_]] : tensor<1xi32>) outs([[VAR_0_]] : tensor<1xi32>) { -// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: i32): -// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index -// CHECK: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[VAR_7_]]{{.}} : tensor -// CHECK: linalg.yield [[VAR_extracted_0_]] : i32 +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @unsplat_kernel( +// CHECK-SAME: %[[ARG0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xi32> {maia.rank = 1 : i32, tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[ARG1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[ARG6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 42 : i32 +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : index +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CONSTANT_0]] : i32) outs(%[[EMPTY_0]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %[[FILL_1:.*]] = linalg.fill ins(%[[CONSTANT_1]] : i32) outs(%[[EMPTY_0]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<*xi32> to memref +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[CAST_0]] restrict : memref to tensor +// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[FILL_1]] : tensor<1xi32>) outs(%[[EMPTY_0]] : tensor<1xi32>) { +// CHECK: ^bb0(%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32): +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i32 to index +// CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[TO_TENSOR_0]]{{\[}}%[[INDEX_CAST_0]]] : tensor +// CHECK: linalg.yield %[[EXTRACT_0]] : i32 // CHECK: } -> tensor<1xi32> -// CHECK: [[VAR_5_:%.+]] = tensor.empty() : tensor<1xi1> -// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_1_]] : tensor<1xi32>, tensor<1xi32>) outs([[VAR_5_]] : tensor<1xi1>) { -// CHECK: ^bb0([[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: i1): -// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi sgt, [[IN_2_]], [[IN_3_]] : i32 -// CHECK: linalg.yield [[VAR_7_1_]] : i1 +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<1xi1> +// CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[GENERIC_0]], %[[FILL_0]] : tensor<1xi32>, tensor<1xi32>) outs(%[[EMPTY_1]] : tensor<1xi1>) { +// CHECK: ^bb0(%[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i1): +// CHECK: %[[CMPI_0:.*]] = arith.cmpi sgt, %[[VAL_2]], %[[VAL_3]] : i32 +// CHECK: linalg.yield %[[CMPI_0]] : i1 // CHECK: } -> tensor<1xi1> -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_6_]]{{.}}[[CST_0_1_]]{{.}} : tensor<1xi1> -// CHECK: scf.if [[VAR_extracted_]] { -// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_1_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[CST_42_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1], offset: ?>> +// CHECK: %[[EXTRACT_1:.*]] = tensor.extract %[[GENERIC_1]]{{\[}}%[[CONSTANT_2]]] : tensor<1xi1> +// CHECK: scf.if %[[EXTRACT_1]] { +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> +// CHECK: affine.store %[[CONSTANT_0]], %[[REINTERPRET_CAST_0]][0] : memref<1xi32, strided<[1]>> // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir b/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir index a64c185c..bb55ec0b 100644 --- a/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir +++ b/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir @@ -42,11 +42,11 @@ module { // CHECK-LABEL: func.func @cast_with_int_ptr // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset i16 : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tptr.type_offset i64 : i32 -// CHECK-DAG: [[VAR_2_:%.+]] = tptr.type_offset i32 : i64 -// CHECK-DAG: [[VAR_3_:%.+]] = tptr.type_offset i8 : i32 -// CHECK-DAG: [[VAR_4_:%.+]] = tptr.type_offset i32 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = ptr.type_offset i16 : i32 +// CHECK-DAG: [[VAR_1_:%.+]] = ptr.type_offset i64 : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = ptr.type_offset i32 : i64 +// CHECK-DAG: [[VAR_3_:%.+]] = ptr.type_offset i8 : i32 +// CHECK-DAG: [[VAR_4_:%.+]] = ptr.type_offset i32 : i32 // CHECK-DAG: [[CST_111_:%.+]] = arith.constant 111 : i32 // CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : i32 // CHECK-DAG: [[CST_9_:%.+]] = arith.constant 9 : i32 @@ -58,46 +58,46 @@ module { // CHECK-DAG: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> // CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> // CHECK: [[VAR_7_:%.+]] = arith.muli [[CST_111_]], [[VAR_4_]] : i32 -// CHECK-DAG: [[VAR_8_:%.+]] = tptr.ptradd [[VAR_6_]] [[VAR_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_8_:%.+]] = ptr.ptr_add [[VAR_6_]] [[VAR_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_9_:%.+]] = arith.muli [[CST_10_]], [[VAR_3_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = tptr.ptradd [[VAR_8_]] [[VAR_9_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_10_:%.+]] = ptr.ptr_add [[VAR_8_]] [[VAR_9_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_11_:%.+]] = tptr.ptrtoint [[VAR_5_]] : <#tptr.default_memory_space> to i64 // CHECK: [[VAR_12_:%.+]] = arith.muli [[VAR_11_]], [[VAR_2_]] : i64 -// CHECK-DAG: [[VAR_13_:%.+]] = tptr.ptradd [[VAR_5_]] [[VAR_12_]] : <#tptr.default_memory_space>, i64 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_13_:%.+]] = ptr.ptr_add [[VAR_5_]] [[VAR_12_]] : <#tptr.default_memory_space>, i64 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[CST_9_]], [[VAR_4_]] : i32 -// CHECK: [[VAR_15_:%.+]] = tptr.ptradd [[VAR_13_]] [[VAR_14_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_15_:%.+]] = ptr.ptr_add [[VAR_13_]] [[VAR_14_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: [[VAR_16_:%.+]] = tptr.ptrtoint [[VAR_15_]] : <#tptr.default_memory_space> to i64 // CHECK-DAG: [[VAR_17_:%.+]] = arith.remsi [[VAR_16_]], [[CST_10_1_]] : i64 // CHECK-DAG: [[VAR_18_:%.+]] = arith.muli [[CST_1_]], [[VAR_4_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = tptr.ptradd [[VAR_10_]] [[VAR_18_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_19_:%.+]] = ptr.ptr_add [[VAR_10_]] [[VAR_18_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_17_]], [[VAR_2_]] : i64 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]] = tptr.ptradd [[VAR_19_]] [[VAR_20_]] : <#tptr.default_memory_space>, i64 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_21_:%.+]] = ptr.ptr_add [[VAR_19_]] [[VAR_20_]] : <#tptr.default_memory_space>, i64 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[CST_2_]], [[VAR_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = tptr.ptradd [[VAR_21_]] [[VAR_22_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_23_:%.+]] = ptr.ptr_add [[VAR_21_]] [[VAR_22_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_2_]], [[VAR_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_25_:%.+]] = tptr.ptradd [[VAR_23_]] [[VAR_24_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_25_:%.+]] = ptr.ptr_add [[VAR_23_]] [[VAR_24_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_26_:%.+]] = arith.muli [[CST_3_]], [[VAR_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = tptr.ptradd [[VAR_25_]] [[VAR_26_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_27_:%.+]] = ptr.ptr_add [[VAR_25_]] [[VAR_26_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[CST_4_]], [[VAR_0_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_29_:%.+]] = tptr.ptradd [[VAR_27_]] [[VAR_28_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_29_:%.+]] = ptr.ptr_add [[VAR_27_]] [[VAR_28_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[PARAM_2_]], [[VAR_0_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_31_:%.+]] = tptr.ptradd [[VAR_29_]] [[VAR_30_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_31_:%.+]] = ptr.ptr_add [[VAR_29_]] [[VAR_30_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK-DAG: [[VAR_32_:%.+]] = arith.muli [[CST_3_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_33_:%.+]] = tptr.ptradd [[VAR_31_]] [[VAR_32_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: [[VAR_34_:%.+]] = tptr.to_memref [[VAR_33_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_33_:%.+]] = ptr.ptr_add [[VAR_31_]] [[VAR_32_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_34_:%.+]] = ptr.from_ptr [[VAR_33_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK-DAG: [[LOAD_VAR_34_MEM_:%.+]] = memref.load [[VAR_34_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK-DAG: [[VAR_36_:%.+]] = arith.extsi [[PARAM_2_]] : i32 to i64 // CHECK: [[VAR_37_:%.+]] = arith.addi [[VAR_17_]], [[VAR_36_]] : i64 // CHECK: [[VAR_38_:%.+]] = tptr.inttoptr [[VAR_37_]] : i64 to <#tptr.default_memory_space> -// CHECK: [[VAR_39_:%.+]] = tptr.to_memref [[VAR_38_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_39_:%.+]] = ptr.from_ptr [[VAR_38_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: memref.store [[LOAD_VAR_34_MEM_]], [[VAR_39_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir b/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir index b8a37cbc..a1bdf743 100644 --- a/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir +++ b/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir @@ -50,9 +50,9 @@ module { // CHECK-LABEL: func.func @ptr_cat // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset i1 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = ptr.type_offset i1 : i32 // CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_1_:%.+]] = tptr.type_offset i32 : i32 +// CHECK-DAG: [[VAR_1_:%.+]] = ptr.type_offset i32 : i32 // CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : i8 // CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i32 // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 @@ -84,14 +84,14 @@ module { // CHECK: [[VAR_15_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_14_]], [[VAR_12_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>, tensor<8xi32>) outs([[VAR_14_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_2_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_2_:%.+]] = arith.muli [[IN_3_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_2_:%.+]] = tptr.ptradd [[IN_2_]] [[VAR_35_2_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_2_:%.+]] = ptr.ptr_add [[IN_2_]] [[VAR_35_2_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_2_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_16_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_13_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_17_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]], [[VAR_12_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>, tensor<8xi32>) outs([[VAR_16_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_5_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_6_:%.+]]: i32, [[IN_7_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_3_:%.+]] = arith.muli [[IN_6_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_3_:%.+]] = tptr.ptradd [[IN_5_]] [[VAR_35_3_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_3_:%.+]] = ptr.ptr_add [[IN_5_]] [[VAR_35_3_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_3_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_18_:%.+]] = tensor.empty() : tensor<16x!ptr.ptr<#tptr.default_memory_space>> @@ -100,12 +100,12 @@ module { // CHECK: [[VAR_19_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_8_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_9_:%.+]]: i32, [[IN_10_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_4_:%.+]] = arith.muli [[IN_9_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_4_:%.+]] = tptr.ptradd [[IN_8_]] [[VAR_35_4_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_4_:%.+]] = ptr.ptr_add [[IN_8_]] [[VAR_35_4_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_4_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_20_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_19_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_7_]] : tensor<16xi32>) { // CHECK: ^bb0([[IN_11_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_12_:%.+]]: i32): -// CHECK: [[VAR_35_5_:%.+]] = tptr.to_memref [[IN_11_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_35_5_:%.+]] = ptr.from_ptr [[IN_11_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: [[VAR_36_4_:%.+]] = memref.load [[VAR_35_5_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_36_4_]] : i32 // CHECK: } -> tensor<16xi32> @@ -117,7 +117,7 @@ module { // CHECK: [[VAR_22_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_21_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_16_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_17_:%.+]]: i32, [[IN_18_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_7_:%.+]] = arith.muli [[IN_17_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_5_:%.+]] = tptr.ptradd [[IN_16_]] [[VAR_35_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_5_:%.+]] = ptr.ptr_add [[IN_16_]] [[VAR_35_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_5_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_23_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_20_]], [[VAR_8_]] : tensor<16xi32>, tensor<16xi32>) outs([[VAR_20_]] : tensor<16xi32>) { @@ -128,19 +128,19 @@ module { // CHECK: [[VAR_24_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_23_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_22_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_23_:%.+]]: i32, [[IN_24_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_9_:%.+]] = arith.muli [[IN_23_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_6_:%.+]] = tptr.ptradd [[IN_22_]] [[VAR_35_9_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_6_:%.+]] = ptr.ptr_add [[IN_22_]] [[VAR_35_9_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_6_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_25_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_18_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_26_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_25_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_25_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_25_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_26_:%.+]]: i32, [[IN_27_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_10_:%.+]] = arith.muli [[IN_26_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_36_7_:%.+]] = tptr.ptradd [[IN_25_]] [[VAR_35_10_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_7_:%.+]] = ptr.ptr_add [[IN_25_]] [[VAR_35_10_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_7_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_27_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_26_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_5_]] : tensor<16xi8>) { // CHECK: ^bb0([[IN_28_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_29_:%.+]]: i8): -// CHECK: [[VAR_35_11_:%.+]] = tptr.to_memref [[IN_28_]] : <#tptr.default_memory_space> to memref<1xi8> +// CHECK: [[VAR_35_11_:%.+]] = ptr.from_ptr [[IN_28_]] : <#tptr.default_memory_space> to memref<1xi8> // CHECK: [[VAR_36_7_:%.+]] = memref.load [[VAR_35_11_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi8> // CHECK: linalg.yield [[VAR_36_7_]] : i8 // CHECK: } -> tensor<16xi8> @@ -158,12 +158,12 @@ module { // CHECK: [[VAR_31_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_30_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_30_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_37_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_38_:%.+]]: i32, [[IN_39_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_14_:%.+]] = arith.muli [[IN_38_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_8_:%.+]] = tptr.ptradd [[IN_37_]] [[VAR_35_14_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_8_:%.+]] = ptr.ptr_add [[IN_37_]] [[VAR_35_14_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_8_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_32_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_31_]], [[VAR_29_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi1>) outs([[VAR_7_]] : tensor<16xi32>) { // CHECK: ^bb0([[IN_40_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_41_:%.+]]: i1, [[IN_42_:%.+]]: i32): -// CHECK-DAG: [[VAR_35_15_:%.+]] = tptr.to_memref [[IN_40_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK-DAG: [[VAR_35_15_:%.+]] = ptr.from_ptr [[IN_40_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK-DAG: [[VAR_36_9_:%.+]] = scf.if [[IN_41_]] -> (i32) { // CHECK: [[LOAD_VAR_35_15_MEM_:%.+]] = memref.load [[VAR_35_15_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK: scf.yield [[LOAD_VAR_35_15_MEM_]] : i32 @@ -176,13 +176,13 @@ module { // CHECK: [[VAR_34_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_33_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_33_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_43_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_44_:%.+]]: i32, [[IN_45_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_35_16_:%.+]] = arith.muli [[IN_44_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_10_:%.+]] = tptr.ptradd [[IN_43_]] [[VAR_35_16_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_36_10_:%.+]] = ptr.ptr_add [[IN_43_]] [[VAR_35_16_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_36_10_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_34_]], [[VAR_32_]], [[VAR_29_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>, tensor<16xi1>) { // CHECK: ^bb0([[IN_46_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_47_:%.+]]: i32, [[IN_48_:%.+]]: i1): // CHECK: scf.if [[IN_48_]] { -// CHECK: [[VAR_35_17_:%.+]] = tptr.to_memref [[IN_46_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_35_17_:%.+]] = ptr.from_ptr [[IN_46_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: memref.store [[IN_47_]], [[VAR_35_17_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK: } // CHECK: linalg.yield diff --git a/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir b/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir index 810056cc..fcd12823 100644 --- a/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir +++ b/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir @@ -42,7 +42,7 @@ module { // CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%{{.+}}, [[MASK]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi1>) // CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i1, %out: i32): -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[MEMREF:%.+]] = ptr.from_ptr %in : <#tptr.default_memory_space> to memref<1xi32> // CHECK: [[SCF_IF:%.+]] = scf.if %in_0 -> (i32) { // CHECK: [[LOAD:%.+]] = memref.load [[MEMREF]][%c0] : memref<1xi32> // CHECK: scf.yield [[LOAD]] : i32 @@ -55,7 +55,7 @@ module { // CHECK: linalg.generic // CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i32, %in_1: i1): // CHECK: scf.if %in_1 { -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[MEMREF:%.+]] = ptr.from_ptr %in : <#tptr.default_memory_space> to memref<1xi32> // CHECK: memref.store %in_0, [[MEMREF]][%c0] : memref<1xi32> // CHECK: } // CHECK: linalg.yield diff --git a/test/Conversion/TritonToPtr/masked_load_store.mlir b/test/Conversion/TritonToPtr/masked_load_store.mlir index 0180f4b7..55044ead 100644 --- a/test/Conversion/TritonToPtr/masked_load_store.mlir +++ b/test/Conversion/TritonToPtr/masked_load_store.mlir @@ -30,7 +30,7 @@ module { // CHECK: linalg.generic // CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i1, %out: i32): -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[MEMREF:%.+]] = ptr.from_ptr %in : <#tptr.default_memory_space> to memref<1xi32> // CHECK: [[SCF_IF:%.+]] = scf.if %in_0 -> (i32) { // CHECK: [[LOAD:%.+]] = memref.load [[MEMREF]][%c0] : memref<1xi32> // CHECK: scf.yield [[LOAD]] : i32 @@ -43,7 +43,7 @@ module { // CHECK: linalg.generic // CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i32, %in_1: i1): // CHECK: scf.if %in_1 { -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[MEMREF:%.+]] = ptr.from_ptr %in : <#tptr.default_memory_space> to memref<1xi32> // CHECK: memref.store %in_0, [[MEMREF]][%c0] : memref<1xi32> // CHECK: } // CHECK: linalg.yield diff --git a/test/Conversion/TritonToPtr/regular_load_store.mlir b/test/Conversion/TritonToPtr/regular_load_store.mlir index d1612e92..1d8fe437 100644 --- a/test/Conversion/TritonToPtr/regular_load_store.mlir +++ b/test/Conversion/TritonToPtr/regular_load_store.mlir @@ -25,8 +25,8 @@ module { // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset f32 : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tptr.type_offset i32 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = ptr.type_offset f32 : i32 +// CHECK-DAG: [[VAR_1_:%.+]] = ptr.type_offset i32 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> // CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> // CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<1024xi32> @@ -41,19 +41,19 @@ module { // CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>, tensor<1024xi32>) outs([[VAR_7_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_1_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_14_1_:%.+]] = arith.muli [[IN_2_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_15_1_:%.+]] = tptr.ptradd [[IN_1_]] [[VAR_14_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_15_1_:%.+]] = ptr.ptr_add [[IN_1_]] [[VAR_14_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_15_1_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_9_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_6_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]], [[VAR_5_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>, tensor<1024xi32>) outs([[VAR_9_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_4_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_14_2_:%.+]] = arith.muli [[IN_5_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_15_2_:%.+]] = tptr.ptradd [[IN_4_]] [[VAR_14_2_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_15_2_:%.+]] = ptr.ptr_add [[IN_4_]] [[VAR_14_2_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_15_2_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_8_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_4_]] : tensor<1024xi32>) { // CHECK: ^bb0([[IN_7_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_8_:%.+]]: i32): -// CHECK: [[VAR_14_3_:%.+]] = tptr.to_memref [[IN_7_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_14_3_:%.+]] = ptr.from_ptr [[IN_7_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: [[VAR_15_2_:%.+]] = memref.load [[VAR_14_3_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_15_2_]] : i32 // CHECK: } -> tensor<1024xi32> @@ -65,7 +65,7 @@ module { // CHECK: } -> tensor<1024xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]], [[VAR_13_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>, tensor<1024xf32>) { // CHECK: ^bb0([[IN_11_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_12_:%.+]]: f32): -// CHECK: [[VAR_14_5_:%.+]] = tptr.to_memref [[IN_11_]] : <#tptr.default_memory_space> to memref<1xf32> +// CHECK: [[VAR_14_5_:%.+]] = ptr.from_ptr [[IN_11_]] : <#tptr.default_memory_space> to memref<1xf32> // CHECK: memref.store [[IN_12_]], [[VAR_14_5_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> // CHECK: linalg.yield // CHECK: } diff --git a/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir b/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir index dfcab2b2..369a0fb0 100644 --- a/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir +++ b/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir @@ -43,9 +43,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @tensor_ptr // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset i64 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = ptr.type_offset i64 : i32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_1_:%.+]] = tptr.type_offset i32 : i32 +// CHECK-DAG: [[VAR_1_:%.+]] = ptr.type_offset i32 : i32 // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> // CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> @@ -63,12 +63,12 @@ module { // CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_8_]], [[VAR_6_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_8_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_1_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_22_1_:%.+]] = arith.muli [[IN_2_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_23_1_:%.+]] = tptr.ptradd [[IN_1_]] [[VAR_22_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_23_1_:%.+]] = ptr.ptr_add [[IN_1_]] [[VAR_22_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_23_1_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_4_]] : tensor<16xi32>) { // CHECK: ^bb0([[IN_4_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_5_:%.+]]: i32): -// CHECK: [[VAR_22_2_:%.+]] = tptr.to_memref [[IN_4_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_22_2_:%.+]] = ptr.from_ptr [[IN_4_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: [[VAR_23_1_:%.+]] = memref.load [[VAR_22_2_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_23_1_]] : i32 // CHECK: } -> tensor<16xi32> @@ -85,7 +85,7 @@ module { // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_14_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_13_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_4_]] : tensor<16xi32>) { // CHECK: ^bb0([[IN_10_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_11_:%.+]]: i32): -// CHECK: [[VAR_22_5_:%.+]] = tptr.to_memref [[IN_10_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_22_5_:%.+]] = ptr.from_ptr [[IN_10_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: [[VAR_23_1_1_:%.+]] = memref.load [[VAR_22_5_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_23_1_1_]] : i32 // CHECK: } -> tensor<16xi32> @@ -93,7 +93,7 @@ module { // CHECK: [[VAR_16_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_15_]], [[VAR_6_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_15_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_12_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_13_:%.+]]: i32, [[IN_14_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_22_6_:%.+]] = arith.muli [[IN_13_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_23_2_:%.+]] = tptr.ptradd [[IN_12_]] [[VAR_22_6_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_23_2_:%.+]] = ptr.ptr_add [[IN_12_]] [[VAR_22_6_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_23_2_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_17_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_11_]] : tensor<16xi64>) { @@ -114,7 +114,7 @@ module { // CHECK: [[VAR_20_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]], [[VAR_5_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_16_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { // CHECK: ^bb0([[IN_22_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_23_:%.+]]: i32, [[IN_24_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): // CHECK: [[VAR_22_10_:%.+]] = arith.muli [[IN_23_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_23_3_:%.+]] = tptr.ptradd [[IN_22_]] [[VAR_22_10_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK: [[VAR_23_3_:%.+]] = ptr.ptr_add [[IN_22_]] [[VAR_22_10_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> // CHECK: linalg.yield [[VAR_23_3_]] : !ptr.ptr<#tptr.default_memory_space> // CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> // CHECK: [[VAR_21_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_19_]] : tensor<16xi64>) outs([[VAR_4_]] : tensor<16xi32>) { @@ -124,7 +124,7 @@ module { // CHECK: } -> tensor<16xi32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_20_]], [[VAR_21_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) { // CHECK: ^bb0([[IN_27_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_28_:%.+]]: i32): -// CHECK: [[VAR_22_12_:%.+]] = tptr.to_memref [[IN_27_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_22_12_:%.+]] = ptr.from_ptr [[IN_27_]] : <#tptr.default_memory_space> to memref<1xi32> // CHECK: memref.store [[IN_28_]], [[VAR_22_12_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield // CHECK: } diff --git a/tools/triton-shared-opt/CMakeLists.txt b/tools/triton-shared-opt/CMakeLists.txt index 260c3e5d..c47bf833 100644 --- a/tools/triton-shared-opt/CMakeLists.txt +++ b/tools/triton-shared-opt/CMakeLists.txt @@ -20,3 +20,27 @@ target_link_libraries(triton-shared-opt PRIVATE ) mlir_check_all_link_libraries(triton-shared-opt) + +add_llvm_executable(triton-shared-lsp triton-shared-lsp.cpp PARTIAL_SOURCES_INTENDED) + +llvm_update_compile_flags(triton-shared-lsp) +target_link_libraries(triton-shared-lsp PRIVATE + TritonTransforms + TritonSharedAnalysis + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialect + TritonAMDGPUTestAnalysis + TritonTestProton + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-shared-lsp) \ No newline at end of file diff --git a/tools/triton-shared-opt/RegisterTritonSharedDialects.h b/tools/triton-shared-opt/RegisterTritonSharedDialects.h index 6d92953f..0edffea5 100644 --- a/tools/triton-shared-opt/RegisterTritonSharedDialects.h +++ b/tools/triton-shared-opt/RegisterTritonSharedDialects.h @@ -1,14 +1,24 @@ #pragma once #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "triton-shared/Conversion/StructuredToMemref/Passes.h" #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" @@ -47,5 +57,11 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::linalg::LinalgDialect, mlir::func::FuncDialect, mlir::tensor::TensorDialect, mlir::memref::MemRefDialect, - mlir::bufferization::BufferizationDialect>(); + mlir::bufferization::BufferizationDialect, + mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, + mlir::triton::gpu::TritonGPUDialect, + mlir::triton::instrument::TritonInstrumentDialect, mlir::gpu::GPUDialect, + mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect, + mlir::triton::nvws::NVWSDialect, mlir::ROCDL::ROCDLDialect, + mlir::triton::gluon::GluonDialect>(); } diff --git a/tools/triton-shared-opt/triton-shared-lsp.cpp b/tools/triton-shared-opt/triton-shared-lsp.cpp new file mode 100644 index 00000000..6a8e9d2b --- /dev/null +++ b/tools/triton-shared-opt/triton-shared-lsp.cpp @@ -0,0 +1,10 @@ +#include "./RegisterTritonSharedDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonSharedDialects(registry); + + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/triton-hash.txt b/triton-hash.txt index 80f8308f..361f2dfa 100644 --- a/triton-hash.txt +++ b/triton-hash.txt @@ -1 +1 @@ -e44bd1c83c1c3e8deac7c4f02683cfb3cc395c8b +acd81049917c62aa156fff2669ae25664048ac77 \ No newline at end of file