diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index 10240298..785c74b9 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -807,10 +807,13 @@ struct LoadConverter : public OpConversionPattern { SmallVector mixedDims = op.getMixedMaskDims(); // Fill load destination with other value - if (Value other = op.getOther()) { - fillWithValue(loc, alloc, other, tensorType.getShape(), - op.getMixedMaskDims(), op.getStaticMaskDims(), rewriter); + Value other = op.getOther(); + if (!other) { + other = rewriter.create( + loc, rewriter.getZeroAttr(elemType)); } + fillWithValue(loc, alloc, other, tensorType.getShape(), + op.getMixedMaskDims(), op.getStaticMaskDims(), rewriter); auto ptrDefiningOp = ptr.getDefiningOp(); if (ptrDefiningOp->hasAttr(WRAP_SIDE_BY_SIDE) || diff --git a/python/examples/test_mm.py b/python/examples/test_mm.py index 5bbdcd5d..e1d1bd1c 100644 --- a/python/examples/test_mm.py +++ b/python/examples/test_mm.py @@ -62,12 +62,10 @@ def mm_kernel( a = tl.load( A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :], - other=0.0 ) b = tl.load( B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None], - other=0.0 ) if a.dtype != b.dtype: a = a.to(C.dtype.element_ty) diff --git a/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir b/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir index 4d41da34..315decbd 100644 --- a/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir +++ b/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir @@ -225,6 +225,10 @@ module { // CHECK: [[VAR_24_2_:%.+]] = arith.maxsi [[VAR_23_2_]], [[VAR_20_5_]] : index // CHECK-DAG: [[VAR_25_2_:%.+]] = arith.subi [[VAR_24_2_]], [[VAR_20_5_]] : index // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<256xf32> +// CHECK: [[CMPI_:%.+]] = arith.cmpi slt, [[VAR_25_2_]], [[CST_256_1_]] : index +// CHECK: scf.if [[CMPI_]] { +// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_2_]] : memref<256xf32>) +// CHECK: } // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_2_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> // CHECK-DAG: [[VAR_subview_6_2_:%.+]] = memref.subview [[RES_2_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref> @@ -232,6 +236,9 @@ module { // CHECK-DAG: [[VAR_26_2_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<256xf32> // CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_20_5_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref<256xf32> +// CHECK: scf.if [[CMPI_]] { +// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_3_]] : memref<256xf32>) +// CHECK: } // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> // CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_3_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref> @@ -241,8 +248,7 @@ module { // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_reinterpret_cast_11_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_28_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_29_2_:%.+]] = arith.cmpi slt, [[VAR_25_2_]], [[CST_256_1_]] : index -// CHECK: scf.if [[VAR_29_2_]] { +// CHECK: scf.if [[CMPI_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_4_]] : memref<256xf32>) // CHECK: } // CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref>