diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index d7184ad0bad2c7..8f2ab5f5e6dc13 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -83,13 +83,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex( if (write == bufferStores.end()) return WalkResult::advance(); - // Allow only single write access per buffer. - if (write->second.size() != 1) + // Check that at last one store was retrieved + if (!write->second.size()) return WalkResult::interrupt(); + auto storeIndices = write->second.front(); + + // Multiple writes to the same memref are allowed only on the same indices + for (const auto &othStoreIndices : write->second) { + if (othStoreIndices != storeIndices) + return WalkResult::interrupt(); + } + // Check that the load indices of secondPloop coincide with store indices of // firstPloop for the same memrefs. - auto storeIndices = write->second.front(); auto loadIndices = load.getIndices(); if (storeIndices.size() != loadIndices.size()) return WalkResult::interrupt(); diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 9fd33b4e524717..110168ba6eca52 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -13,9 +13,9 @@ func.func @fuse_empty_loops() { return } // CHECK-LABEL: func @fuse_empty_loops -// CHECK: [[C2:%.*]] = arith.constant 2 : index -// CHECK: [[C0:%.*]] = arith.constant 0 : index -// CHECK: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { // CHECK: scf.reduce @@ -24,16 +24,15 @@ func.func @fuse_empty_loops() { // ----- -func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, - %C: memref<2x2xf32>, %result: memref<2x2xf32>) { +func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 %sum = memref.alloc() : memref<2x2xf32> scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> - %sum_elem = arith.addf %B_elem, %C_elem : f32 + %sum_elem = arith.addf %B_elem, %c1fp : f32 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> scf.reduce } @@ -41,89 +40,90 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 - memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> scf.reduce } memref.dealloc %sum : memref<2x2xf32> return } // CHECK-LABEL: func @fuse_two -// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, -// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { -// CHECK: [[C2:%.*]] = arith.constant 2 : index -// CHECK: [[C0:%.*]] = arith.constant 0 : index -// CHECK: [[C1:%.*]] = arith.constant 1 : index +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. // CHECK: [[SUM:%.*]] = memref.alloc() // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { // CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] -// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]] -// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] // CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel // CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] -// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] // CHECK: scf.reduce // CHECK: } // CHECK: memref.dealloc [[SUM]] // ----- -func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>, - %result: memref<100x10xf32>) { - %c100 = arith.constant 100 : index - %c10 = arith.constant 10 : index +func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { + %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %broadcast_rhs = memref.alloc() : memref<100x10xf32> - %diff = memref.alloc() : memref<100x10xf32> - scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { - %rhs_elem = memref.load %rhs[%i] : memref<100xf32> - memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32> + %c1fp = arith.constant 1.0 : f32 + %c2fp = arith.constant 2.0 : f32 + %sum = memref.alloc() : memref<2x2xf32> + %prod = memref.alloc() : memref<2x2xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %c1fp : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> scf.reduce } - scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { - %lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32> - %broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32> - %diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32 - memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %c2fp : f32 + memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32> scf.reduce } - scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { - %diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32> - %exp_elem = math.exp %diff_elem : f32 - memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32> - scf.reduce + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %res_elem = arith.addf %A_elem, %c2fp : f32 + memref.store %res_elem, %B[%i, %j] : memref<2x2xf32> } - memref.dealloc %broadcast_rhs : memref<100x10xf32> - memref.dealloc %diff : memref<100x10xf32> + memref.dealloc %sum : memref<2x2xf32> + memref.dealloc %prod : memref<2x2xf32> return } // CHECK-LABEL: func @fuse_three -// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>, -// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) { -// CHECK: [[C100:%.*]] = arith.constant 100 : index -// CHECK: [[C10:%.*]] = arith.constant 10 : index -// CHECK: [[C0:%.*]] = arith.constant 0 : index -// CHECK: [[C1:%.*]] = arith.constant 1 : index -// CHECK: [[BROADCAST_RHS:%.*]] = memref.alloc() -// CHECK: [[DIFF:%.*]] = memref.alloc() +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. +// CHECK-DAG: [[C2FP:%.*]] = arith.constant 2. +// CHECK: [[SUM:%.*]] = memref.alloc() +// CHECK: [[PROD:%.*]] = memref.alloc() // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) -// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) { -// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]] -// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]] -// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]] -// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]] -// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]] -// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]] -// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]] -// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]] -// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]] +// CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]] // CHECK: scf.reduce // CHECK: } -// CHECK: memref.dealloc [[BROADCAST_RHS]] -// CHECK: memref.dealloc [[DIFF]] +// CHECK: memref.dealloc [[SUM]] +// CHECK: memref.dealloc [[PROD]] // ----- @@ -310,17 +310,16 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() { // ----- -func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, - %C: memref<2x2xf32>, %result: memref<2x2xf32>) { +func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c1fp = arith.constant 1.0 : f32 %sum = memref.alloc() : memref<2x2xf32> scf.parallel (%k) = (%c0) to (%c2) step (%c1) { scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> - %sum_elem = arith.addf %B_elem, %C_elem : f32 + %sum_elem = arith.addf %B_elem, %c1fp : f32 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> scf.reduce } @@ -328,7 +327,7 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 - memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> scf.reduce } } @@ -336,23 +335,23 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, return } // CHECK-LABEL: func @nested_fuse -// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, -// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { -// CHECK: [[C2:%.*]] = arith.constant 2 : index -// CHECK: [[C0:%.*]] = arith.constant 0 : index -// CHECK: [[C1:%.*]] = arith.constant 1 : index +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. // CHECK: [[SUM:%.*]] = memref.alloc() // CHECK: scf.parallel // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { // CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] -// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]] -// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] // CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel // CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] -// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] // CHECK: scf.reduce // CHECK: } // CHECK: } @@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>, } return } - // %sum and %result may alias with other args, do not fuse loops // CHECK-LABEL: func @do_not_fuse_alias // CHECK: scf.parallel // CHECK: scf.parallel + +// ----- + +func.func @fuse_when_1st_has_multiple_stores( + %A: memref<2x2xf32>, %B: memref<2x2xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0fp = arith.constant 0.0 : f32 + %sum = memref.alloc() : memref<2x2xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32> + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %B_elem : f32 + memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0. +// CHECK: [[SUM:%.*]] = memref.alloc() +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK-NOT: scf.parallel +// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } +// CHECK: memref.dealloc [[SUM]] + +// ----- + +func.func @do_not_fuse_multiple_stores_on_diff_indices( + %A: memref<2x2xf32>, %B: memref<2x2xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0fp = arith.constant 0.0 : f32 + %sum = memref.alloc() : memref<2x2xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32> + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum_elem = arith.addf %B_elem, %B_elem : f32 + memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32> + scf.reduce + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + %product_elem = arith.mulf %sum_elem, %A_elem : f32 + memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> + scf.reduce + } + memref.dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0. +// CHECK: [[SUM:%.*]] = memref.alloc() +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]] +// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]] +// CHECK: scf.reduce +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf +// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] +// CHECK: scf.reduce +// CHECK: } +// CHECK: memref.dealloc [[SUM]]