From 270f7d17743be24f653efa47a093a86413d19eac Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Wed, 7 Aug 2024 11:30:34 +0800 Subject: [PATCH] [compiler] fix remove-copy returns redundant copy --- .../Dialect/MemRef/Transforms/RemoveCopy.cpp | 12 ++++++++-- .../test/E2E/Host/AliasLike/01_HostOpt.mlir | 22 +++++++++---------- .../test/E2E/Host/AliasLike/02a_ByreHost.mlir | 8 +++---- .../test/E2E/Host/AliasLike/02b_ToLLVM.mlir | 8 +++---- compiler/test/E2E/Host/AliasLike/Output.mlir | 8 +++---- 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp b/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp index 7b81af4e2..44523eb03 100644 --- a/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp +++ b/compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp @@ -165,16 +165,19 @@ class RemoveCopyPattern : public OpRewritePattern { return false; }; + bool targetUsedInTerm = false; if (target.getType() != src.getType()) { // skip copy when it is used in a terminator if (auto srcAlloc = src.getDefiningOp()) { if (allocUseInTerminator(srcAlloc)) { + LLVM_DEBUG(llvm::dbgs() << "src is used in a terminator"); return failure(); } } if (auto targetAlloc = target.getDefiningOp()) { if (allocUseInTerminator(targetAlloc)) { - return failure(); + LLVM_DEBUG(llvm::dbgs() << "target is used in a terminator"); + targetUsedInTerm = true; } } } @@ -254,8 +257,10 @@ class RemoveCopyPattern : public OpRewritePattern { } // now it is legal to rewrite. + LLVM_DEBUG(llvm::dbgs() << "it is legal to rewrite " << copyOp << "\n"); // we prefer target alloc over src alloc in this implementation if (auto targetAlloc = target.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() << "match target alloc: " << targetAlloc << "\n"); if (auto srcDef = src.getDefiningOp()) { if (isa(srcDef)) @@ -269,11 +274,13 @@ class RemoveCopyPattern : public OpRewritePattern { return failure(); } - if (!anyIncompatibleUse(target, src)) { + LLVM_DEBUG(llvm::dbgs() << "check anyIncompatibleUse\n"); + if (!anyIncompatibleUse(target, src) && !targetUsedInTerm) { replaceUsesAndPropagateType(rewriter, targetAlloc, src); return success(); } + LLVM_DEBUG(llvm::dbgs() << "check anyIncompatibleUseWithCast\n"); if (!anyIncompatibleUseWithCast(target, src)) { // The memref of source and target are contiguous, cast source value to // the same type with target. As `byre.alias` could handle source with @@ -324,6 +331,7 @@ class RemoveCopyPattern : public OpRewritePattern { } if (auto srcAlloc = src.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() << "match src alloc: " << srcAlloc << "\n"); if (auto targetDef = target.getDefiningOp()) { if (isa(targetDef)) hoistUpOpInBlock(targetDef, domInfo); diff --git a/compiler/test/E2E/Host/AliasLike/01_HostOpt.mlir b/compiler/test/E2E/Host/AliasLike/01_HostOpt.mlir index daa346a43..173f39fcd 100644 --- a/compiler/test/E2E/Host/AliasLike/01_HostOpt.mlir +++ b/compiler/test/E2E/Host/AliasLike/01_HostOpt.mlir @@ -21,19 +21,17 @@ module { } func.func @main(%arg0: memref<512x200xf32>, %arg1: memref<512x200xf32>) -> (memref<128x2x100xf32>, memref<128x2x100xf32>, memref<1x100xf32>, memref<1x100xf32>, memref<512x200xf32>) attributes {__placeholder__byre.entry_point} { %subview = memref.subview %arg0[0, 0] [128, 200] [1, 1] : memref<512x200xf32> to memref<128x200xf32, strided<[200, 1]>> - %subview_0 = memref.subview %arg1[10, 0] [128, 200] [1, 1] : memref<512x200xf32> to memref<128x200xf32, strided<[200, 1], offset: 2000>> - %expand_shape = memref.expand_shape %subview [[0], [1, 2]] output_shape [128, 2, 100] : memref<128x200xf32, strided<[200, 1]>> into memref<128x2x100xf32, strided<[200, 100, 1]>> - %expand_shape_1 = memref.expand_shape %subview_0 [[0], [1, 2]] output_shape [128, 2, 100] : memref<128x200xf32, strided<[200, 1], offset: 2000>> into memref<128x2x100xf32, strided<[200, 100, 1], offset: 2000>> - %subview_2 = memref.subview %arg0[0, 0] [1, 100] [1, 1] : memref<512x200xf32> to memref<1x100xf32, strided<[200, 1]>> - %subview_3 = memref.subview %arg1[10, 100] [1, 100] [1, 1] : memref<512x200xf32> to memref<1x100xf32, strided<[200, 1], offset: 2100>> + %subview_0 = memref.subview %arg1[10, 100] [1, 100] [1, 1] : memref<512x200xf32> to memref<1x100xf32, strided<[200, 1], offset: 2100>> + %subview_1 = memref.subview %arg1[10, 0] [128, 200] [1, 1] : memref<512x200xf32> to memref<128x200xf32, strided<[200, 1], offset: 2000>> + %expand_shape = memref.expand_shape %subview_1 [[0], [1, 2]] output_shape [128, 2, 100] : memref<128x200xf32, strided<[200, 1], offset: 2000>> into memref<128x2x100xf32, strided<[200, 100, 1], offset: 2000>> + %expand_shape_2 = memref.expand_shape %subview [[0], [1, 2]] output_shape [128, 2, 100] : memref<128x200xf32, strided<[200, 1]>> into memref<128x2x100xf32, strided<[200, 100, 1]>> %0 = call @Unknown0(%arg0, %arg1) : (memref<512x200xf32>, memref<512x200xf32>) -> memref<512x200xf32> - %cast = memref.cast %expand_shape : memref<128x2x100xf32, strided<[200, 100, 1]>> to memref<128x2x100xf32> + %cast = memref.cast %expand_shape_2 : memref<128x2x100xf32, strided<[200, 100, 1]>> to memref<128x2x100xf32> %alloc = memref.alloc() : memref<128x2x100xf32> - memref.copy %expand_shape_1, %alloc : memref<128x2x100xf32, strided<[200, 100, 1], offset: 2000>> to memref<128x2x100xf32> - %alloc_4 = memref.alloc() : memref<1x100xf32> - memref.copy %subview_2, %alloc_4 : memref<1x100xf32, strided<[200, 1]>> to memref<1x100xf32> - %alloc_5 = memref.alloc() : memref<1x100xf32> - memref.copy %subview_3, %alloc_5 : memref<1x100xf32, strided<[200, 1], offset: 2100>> to memref<1x100xf32> - return %cast, %alloc, %alloc_4, %alloc_5, %0 : memref<128x2x100xf32>, memref<128x2x100xf32>, memref<1x100xf32>, memref<1x100xf32>, memref<512x200xf32> + memref.copy %expand_shape, %alloc : memref<128x2x100xf32, strided<[200, 100, 1], offset: 2000>> to memref<128x2x100xf32> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 100], strides: [100, 1] : memref<512x200xf32> to memref<1x100xf32> + %alloc_3 = memref.alloc() : memref<1x100xf32> + memref.copy %subview_0, %alloc_3 : memref<1x100xf32, strided<[200, 1], offset: 2100>> to memref<1x100xf32> + return %cast, %alloc, %reinterpret_cast, %alloc_3, %0 : memref<128x2x100xf32>, memref<128x2x100xf32>, memref<1x100xf32>, memref<1x100xf32>, memref<512x200xf32> } } \ No newline at end of file diff --git a/compiler/test/E2E/Host/AliasLike/02a_ByreHost.mlir b/compiler/test/E2E/Host/AliasLike/02a_ByreHost.mlir index 801f3622a..07a2f07e2 100644 --- a/compiler/test/E2E/Host/AliasLike/02a_ByreHost.mlir +++ b/compiler/test/E2E/Host/AliasLike/02a_ByreHost.mlir @@ -22,14 +22,14 @@ module attributes {byre.container_module} { } func.func @main(%arg0: memref<512x200xf32, "cpu"> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<512x200xf32, "cpu"> {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg2: memref<128x2x100xf32, "cpu"> {byre.argname = "Output0", byre.argtype = 2 : i32}, %arg3: memref<128x2x100xf32, "cpu"> {byre.argname = "Output1", byre.argtype = 2 : i32}, %arg4: memref<1x100xf32, "cpu"> {byre.argname = "Output2", byre.argtype = 2 : i32}, %arg5: memref<1x100xf32, "cpu"> {byre.argname = "Output3", byre.argtype = 2 : i32}, %arg6: memref<512x200xf32, "cpu"> {byre.argname = "Output4", byre.argtype = 2 : i32}) attributes {byre.entry_point} { byre.compute @LLVMJITOp(%arg0, %arg1, %arg6) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 2 : i32]} : memref<512x200xf32, "cpu">, memref<512x200xf32, "cpu">, memref<512x200xf32, "cpu"> - %0 = "byre.alias"(%arg1) <{offset = 2000 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> - byre.copy(%0, %arg3) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> - %1 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> - byre.copy(%1, %arg4) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> + %0 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> + %1 = "byre.alias"(%arg1) <{offset = 2000 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> + byre.copy(%1, %arg3) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> %2 = "byre.alias"(%arg1) <{offset = 2100 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> byre.copy(%2, %arg5) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> %3 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> byre.copy(%3, %arg2) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> + byre.copy(%0, %arg4) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> return } } \ No newline at end of file diff --git a/compiler/test/E2E/Host/AliasLike/02b_ToLLVM.mlir b/compiler/test/E2E/Host/AliasLike/02b_ToLLVM.mlir index f1d36fe6c..bcc65d919 100644 --- a/compiler/test/E2E/Host/AliasLike/02b_ToLLVM.mlir +++ b/compiler/test/E2E/Host/AliasLike/02b_ToLLVM.mlir @@ -22,14 +22,14 @@ module attributes {byre.container_module} { } func.func @main(%arg0: memref<512x200xf32, "cpu"> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<512x200xf32, "cpu"> {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg2: memref<128x2x100xf32, "cpu"> {byre.argname = "Output0", byre.argtype = 2 : i32}, %arg3: memref<128x2x100xf32, "cpu"> {byre.argname = "Output1", byre.argtype = 2 : i32}, %arg4: memref<1x100xf32, "cpu"> {byre.argname = "Output2", byre.argtype = 2 : i32}, %arg5: memref<1x100xf32, "cpu"> {byre.argname = "Output3", byre.argtype = 2 : i32}, %arg6: memref<512x200xf32, "cpu"> {byre.argname = "Output4", byre.argtype = 2 : i32}) attributes {byre.entry_point} { byre.compute @LLVMJITOp(%arg0, %arg1, %arg6) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 2 : i32]} : memref<512x200xf32, "cpu">, memref<512x200xf32, "cpu">, memref<512x200xf32, "cpu"> - %0 = "byre.alias"(%arg1) <{offset = 2000 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> - byre.copy(%0, %arg3) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> - %1 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> - byre.copy(%1, %arg4) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> + %0 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> + %1 = "byre.alias"(%arg1) <{offset = 2000 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> + byre.copy(%1, %arg3) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> %2 = "byre.alias"(%arg1) <{offset = 2100 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> byre.copy(%2, %arg5) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> %3 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> byre.copy(%3, %arg2) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> + byre.copy(%0, %arg4) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> return } } \ No newline at end of file diff --git a/compiler/test/E2E/Host/AliasLike/Output.mlir b/compiler/test/E2E/Host/AliasLike/Output.mlir index 894219267..90000a017 100644 --- a/compiler/test/E2E/Host/AliasLike/Output.mlir +++ b/compiler/test/E2E/Host/AliasLike/Output.mlir @@ -5,14 +5,14 @@ module attributes {byre.container_module} { func.func @main(%arg0: memref<512x200xf32, "cpu"> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<512x200xf32, "cpu"> {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg2: memref<128x2x100xf32, "cpu"> {byre.argname = "Output0", byre.argtype = 2 : i32}, %arg3: memref<128x2x100xf32, "cpu"> {byre.argname = "Output1", byre.argtype = 2 : i32}, %arg4: memref<1x100xf32, "cpu"> {byre.argname = "Output2", byre.argtype = 2 : i32}, %arg5: memref<1x100xf32, "cpu"> {byre.argname = "Output3", byre.argtype = 2 : i32}, %arg6: memref<512x200xf32, "cpu"> {byre.argname = "Output4", byre.argtype = 2 : i32}) attributes {byre.entry_point, device_file_name = "your_file"} { byre.compute @LLVMJITOp(%arg0, %arg1, %arg6) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 2 : i32]} : memref<512x200xf32, "cpu">, memref<512x200xf32, "cpu">, memref<512x200xf32, "cpu"> - %0 = "byre.alias"(%arg1) <{offset = 2000 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> - byre.copy(%0, %arg3) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> - %1 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> - byre.copy(%1, %arg4) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> + %0 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> + %1 = "byre.alias"(%arg1) <{offset = 2000 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> + byre.copy(%1, %arg3) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> %2 = "byre.alias"(%arg1) <{offset = 2100 : i64}> : (memref<512x200xf32, "cpu">) -> memref<1x100xf32, "cpu"> byre.copy(%2, %arg5) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> %3 = "byre.alias"(%arg0) <{offset = 0 : i64}> : (memref<512x200xf32, "cpu">) -> memref<128x2x100xf32, "cpu"> byre.copy(%3, %arg2) {callee = "cpu2cpu"} : memref<128x2x100xf32, "cpu">, memref<128x2x100xf32, "cpu"> + byre.copy(%0, %arg4) {callee = "cpu2cpu"} : memref<1x100xf32, "cpu">, memref<1x100xf32, "cpu"> return } } \ No newline at end of file