-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][scf] Relax requirements for loops fusion #79187
[mlir][scf] Relax requirements for loops fusion #79187
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (fabrizio-indirli) ChangesEnable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices. Full diff: https://github.com/llvm/llvm-project/pull/79187.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7a..88e22b104bcfc74 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -21,6 +21,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/SetVector.h"
+
namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -63,10 +65,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
llvm::function_ref<bool(Value, Value)> mayAlias) {
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
SmallVector<Value> bufferStoresVec;
+ llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
firstPloop.getBody()->walk([&](memref::StoreOp store) {
bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
+ const auto storeMemRef = store.getMemRef();
+ bufferStoresVec.emplace_back(storeMemRef);
+ if (llvm::isa<BlockArgument>(storeMemRef))
+ writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
+ });
+ secondPloop.getBody()->walk([&](memref::StoreOp store) {
+ const auto storeMemRef = store.getMemRef();
+ if (llvm::isa<BlockArgument>(storeMemRef))
+ writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
});
+
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
Value loadMem = load.getMemRef();
// Stop if the memref is defined in secondPloop body. Careful alias analysis
@@ -76,20 +88,28 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
return WalkResult::interrupt();
for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
+ if ((store != loadMem) && (writtenArgs.size() > 1) &&
+ mayAlias(store, loadMem))
return WalkResult::interrupt();
auto write = bufferStores.find(loadMem);
if (write == bufferStores.end())
return WalkResult::advance();
- // Allow only single write access per buffer.
- if (write->second.size() != 1)
+ // Check that at last one store was retrieved
+ if (!write->second.size())
return WalkResult::interrupt();
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
// Check that the load indices of secondPloop coincide with store indices of
// firstPloop for the same memrefs.
- auto storeIndices = write->second.front();
auto loadIndices = load.getIndices();
if (storeIndices.size() != loadIndices.size())
return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e5247178..d62f5ed91dec8fb 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
@@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
@@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}
-
// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0f32 = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
+// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
+// CHECK: [[C0:%.*]] = arith.constant 0 : index
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[C2:%.*]] = arith.constant 2 : index
+// CHECK: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK: scf.parallel
+// CHECK: scf.parallel
\ No newline at end of file
|
e4faeae
to
0660514
Compare
This patch attempts to enable the fusion of parallel loops in more cases, by relaxing some of the requirements: Allowing fusion when 1st loop contains multiple writes to a buffer that is then read, but always on the same indicesBefore this patch, the following loops would not be fused:
... because the first loop contains multiple stores on the same buffer, that is then read by the 2nd loop. EDIT: removed the second part
|
Thanks for the PR, I need some more time to think about these cases. |
|
375b519
to
4aa1aff
Compare
I think first case with multiple writes to the same indices is fine,
If %A and %B alias same array with an different offset, the fusion will not be legal. Regarding broken tests, they probably can be fixed by making some memrefs local allocations instead |
4aa1aff
to
6ebafb2
Compare
Thank you for having a look @Hardcode84, |
6ebafb2
to
bdfc74f
Compare
Enable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices. Fix LIT test cases whose loops were not being fused. Signed-off-by: Fabrizio Indirli <[email protected]>
bdfc74f
to
50ff4f3
Compare
Thanks for the review @Hardcode84 ! Note: I noticed that the CI check on Windowx x64 is failing, but the failure seems to be unrelated to my change, but rather to the CI env: CMake is returning an error "Could NOT find Python3" |
Merged, thanks for the contribution! Windows bot is quite unreliable now, yes. |
Enable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices.
Fix LIT test cases whose loops were not being fused.