Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ void populateMIGraphXToLinalgBoundaryDialectConversion(
/// migraphx.mlir.as_logical_shape and migraphx.mlir.as_underlying_shape.
void populateMIGraphXFuncBoundaryToLinalgConversionPatterns(
RewritePatternSet &target, TypeConverter &typeConverter);

/// Populates conversion patterns for function boundaries mhal.launcher
void populateMIGraphXToLinalgMHALLauncherConversion(
RewritePatternSet &target, TypeConverter &typeConverter);
} // namespace migraphx
} // namespace mlir

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/RocMLIRPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def MIGraphXToLinalgPass : Pass<"migraphx-to-linalg", "::mlir::func::FuncOp"> {
}];

let dependentDialects = ["arith::ArithDialect", "tensor::TensorDialect",
"linalg::LinalgDialect"];
"linalg::LinalgDialect", "rock::RockDialect"];
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ def ConvOpBwdWeightType : I32EnumAttrCase<"BwdWeight", 2, "conv_bwd_weight">;
def ConvOpTypes : Rock_I32Enum<"ConvOpType", "The type of a convolution operation",
[ConvOpType, ConvOpBwdDataType, ConvOpBwdWeightType]>;

/// LinalgConvType
def LinalgConv_1D : I32EnumAttrCase<"Conv1dNgchGfch", 0, "conv1d_ngch_gfch">;
def LinalgConv_2D
: I32EnumAttrCase<"Conv2dNgchwGfchw", 1, "conv2d_ngchw_gfchw">;
def LinalgConv_3D
: I32EnumAttrCase<"Conv3dNgchwdGfchwd", 2, "conv3d_ngchwd_gfchwd">;

def LinalgConvType
: Rock_I32Enum<"LinalgConvType",
"The layout of a grouped convolution operation",
[LinalgConv_1D, LinalgConv_2D, LinalgConv_3D]>;

def LinalgConvTypeAttr : EnumAttr<Rock_Dialect, LinalgConvType, "LinalgConvType">;

/// Kerneltype
def KernelTypeConv : I32EnumAttrCase<"Conv", 0>;
def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;
Expand Down
309 changes: 308 additions & 1 deletion mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

using namespace mlir;
Expand Down Expand Up @@ -104,6 +105,307 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite(
"input shape is non standard or broadcast; cannot convert this shape");
}

namespace {
struct ConvConverter final
: public OpConversionPattern<migraphx::ConvolutionOp> {
using OpConversionPattern<migraphx::ConvolutionOp>::OpConversionPattern;
using OpConversionPattern<migraphx::ConvolutionOp>::getTypeConverter;
using OpAdaptor =
typename OpConversionPattern<migraphx::ConvolutionOp>::OpAdaptor;

LogicalResult
matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

private:
LogicalResult emitConv(ConversionPatternRewriter &rewriter,
migraphx::ConvolutionOp op, Value input,
Value filter) const;
};
} // namespace

// Nice helper function for the linalg.generic op region
static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) {
Value inputVal = blockArgs[0];
Value filterVal = blockArgs[1];
Value outputVal = blockArgs[2];
Value mul = arith::MulFOp::create(b, loc, inputVal, filterVal);
Value add = arith::AddFOp::create(b, loc, outputVal, mul);
linalg::YieldOp::create(b, loc, add);
}

/// Emit convolution attributes on the newly created operation.
static void emitConvAttributes(migraphx::ConvolutionOp op, Value convOp,
Attribute strides, Attribute dilation,
Attribute pad, Attribute convOpName) {
Operation *newOp = convOp.getDefiningOp();
newOp->setAttr("pad", pad);
newOp->setAttr("group", op.getGroupAttr());
newOp->setAttr("stride", strides);
newOp->setAttr("dilation", dilation);

// Convert optional attributes
if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
newOp->setAttr("perf_config", attr);
newOp->setAttr("conv_op", convOpName);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have this attribute defined in RockAttrDefs and referenced here with its name instead of hardcoding a string like this. Specially because when you work on the linalg-to-rock side of things you will need to reference the attribute again, and hardcoding strings all over the place should be avoided

}

