Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][scf] Relax requirements for loops fusion #79187

Merged
merged 1 commit into from
Jan 30, 2024

Conversation

fabrizio-indirli
Copy link
Contributor

@fabrizio-indirli fabrizio-indirli commented Jan 23, 2024

Enable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices.
Fix LIT test cases whose loops were not being fused.

@llvmbot
Copy link

llvmbot commented Jan 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (fabrizio-indirli)

Changes

Enable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices.
Avoid failing on possible aliasing when only one memref from the function args is being written.


Full diff: https://github.com/llvm/llvm-project/pull/79187.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+25-5)
  • (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+80-1)
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7a..88e22b104bcfc74 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -21,6 +21,8 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/SetVector.h"
+
 namespace mlir {
 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -63,10 +65,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
     llvm::function_ref<bool(Value, Value)> mayAlias) {
   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
   SmallVector<Value> bufferStoresVec;
+  llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
   firstPloop.getBody()->walk([&](memref::StoreOp store) {
     bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
+    const auto storeMemRef = store.getMemRef();
+    bufferStoresVec.emplace_back(storeMemRef);
+    if (llvm::isa<BlockArgument>(storeMemRef))
+      writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
+  });
+  secondPloop.getBody()->walk([&](memref::StoreOp store) {
+    const auto storeMemRef = store.getMemRef();
+    if (llvm::isa<BlockArgument>(storeMemRef))
+      writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
   });
+
   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
     Value loadMem = load.getMemRef();
     // Stop if the memref is defined in secondPloop body. Careful alias analysis
@@ -76,20 +88,28 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
       return WalkResult::interrupt();
 
     for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
+      if ((store != loadMem) && (writtenArgs.size() > 1) &&
+          mayAlias(store, loadMem))
         return WalkResult::interrupt();
 
     auto write = bufferStores.find(loadMem);
     if (write == bufferStores.end())
       return WalkResult::advance();
 
-    // Allow only single write access per buffer.
-    if (write->second.size() != 1)
+    // Check that at last one store was retrieved
+    if (!write->second.size())
       return WalkResult::interrupt();
 
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
     // Check that the load indices of secondPloop coincide with store indices of
     // firstPloop for the same memrefs.
-    auto storeIndices = write->second.front();
     auto loadIndices = load.getIndices();
     if (storeIndices.size() != loadIndices.size())
       return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e5247178..d62f5ed91dec8fb 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
 // CHECK:        [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
 // CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
@@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
 // CHECK-SAME:     to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
 // CHECK:        [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
 // CHECK:        memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
 // CHECK:        [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
 // CHECK:        memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
 // CHECK:        memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
@@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
   }
   return
 }
-
 // %sum and %result may alias with other args, do not fuse loops
 // CHECK-LABEL: func @do_not_fuse_alias
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0f32 = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
+// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
+// CHECK:      [[C0:%.*]] = arith.constant 0 : index
+// CHECK:      [[C1:%.*]] = arith.constant 1 : index
+// CHECK:      [[C2:%.*]] = arith.constant 2 : index
+// CHECK:      [[C0F32:%.*]] = arith.constant 0.
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+// CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK:        scf.parallel
+// CHECK:        scf.parallel
\ No newline at end of file

@fabrizio-indirli
Copy link
Contributor Author

fabrizio-indirli commented Jan 24, 2024

This patch attempts to enable the fusion of parallel loops in more cases, by relaxing some of the requirements:

Allowing fusion when 1st loop contains multiple writes to a buffer that is then read, but always on the same indices

Before this patch, the following loops would not be fused:

  scf.parallel (%arg0) = (%c0) to (%c5) step (%c1) {
   ...
   memref.store %c2 %alloc[%arg0]
   ...
   memref.store %c2 %alloc[%arg0]
   scf.reduce
 }
 scf.parallel (%arg0) = (%c0) to (%c5) step (%c1) {
   ...
   %2 = memref.load %alloc[%arg0]
 }

... because the first loop contains multiple stores on the same buffer, that is then read by the 2nd loop.
However I believe that there should be no problem in fusing these loops when the "stores" access exactly the same indices, like in the example above.

EDIT: removed the second part

Allowing fusion when only one function argument is being written

Before this patch, the following loops would not be fused:

func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
  ...  // constants definitions
  %sum = memref.alloc()  : memref<2x2xf32>
  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
    %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
    %sum_elem = arith.addf %B_elem, %C_elem : f32
    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
    scf.reduce
  }
  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
    %product_elem = arith.mulf %sum_elem, %A_elem : f32
    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
    scf.reduce
  }
  memref.dealloc %sum : memref<2x2xf32>
  return
}

because they were deemed as having potential aliasing issues, even though only one of the function arguments is being written.
This test case is actually taken from the LIT suite that was already in place, but the failure wasn't being picked up due to a missing check in the test: thus, the loops weren't being fused, but the test passed anyway.
In the previous version of this patch, I was attempting to handle this case by checking that at most one buffer from the function's arguments was being written among the two loops; however this was based on an incorrect assumption, thus this part of the change was removed

@Hardcode84
Copy link
Contributor

Hardcode84 commented Jan 24, 2024

Thanks for the PR, I need some more time to think about these cases.

@Hardcode84 Hardcode84 self-requested a review January 24, 2024 12:26
@fabrizio-indirli
Copy link
Contributor Author

fabrizio-indirli commented Jan 24, 2024

For the 2nd case, it's been pointed out to me that there could be an issue with global variables: to exclude potential aliasings when only one function operand is being written, we would also need to make sure that the function does not access any globals to which the operand could be aliased. I'll try to see if I can add this check as well
On second thought, I believe this shouldn't be an issue to consider in the fusion of the loops: it should be fine as long as the semantics and the data dependencies are maintained after the fusion.

@fabrizio-indirli fabrizio-indirli force-pushed the scf-fuse-parloops-relax branch 2 times, most recently from 375b519 to 4aa1aff Compare January 29, 2024 16:30
@Hardcode84
Copy link
Contributor

I think first case with multiple writes to the same indices is fine,
but I believe I have a counter-example for the second case with func arguments:

func.func @fuse_two(%A: memref, %B: memref) {
    %temp = memref.alloc
    scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
        %A_elem = memref.load %A[%i, %j]
        memref.store %A_elem, %temp[%i, %j] : memref<2x2xf32>
    }
    scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
        %elem = memref.load %temp[%i, %j]
        memref.store %elem, %B[%i, %j] : memref<2x2xf32>
    }
}

If %A and %B alias same array with an different offset, the fusion will not be legal.
And generally, I would prefer to be conservative about handling functions arguments.
What you probably want is some sort of restrict attribute on function args, so user can
guarantee args do not alias. Upstream alias analysis doesn't support this, but this is the
reason I did aliasing hook configurable.

Regarding broken tests, they probably can be fixed by making some memrefs local allocations instead
of functions args.

@fabrizio-indirli
Copy link
Contributor Author

fabrizio-indirli commented Jan 30, 2024

Thank you for having a look @Hardcode84,
I believe you're correct, I had missed those potential aliasing cases.
I have removed that part of the patch (the one handling the 2nd case) and I've modified the LIT tests.
Some of the test cases that I've made are a bit "artificial" (e.g. to avoid the aliasing issue, now the 1st loop reads from the output tensor, but this is probably uncommon in real scenarios) but they should pass; however feel free to suggest improvements ;)

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/SCF/parallel-loop-fusion.mlir Outdated Show resolved Hide resolved
Enable the fusion of parallel loops also when the 1st loop
contains multiple write accesses to the same buffer,
if the accesses are always on the same indices.
Fix LIT test cases whose loops were not being fused.

Signed-off-by: Fabrizio Indirli <[email protected]>
@fabrizio-indirli
Copy link
Contributor Author

Thanks for the review @Hardcode84 !
Unfortunately I don't have write access, do you think you could merge this patch please, if there are no further modifications needed to this change?

Note: I noticed that the CI check on Windowx x64 is failing, but the failure seems to be unrelated to my change, but rather to the CI env: CMake is returning an error "Could NOT find Python3"

@Hardcode84 Hardcode84 merged commit d17b005 into llvm:main Jan 30, 2024
3 of 4 checks passed
@Hardcode84
Copy link
Contributor

Merged, thanks for the contribution!

Windows bot is quite unreliable now, yes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants