Skip to content

Commit

Permalink
Stronger chain detection in LoopCarry pass (#8016)
Browse files Browse the repository at this point in the history
* Stronger chain detection in LoopCarry

* Make sure that types are the same

* Add a comment

* Run CSE before calling can_prove

* Test for loop carry

* clang-tidy

* Add missing override

* Update comments
  • Loading branch information
vksnk authored Jan 9, 2024
1 parent cdebeb8 commit 91b063d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 2 deletions.
32 changes: 30 additions & 2 deletions src/LoopCarry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,34 @@ class LoopCarryOverLoop : public IRMutator {

// For each load, move the load index forwards by one loop iteration
vector<Expr> indices, next_indices, predicates, next_predicates;
// CSE-d versions of the above, so can_prove can be safely used on them.
vector<Expr> indices_csed, next_indices_csed, predicates_csed, next_predicates_csed;
for (const vector<const Load *> &v : loads) {
indices.push_back(v[0]->index);
next_indices.push_back(step_forwards(v[0]->index, linear));
predicates.push_back(v[0]->predicate);
next_predicates.push_back(step_forwards(v[0]->predicate, linear));

if (indices.back().defined()) {
indices_csed.push_back(common_subexpression_elimination(indices.back()));
} else {
indices_csed.emplace_back();
}
if (next_indices.back().defined()) {
next_indices_csed.push_back(common_subexpression_elimination(next_indices.back()));
} else {
next_indices_csed.emplace_back();
}
if (predicates.back().defined()) {
predicates_csed.push_back(common_subexpression_elimination(predicates.back()));
} else {
predicates_csed.emplace_back();
}
if (next_predicates.back().defined()) {
next_predicates_csed.push_back(common_subexpression_elimination(next_predicates.back()));
} else {
next_predicates_csed.emplace_back();
}
}

// Find loads done on this loop iteration that will be
Expand All @@ -299,11 +322,16 @@ class LoopCarryOverLoop : public IRMutator {
if (i == j) {
continue;
}
// can_prove is stronger than graph_equal, because it doesn't require index expressions to be
// exactly the same, but evaluate to the same value. We keep the graph_equal check, because
// it's faster and should be executed before the more expensive check.
if (loads[i][0]->name == loads[j][0]->name &&
next_indices[j].defined() &&
graph_equal(indices[i], next_indices[j]) &&
(graph_equal(indices[i], next_indices[j]) ||
((indices[i].type() == next_indices[j].type()) && can_prove(indices_csed[i] == next_indices_csed[j]))) &&
next_predicates[j].defined() &&
graph_equal(predicates[i], next_predicates[j])) {
(graph_equal(predicates[i], next_predicates[j]) ||
((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates_csed[i] == next_predicates_csed[j])))) {
chains.push_back({j, i});
debug(3) << "Found carried value:\n"
<< i << ": -> " << Expr(loads[i][0]) << "\n"
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ tests(GROUPS correctness
likely.cpp
load_library.cpp
logical.cpp
loop_carry.cpp
loop_invariant_extern_calls.cpp
loop_level_generator_param.cpp
lossless_cast.cpp
Expand Down
64 changes: 64 additions & 0 deletions test/correctness/loop_carry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;
using namespace Halide::Internal;

// Wrapper class to call loop_carry on a given statement.
class LoopCarryWrapper : public IRMutator {
using IRMutator::visit;

int register_count_;
Stmt mutate(const Stmt &stmt) override {
return simplify(loop_carry(stmt, register_count_));
}

public:
LoopCarryWrapper(int register_count)
: register_count_(register_count) {
}
};

int main(int argc, char **argv) {
Func input;
Func g;
Func h;
Func f;
Var x, y, xo, yo, xi, yi;

input(x, y) = x + y;

Expr sum_expr = 0;
for (int ix = -100; ix <= 100; ix++) {
// Generate two chains of sums, but only one of them will be carried.
sum_expr += input(x, y + ix);
sum_expr += input(x + 13, y + 2 * ix);
}
g(x, y) = sum_expr;
h(x, y) = g(x, y) + 12;
f(x, y) = h(x, y);

// Make a maximum number of the carried values very large for the purpose
// of this test.
constexpr int kMaxRegisterCount = 1024;
f.add_custom_lowering_pass(new LoopCarryWrapper(kMaxRegisterCount));

const int size = 128;
f.compute_root()
.bound(x, 0, size)
.bound(y, 0, size);

h.compute_root()
.tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::RoundUp);

g.compute_at(h, xo)
.reorder(y, x)
.vectorize(x, 4);

input.compute_root();

f.realize({size, size});

printf("Success!\n");
return 0;
}

0 comments on commit 91b063d

Please sign in to comment.