/// Emit a grouped convolution of any spatial rank (1D, 2D, or 3D).
/// Input shape: (batch, group, channel, spatial...),
/// filter shape: (group, filter, channel, kernel_spatial...)
///
/// clang-format off
/// for n in batch:
/// for g in group:
/// for f in filters:
/// for oh_0 in output_spatial_0:
/// for oh_1 in output_spatial_1:
/// // ...
/// for oh_{dim-1} in output_spatial_{dim-1}:
/// for c in channels: // reduction
/// for kh_0 in kernel_spatial_0: // reduction
/// for kh_1 in kernel_spatial_1: // reduction
/// // ...
/// clang-format on
static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc,
RankedTensorType resultType, Value input,
Value filter, Value zero,
ArrayAttr strides,
ArrayAttr dilation) {
MLIRContext *ctx = rewriter.getContext();
int64_t dim = cast<RankedTensorType>(input.getType()).getRank() - 3;
SmallVector<int64_t, 4> strideVals;
SmallVector<int64_t, 4> dilationVals;
llvm::transform(strides.getValue(), std::back_inserter(strideVals), [](Attribute attr){
return cast<IntegerAttr>(attr).getInt();
});
llvm::transform(dilation.getValue(), std::back_inserter(dilationVals), [](Attribute attr){
return cast<IntegerAttr>(attr).getInt();
});

// Iteration domain layout:
// parallel: batch, group, filter, oh_0 .. oh_{dim-1}
// reduction: channel, kh_0 .. kh_{dim-1}
int64_t totalDims = 4 + 2 * dim;
SmallVector<AffineExpr> d;
for (int64_t i = 0; i < totalDims; ++i)
d.push_back(getAffineDimExpr(i, ctx));

AffineExpr batch = d[0], group = d[1], filterExpr = d[2];
AffineExpr channel = d[3 + dim];

SmallVector<AffineExpr> inputExprs = {batch, group, channel};
for (int64_t i = 0; i < dim; ++i)
inputExprs.push_back(d[3 + i] * strideVals[i] +
d[4 + dim + i] * dilationVals[i]);

SmallVector<AffineExpr> filterExprs = {group, filterExpr, channel};
for (int64_t i = 0; i < dim; ++i)
filterExprs.push_back(d[4 + dim + i]);

SmallVector<AffineExpr> outputExprs = {batch, group, filterExpr};
for (int64_t i = 0; i < dim; ++i)
outputExprs.push_back(d[3 + i]);

SmallVector<AffineMap> indexingMaps = {
AffineMap::get(totalDims, 0, inputExprs, ctx),
AffineMap::get(totalDims, 0, filterExprs, ctx),
AffineMap::get(totalDims, 0, outputExprs, ctx)};

SmallVector<utils::IteratorType> iteratorTypes(3 + dim,
utils::IteratorType::parallel);
iteratorTypes.append(1 + dim, utils::IteratorType::reduction);

return linalg::GenericOp::create(rewriter, loc, resultType,
ValueRange{input, filter}, zero,
indexingMaps, iteratorTypes, convBodyBuilder)
.getResult(0);
}

LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter,
migraphx::ConvolutionOp op, Value input,
Value filter) const {
// Input and filter are already in NGC* and GFC* form (group dimension
// expanded). Build the result type as NGF* (with explicit G), emit the
// grouped linalg conv (1D/2D/3D), then collapse back to NF* for the type
// converter.
Location loc = op.getLoc();
int64_t group = op.getGroupAttr().getInt();
int64_t dim = cast<RankedTensorType>(input.getType()).getRank() -
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have a migraphx::ConvolutionOp method to get the rank?

3; // exclude batch (N), group (G), channel (C)
assert(dim >= 1 && dim <= 3 && "this should be checked at matchAndRewrite");

// Result type from the op is NF*; expand to NGF* for the linalg conv.
RankedTensorType resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getResult()));
ArrayRef<int64_t> resultShape = resultType.getShape();
SmallVector<int64_t, 4> newShape;
int64_t n = resultType.getDimSize(0);
int64_t newF = resultType.getDimSize(1) / group;
assert(resultType.getDimSize(1) % group == 0 &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's generate a proper failure error (LogicalResult) instead of using asserts

Copy link
Member Author

@Mr-Anyone Mr-Anyone Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original intention was to move this to the verifier, and make this an invariant of the operation.

"output channel must be divisible by group");
newShape.push_back(n);
newShape.push_back(group);
newShape.push_back(newF);
newShape.insert(newShape.end(), std::next(resultShape.begin(), 2),
resultShape.end());
auto newResultType =
RankedTensorType::get(newShape, resultType.getElementType());
Value zero = arith::ConstantOp::create(rewriter, loc, newResultType,
rewriter.getZeroAttr(newResultType));

ArrayAttr strides = op.getStride();
ArrayAttr dilation =op.getDilation();

rock::LinalgConvType convLayout =
(dim == 1) ? rock::LinalgConvType::Conv1dNgchGfch
: (dim == 2) ? rock::LinalgConvType::Conv2dNgchwGfchw
: rock::LinalgConvType::Conv3dNgchwdGfchwd;
auto resultConvOpName =
rock::LinalgConvTypeAttr::get(rewriter.getContext(), convLayout);
Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter,
zero, strides, dilation);

emitConvAttributes(op, result, strides, dilation,
op.getPaddingAttr(),
resultConvOpName);

// we must reshape the operand to what the type converter expects
SmallVector<ReassociationIndices, 4> reassociation{{0}, {1, 2}};
llvm::for_each(llvm::seq<int64_t>(3, dim + 3),
[&](int64_t index) { reassociation.push_back({index}); });
auto finalResult =
tensor::CollapseShapeOp::create(rewriter, loc, result, reassociation);

rewriter.replaceOp(op, finalResult);
return success();
}

LogicalResult
ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Forward convolution is lowered in three steps:
// 1. Apply padding to the input when the op has non-zero padding.
// 2. Expand the channel dimension into (group, channel_per_group),
// introducing
// a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and
// filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr.
// 3. Emit the grouped linalg convolution (1D/2D/3D), then collapse the
// result back to the original NFHW/NFDHW shape for the type converter.
Location loc = op.getLoc();
Value input = adaptor.getInput();
Value filter = adaptor.getFilter();
ArrayAttr padAttr = adaptor.getPaddingAttr();
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
int64_t dim = inputType.getRank() - 2;
int64_t group = op.getGroupAttr().getInt();

if (dim > 3 || dim < 1) {
return op.emitError(Twine(dim) + "D conv is not supported for now");
}

// For now, the linalg.generic region doesn't support type casting,
// so we emit an error for now

if (inputType.getElementType() != op.getFilter().getType().getElementType() ||
inputType.getElementType() != op.getResult().getType().getElementType()) {
return op.emitError(
"type casting between operands and result is unsupported for now");
}

// Step 1: apply padding when any padding value is non-zero.
if (!llvm::all_of(padAttr, [](Attribute pad) {
return cast<IntegerAttr>(pad).getValue() == 0;
})) {
// Apply symmetric padding to spatial dimensions.
SmallVector<OpFoldResult, 4> low(inputType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult, 4> high(inputType.getRank(),
rewriter.getIndexAttr(0));
// insert padding to inputs
assert(2 * dim == (int64_t)padAttr.size() && "padding is symmetric");

// MIGraphX padAttr is [dim0_low, dim1_low,..., dim0_high, dim1_high, ...]
SmallVector<int64_t, 4> newShape(inputType.getShape());
auto lowAttrs = padAttr.getValue().drop_back(dim);
auto highAttrs = padAttr.getValue().drop_front(dim);
// The first spatial dimension (H) is always located at index 2 in the
// NC* layout (after batch and channel), regardless of convolution rank.
int64_t dimHOffset = 2;
llvm::for_each(llvm::seq<int64_t>(dim), [&](int64_t index) {
int64_t lowPad = cast<IntegerAttr>(lowAttrs[index]).getInt();
int64_t highPad = cast<IntegerAttr>(highAttrs[index]).getInt();
newShape[dimHOffset + index] += lowPad + highPad;
low[dimHOffset + index] = rewriter.getIndexAttr(lowPad);
high[dimHOffset + index] = rewriter.getIndexAttr(highPad);
});

RankedTensorType newInputType =
RankedTensorType::get(newShape, inputType.getElementType());
Value padValue = arith::ConstantOp::create(
rewriter, loc, rewriter.getZeroAttr(inputType.getElementType()));
input = tensor::PadOp::create(rewriter, loc, newInputType, input, low, high,
padValue)
.getResult();
}

auto expandGroupDim = [&](Value input, bool isFilter) -> Value {
RankedTensorType originalType = cast<RankedTensorType>(input.getType());
ArrayRef<int64_t> originalShape = originalType.getShape();
SmallVector<int64_t, 4> newShape;

if (isFilter) {
// FCHW into GFCHW
int64_t newF = originalType.getDimSize(0) / group;
assert(originalType.getDimSize(0) % group == 0 &&
"output channel must be divisible by group");
newShape.push_back(group);
newShape.push_back(newF);
newShape.push_back(originalType.getDimSize(1));
newShape.insert(newShape.end(), std::next(originalShape.begin(), 2),
originalShape.end());
RankedTensorType newType =
RankedTensorType::get(newShape, originalType.getElementType());

SmallVector<ReassociationIndices, 4> reassociation;
reassociation.push_back({0, 1});
llvm::for_each(llvm::seq<int64_t>(2, dim + 3),
[&](int64_t i) { reassociation.push_back({i}); });
return tensor::ExpandShapeOp::create(rewriter, loc, newType, input,
reassociation);
} else {
// Convert NCHW into NGCHW
int64_t newC = originalType.getDimSize(1) / group;
assert(originalType.getDimSize(1) % group == 0 &&
"input channel must be divisible by group");
newShape.push_back(originalType.getDimSize(0));
newShape.push_back(group);
newShape.push_back(newC);
newShape.insert(newShape.end(), std::next(originalShape.begin(), 2),
originalShape.end());

RankedTensorType newType =
RankedTensorType::get(newShape, originalType.getElementType());
SmallVector<ReassociationIndices, 4> reassociation;
reassociation.push_back({0});
reassociation.push_back({1, 2});
llvm::for_each(llvm::seq<int64_t>(3, dim + 3),
[&](int64_t i) { reassociation.push_back({i}); });
return tensor::ExpandShapeOp::create(rewriter, loc, newType, input,
reassociation);
}
};

// Step 2: expand group dimension (NCHW -> NGCHW, FCHW -> GFCHW). We
// want expand in group dimension because linalg.conv2d_ngchw_gfchw
// expects the layout to have the group dimension. It also makes for
// a nicer linalg.generic loop
input = expandGroupDim(input, false);
filter = expandGroupDim(filter, true);
// Step 3: emit linalg conv and collapse result to match type converter.
return emitConv(rewriter, op, input, filter);
}

// TODO: add support for scaled gemms, and migraphx::DeQuantizeLinearConverter
//===----------------------------------------------------------------------===//
// Base kernels (gemm)
Expand Down Expand Up @@ -396,13 +698,18 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
ElementwiseConverter<migraphx::SqrtOp, linalg::SqrtOp>,
ElementwiseConverter<migraphx::TanhOp, linalg::TanhOp>,
ElementwiseConverter<migraphx::RecipOp, linalg::ReciprocalOp>,
ReluConverter, ClipConverter>(converter, patterns.getContext());
ReluConverter, ClipConverter, ConvConverter>(converter,
patterns.getContext());
}

void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns(
RewritePatternSet &patterns, TypeConverter &typeConverter) {
patterns.add<AsUnderlyingShapeConverter, AsLogicalShapeOpConverter>(
typeConverter, patterns.getContext());

// mhal.launch can be generated through rocmlir-gen, so we need a way to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, can you give an example to show why we need this? Please add it to the description to motivate why we need it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During e2e testing, the rocmlir-gen seems to generate a wrapper that has mhal.launch.

Example (from test case below):

root@14298acabab2 ~/rocMLIR/build-release (pr-template-migraphx-to-linalg-conv)$ cat main.mlir
func.func @conv_1d_group(%in: !migraphx.shaped<10x8x123xf32, 984x123x1>, %fil: !migraphx.shaped<12x2x7xf32, 14x7x1>) -> !migraphx.shaped<10x12x53xf32, 636x53x1> {
  %out = migraphx.convolution %in, %fil {dilation = [4], group = 4 : i64, padding = [3,3], padding_mode = 0 : i64, stride = [2]} :
    <10x8x123xf32, 984x123x1>, <12x2x7xf32, 14x7x1> -> <10x12x53xf32,  636x53x1>
  func.return %out : !migraphx.shaped<10x12x53xf32,  636x53x1>
}
root@14298acabab2 ~/rocMLIR/build-release (pr-template-migraphx-to-linalg-conv)$ ./bin/rocmlir-gen main.mlir --clone-harness --arch gfx950 -fut conv_1d_group
module {
  func.func @conv_1d_group(%arg0: !migraphx.shaped<10x8x123xf32, 984x123x1> {mhal.read_access}, %arg1: !migraphx.shaped<12x2x7xf32, 14x7x1> {mhal.read_access}) -> (!migraphx.shaped<10x12x53xf32, 636x53x1> {mhal.write_access}) {
    %0 = migraphx.convolution %arg0, %arg1 {dilation = [4], group = 4 : i64, padding = [3, 3], padding_mode = 0 : i64, stride = [2]} : <10x8x123xf32, 984x123x1>, <12x2x7xf32, 14x7x1> -> <10x12x53xf32, 636x53x1>
    return %0 : !migraphx.shaped<10x12x53xf32, 636x53x1>
  }
  func.func @conv_1d_group_wrapper(%arg0: !migraphx.shaped<10x8x123xf32, 984x123x1>, %arg1: !migraphx.shaped<12x2x7xf32, 14x7x1>) -> !migraphx.shaped<10x12x53xf32, 636x53x1> {
    %token, %results = mhal.launch @conv_1d_group (%arg0, %arg1) : (!migraphx.shaped<10x8x123xf32, 984x123x1>, !migraphx.shaped<12x2x7xf32, 14x7x1>) -> !migraphx.shaped<10x12x53xf32, 636x53x1>
    mhal.await %token : !mhal.token
    return %results : !migraphx.shaped<10x12x53xf32, 636x53x1>
  }
  module @__xmodule_ attributes {mhal.arch = "gfx950", mhal.module} {
    func.func @conv_1d_group(%arg0: !migraphx.shaped<10x8x123xf32, 984x123x1> {mhal.read_access}, %arg1: !migraphx.shaped<12x2x7xf32, 14x7x1> {mhal.read_access}) -> (!migraphx.shaped<10x12x53xf32, 636x53x1> {mhal.write_access}) attributes {kernel, original_func = @conv_1d_group} {
      %0 = migraphx.convolution %arg0, %arg1 {dilation = [4], group = 4 : i64, padding = [3, 3], padding_mode = 0 : i64, stride = [2]} : <10x8x123xf32, 984x123x1>, <12x2x7xf32, 14x7x1> -> <10x12x53xf32, 636x53x1>
      return %0 : !migraphx.shaped<10x12x53xf32, 636x53x1>
    }
  }
}

If we don't have this line above, passing through ./bin/rocmlir-driver seems to result in failure due to type not being legalized yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you post the rocmlir-driver command as well? I cannot reproduce your issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying this diff:

diff --git a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp
index 742d2b29056d..61f746ae9414 100644
--- a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp
+++ b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp
@@ -799,7 +799,7 @@ void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns(

   // mhal.launch can be generated through rocmlir-gen, so we need a way to
   // legalize it
-  populateMIGraphXToLinalgMHALLauncherConversion(patterns, typeConverter);
+  // populateMIGraphXToLinalgMHALLauncherConversion(patterns, typeConverter);
   populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);
   populateReturnOpTypeConversionPattern(patterns, typeConverter);
   populateCallOpTypeConversionPattern(patterns, typeConverter);
diff --git a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp
index 75f5588d2024..9fd00ff82ed6 100644
--- a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp
+++ b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp
@@ -52,11 +52,11 @@ void mlir::migraphx::populateMIGraphXToLinalgBoundaryDialectConversion(
   target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
     return typeConverter.isSignatureLegal(op.getFunctionType());
   });
