Skip to content

Commit

Permalink
[mlir][scf] Relax requirements for loops fusion
Browse files Browse the repository at this point in the history
Enable the fusion of parallel loops also when the 1st loop
contains multiple write accesses to the same buffer,
if the accesses are always on the same indices.
Avoid failing on possible aliasing when only one memref
from the function args is being written.

Signed-off-by: Fabrizio Indirli <[email protected]>
  • Loading branch information
fabrizio-indirli committed Jan 23, 2024
1 parent 32073b8 commit e4faeae
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
30 changes: 25 additions & 5 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "llvm/ADT/SetVector.h"

namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
Expand Down Expand Up @@ -63,10 +65,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
llvm::function_ref<bool(Value, Value)> mayAlias) {
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
SmallVector<Value> bufferStoresVec;
llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
firstPloop.getBody()->walk([&](memref::StoreOp store) {
bufferStores[store.getMemRef()].push_back(store.getIndices());
bufferStoresVec.emplace_back(store.getMemRef());
const auto storeMemRef = store.getMemRef();
bufferStoresVec.emplace_back(storeMemRef);
if (llvm::isa<BlockArgument>(storeMemRef))
writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
});
secondPloop.getBody()->walk([&](memref::StoreOp store) {
const auto storeMemRef = store.getMemRef();
if (llvm::isa<BlockArgument>(storeMemRef))
writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
});

auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
Value loadMem = load.getMemRef();
// Stop if the memref is defined in secondPloop body. Careful alias analysis
Expand All @@ -76,20 +88,28 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
return WalkResult::interrupt();

for (Value store : bufferStoresVec)
if (store != loadMem && mayAlias(store, loadMem))
if ((store != loadMem) && (writtenArgs.size() > 1) &&
mayAlias(store, loadMem))
return WalkResult::interrupt();

auto write = bufferStores.find(loadMem);
if (write == bufferStores.end())
return WalkResult::advance();

// Allow only single write access per buffer.
if (write->second.size() != 1)
// Check that at last one store was retrieved
if (!write->second.size())
return WalkResult::interrupt();

auto storeIndices = write->second.front();

// Multiple writes to the same memref are allowed only on the same indices
for (const auto &othStoreIndices : write->second) {
if (othStoreIndices != storeIndices)
return WalkResult::interrupt();
}

// Check that the load indices of secondPloop coincide with store indices of
// firstPloop for the same memrefs.
auto storeIndices = write->second.front();
auto loadIndices = load.getIndices();
if (storeIndices.size() != loadIndices.size())
return WalkResult::interrupt();
Expand Down
81 changes: 80 additions & 1 deletion mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
Expand Down Expand Up @@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
Expand Down Expand Up @@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}

// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel

// -----

func.func @fuse_when_1st_has_multiple_stores(
%A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0f32 = arith.constant 0.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %B_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C2:%.*]] = arith.constant 2 : index
// CHECK: [[C0F32:%.*]] = arith.constant 0.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK-NOT: scf.parallel
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]

// -----

func.func @do_not_fuse_multiple_stores_on_diff_indices(
%A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0_f32 = arith.constant 0.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %B_elem : f32
memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
// CHECK: scf.parallel
// CHECK: scf.parallel

0 comments on commit e4faeae

Please sign in to comment.