Skip to content

Commit

Permalink
[mlir][scf] Relax requirements for loops fusion
Browse files Browse the repository at this point in the history
Enable the fusion of parallel loops also when the 1st loop
contains multiple write accesses to the same buffer,
if the accesses are always on the same indices.
Fix LIT test cases whose loops were not being fused.

Signed-off-by: Fabrizio Indirli <[email protected]>
  • Loading branch information
fabrizio-indirli committed Jan 30, 2024
1 parent 32073b8 commit bdfc74f
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 63 deletions.
13 changes: 10 additions & 3 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
if (write == bufferStores.end())
return WalkResult::advance();

// Allow only single write access per buffer.
if (write->second.size() != 1)
// Check that at last one store was retrieved
if (!write->second.size())
return WalkResult::interrupt();

auto storeIndices = write->second.front();

// Multiple writes to the same memref are allowed only on the same indices
for (const auto &othStoreIndices : write->second) {
if (othStoreIndices != storeIndices)
return WalkResult::interrupt();
}

// Check that the load indices of secondPloop coincide with store indices of
// firstPloop for the same memrefs.
auto storeIndices = write->second.front();
auto loadIndices = load.getIndices();
if (storeIndices.size() != loadIndices.size())
return WalkResult::interrupt();
Expand Down
213 changes: 153 additions & 60 deletions mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,106 +24,106 @@ func.func @fuse_empty_loops() {

// -----

func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1fp = arith.constant 1.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
%sum_elem = arith.addf %B_elem, %c1fp : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_two
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C1FP:%.*]] = arith.constant 1.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]

// -----

func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
%result: memref<100x10xf32>) {
%c100 = arith.constant 100 : index
%c10 = arith.constant 10 : index
func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%broadcast_rhs = memref.alloc() : memref<100x10xf32>
%diff = memref.alloc() : memref<100x10xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%rhs_elem = memref.load %rhs[%i] : memref<100xf32>
memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
%c1fp = arith.constant 1.0 : f32
%c2fp = arith.constant 2.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
%prod = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %c1fp : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
%broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
%diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %c2fp : f32
memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32>
%exp_elem = math.exp %diff_elem : f32
memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
scf.reduce
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%res_elem = arith.addf %A_elem, %c2fp : f32
memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
}
memref.dealloc %broadcast_rhs : memref<100x10xf32>
memref.dealloc %diff : memref<100x10xf32>
memref.dealloc %sum : memref<2x2xf32>
memref.dealloc %prod : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_three
// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) {
// CHECK: [[C100:%.*]] = arith.constant 100 : index
// CHECK: [[C10:%.*]] = arith.constant 10 : index
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[BROADCAST_RHS:%.*]] = memref.alloc()
// CHECK: [[DIFF:%.*]] = memref.alloc()
// CHECK: [[C1FP:%.*]] = arith.constant 1.
// CHECK: [[C2FP:%.*]] = arith.constant 2.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: [[PROD:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
// CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[BROADCAST_RHS]]
// CHECK: memref.dealloc [[DIFF]]
// CHECK: memref.dealloc [[SUM]]
// CHECK: memref.dealloc [[PROD]]

// -----

Expand Down Expand Up @@ -310,49 +310,48 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {

// -----

func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1fp = arith.constant 1.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
%sum_elem = arith.addf %B_elem, %c1fp : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @nested_fuse
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C1FP:%.*]] = arith.constant 1.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: }
Expand Down Expand Up @@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}

// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel

// -----

func.func @fuse_when_1st_has_multiple_stores(
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0fp = arith.constant 0.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %B_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0F32:%.*]] = arith.constant 0.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]

// -----

func.func @do_not_fuse_multiple_stores_on_diff_indices(
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0fp = arith.constant 0.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %B_elem : f32
memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0F32:%.*]] = arith.constant 0.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
// CHECK: scf.reduce
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]

0 comments on commit bdfc74f

Please sign in to comment.