-  target.addDynamicallyLegalOp<mhal::LaunchOp>(
-      [=](mhal::LaunchOp op) -> std::optional<bool> {
-        return typeConverter.isLegal(op.getResultTypes()) &&
-               typeConverter.isLegal(op.getOperandTypes());
-      });
+  // target.addDynamicallyLegalOp<mhal::LaunchOp>(
+  //     [=](mhal::LaunchOp op) -> std::optional<bool> {
+  //       return typeConverter.isLegal(op.getResultTypes()) &&
+  //              typeConverter.isLegal(op.getOperandTypes());
+  //     });
   target.addDynamicallyLegalOp<func::ReturnOp>(
       [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
   target.addDynamicallyLegalOp<func::CallOp>(

Running:

./bin/rocmlir-gen -fut conv_3d -arch gfx950 --clone-harness <filename> | ./bin/rocmlir-driver --host-pipeline=migraphx-linalg,highlevel --kernel-pipeline=migraphx-linalg,highlevel

On:

func.func @conv_3d(%arg0: !migraphx.shaped<2x4x2x2x2xf32, 32x8x4x2x1>, %arg1: !migraphx.shaped<2x3x5x5x5xf32, 375x125x25x5x1>, %arg2: !migraphx.shaped<4x3x2x2x2xf32, 24x8x4x2x1>) -> !migraphx.shaped<2x4x2x2x2xf32, 32x8x4x2x1>  {
  %0 = migraphx.convolution %arg1, %arg2 {dilation = [2, 2, 2], group = 1 : i64, padding = [0, 0, 0, 0, 0, 0], padding_mode = 0 : i64, stride = [2, 2, 2]} : <2x3x5x5x5xf32, 375x125x25x5x1>, <4x3x2x2x2xf32, 24x8x4x2x1> -> <2x4x2x2x2xf32, 32x8x4x2x1>
  %1 = migraphx.add %0, %arg0 : <2x4x2x2x2xf32, 32x8x4x2x1>, <2x4x2x2x2xf32, 32x8x4x2x1> -> <2x4x2x2x2xf32, 32x8x4x2x1>
  return %1 : !migraphx.shaped<2x4x2x2x2xf32, 32x8x4x2x1>
}

I got this on my end?

loc("<stdin>":7:138): error: failed to legalize unresolved materialization from ('tensor<96xf32>') to ('!migraphx.shaped<4x3x2x2x2xf32, 24x8x4x2x1>') that remained live after conversion
Lowering failed.

// legalize it
populateMIGraphXToLinalgMHALLauncherConversion(patterns, typeConverter);
populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);
Expand Down
Loading