From 3055e2a28b2ac6f25a7a844ce7c476535f9da01d Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Tue, 2 Jan 2024 15:00:53 -0800 Subject: [PATCH 1/8] Stronger chain detection in LoopCarry --- src/LoopCarry.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 5f4d7bb519d3..cca01947f8de 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -301,9 +301,9 @@ class LoopCarryOverLoop : public IRMutator { } if (loads[i][0]->name == loads[j][0]->name && next_indices[j].defined() && - graph_equal(indices[i], next_indices[j]) && + (graph_equal(indices[i], next_indices[j]) || can_prove(indices[i] == next_indices[j])) && next_predicates[j].defined() && - graph_equal(predicates[i], next_predicates[j])) { + (graph_equal(predicates[i], next_predicates[j]) || can_prove(predicates[i] == next_predicates[j]))) { chains.push_back({j, i}); debug(3) << "Found carried value:\n" << i << ": -> " << Expr(loads[i][0]) << "\n" From 7ed555ba20f1b9e6f5848781af78ec9e91b80c0c Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Tue, 2 Jan 2024 20:12:03 -0800 Subject: [PATCH 2/8] Make sure that types are the same --- src/LoopCarry.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index cca01947f8de..fd60b7eb3dfa 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -301,9 +301,11 @@ class LoopCarryOverLoop : public IRMutator { } if (loads[i][0]->name == loads[j][0]->name && next_indices[j].defined() && - (graph_equal(indices[i], next_indices[j]) || can_prove(indices[i] == next_indices[j])) && + (graph_equal(indices[i], next_indices[j]) || + ((indices[i].type() == next_indices[j].type()) && can_prove(indices[i] == next_indices[j]))) && next_predicates[j].defined() && - (graph_equal(predicates[i], next_predicates[j]) || can_prove(predicates[i] == next_predicates[j]))) { + (graph_equal(predicates[i], next_predicates[j]) || + ((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates[i] == next_predicates[j])))) { chains.push_back({j, i}); debug(3) << "Found carried value:\n" << i << ": -> " << Expr(loads[i][0]) << "\n" From b71d889c2415e99645ca224927a3b753bb65cf18 Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Tue, 2 Jan 2024 20:15:00 -0800 Subject: [PATCH 3/8] Add a comment --- src/LoopCarry.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index fd60b7eb3dfa..5caacf38f5d9 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -299,6 +299,9 @@ class LoopCarryOverLoop : public IRMutator { if (i == j) { continue; } + // can_prove is stronger than graph_equal, because it doesn't require index expressions to be + // exactly the same, but evalutate to the same value. We keep the graph_equal check, because + // it's faster and should be executed before the more expensive check. if (loads[i][0]->name == loads[j][0]->name && next_indices[j].defined() && (graph_equal(indices[i], next_indices[j]) || From 8540d8b3fd53542cac32a9cf3f6af9de42c22c5f Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Thu, 4 Jan 2024 14:26:08 -0800 Subject: [PATCH 4/8] Run CSE before calling can_prove --- src/LoopCarry.cpp | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 5caacf38f5d9..5e90a9c94a74 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -283,11 +283,34 @@ class LoopCarryOverLoop : public IRMutator { // For each load, move the load index forwards by one loop iteration vector indices, next_indices, predicates, next_predicates; + // CSE-d versions of the above, so can_prove can be safely used on them. + vector indices_csed, next_indices_csed, predicates_csed, next_predicates_csed; for (const vector &v : loads) { indices.push_back(v[0]->index); next_indices.push_back(step_forwards(v[0]->index, linear)); predicates.push_back(v[0]->predicate); next_predicates.push_back(step_forwards(v[0]->predicate, linear)); + + if (indices.back().defined()) { + indices_csed.push_back(common_subexpression_elimination(indices.back())); + } else { + indices_csed.push_back(Expr()); + } + if (next_indices.back().defined()) { + next_indices_csed.push_back(common_subexpression_elimination(next_indices.back())); + } else { + next_indices_csed.push_back(Expr()); + } + if (predicates.back().defined()) { + predicates_csed.push_back(common_subexpression_elimination(predicates.back())); + } else { + predicates_csed.push_back(Expr()); + } + if (next_predicates.back().defined()) { + next_predicates_csed.push_back(common_subexpression_elimination(next_predicates.back())); + } else { + next_predicates_csed.push_back(Expr()); + } } // Find loads done on this loop iteration that will be @@ -300,15 +323,15 @@ class LoopCarryOverLoop : public IRMutator { continue; } // can_prove is stronger than graph_equal, because it doesn't require index expressions to be - // exactly the same, but evalutate to the same value. We keep the graph_equal check, because + // exactly the same, but evaluate to the same value. We keep the graph_equal check, because // it's faster and should be executed before the more expensive check. if (loads[i][0]->name == loads[j][0]->name && next_indices[j].defined() && (graph_equal(indices[i], next_indices[j]) || - ((indices[i].type() == next_indices[j].type()) && can_prove(indices[i] == next_indices[j]))) && + ((indices[i].type() == next_indices[j].type()) && can_prove(indices_csed[i] == next_indices_csed[j]))) && next_predicates[j].defined() && (graph_equal(predicates[i], next_predicates[j]) || - ((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates[i] == next_predicates[j])))) { + ((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates_csed[i] == next_predicates_csed[j])))) { chains.push_back({j, i}); debug(3) << "Found carried value:\n" << i << ": -> " << Expr(loads[i][0]) << "\n" From 1fceaa3b567b29c460ef2f69bf63be02901f485e Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Thu, 4 Jan 2024 14:27:06 -0800 Subject: [PATCH 5/8] Test for loop carry --- test/correctness/CMakeLists.txt | 1 + test/correctness/loop_carry.cpp | 69 +++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 test/correctness/loop_carry.cpp diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 07921a347425..cd66f21a346e 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -199,6 +199,7 @@ tests(GROUPS correctness likely.cpp load_library.cpp logical.cpp + loop_carry.cpp loop_invariant_extern_calls.cpp loop_level_generator_param.cpp lossless_cast.cpp diff --git a/test/correctness/loop_carry.cpp b/test/correctness/loop_carry.cpp new file mode 100644 index 000000000000..8efad1dfc413 --- /dev/null +++ b/test/correctness/loop_carry.cpp @@ -0,0 +1,69 @@ +#include "Halide.h" +#include + +// This file demonstrates two example custom lowering passes. The +// first just makes sure the IR passes some test, and doesn't modify +// it. The second actually changes the IR in some useful way. + +using namespace Halide; +using namespace Halide::Internal; + +// Verify that all floating point divisions by constants have been +// converted to float multiplication. +class LoopCarryWrapper : public IRMutator { + using IRMutator::visit; + + int register_count_; + Stmt mutate(const Stmt &stmt) { + return simplify(loop_carry(stmt, register_count_)); + } + +public: + LoopCarryWrapper(int register_count) + : register_count_(register_count) { + } +}; + +int main(int argc, char **argv) { + Func input; + Func g; + Func h; + Func f; + Var x, y, xo, yo, xi, yi; + + input(x, y) = x + y; + + Expr sum_expr = 0; + for (int ix = -100; ix <= 100; ix++) { + // Generate two chains of sums, but only one of them will be carried. + sum_expr += input(x, y + ix); + sum_expr += input(x + 13, y + 2 * ix); + } + g(x, y) = sum_expr; + h(x, y) = g(x, y) + 12; + f(x, y) = h(x, y); + + // Make a maximum number of the carried values very large for the purpose + // of this test. + constexpr int kMaxRegisterCount = 1024; + f.add_custom_lowering_pass(new LoopCarryWrapper(kMaxRegisterCount)); + + const int size = 128; + f.compute_root() + .bound(x, 0, size) + .bound(y, 0, size); + + h.compute_root() + .tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::RoundUp); + + g.compute_at(h, xo) + .reorder(y, x) + .vectorize(x, 4); + + input.compute_root(); + + f.realize({size, size}); + + printf("Success!\n"); + return 0; +} From 09e48f8ed67604c807e8d658ee487af8741a5881 Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Thu, 4 Jan 2024 15:29:02 -0800 Subject: [PATCH 6/8] clang-tidy --- src/LoopCarry.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 5e90a9c94a74..050cdfbfc8d9 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -294,22 +294,22 @@ class LoopCarryOverLoop : public IRMutator { if (indices.back().defined()) { indices_csed.push_back(common_subexpression_elimination(indices.back())); } else { - indices_csed.push_back(Expr()); + indices_csed.emplace_back(); } if (next_indices.back().defined()) { next_indices_csed.push_back(common_subexpression_elimination(next_indices.back())); } else { - next_indices_csed.push_back(Expr()); + next_indices_csed.emplace_back(); } if (predicates.back().defined()) { predicates_csed.push_back(common_subexpression_elimination(predicates.back())); } else { - predicates_csed.push_back(Expr()); + predicates_csed.emplace_back(); } if (next_predicates.back().defined()) { next_predicates_csed.push_back(common_subexpression_elimination(next_predicates.back())); } else { - next_predicates_csed.push_back(Expr()); + next_predicates_csed.emplace_back(); } } From 8ba85a22cd52648abd671566c1642b6b455ac5a3 Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Fri, 5 Jan 2024 10:28:53 -0800 Subject: [PATCH 7/8] Add missing override --- test/correctness/loop_carry.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/correctness/loop_carry.cpp b/test/correctness/loop_carry.cpp index 8efad1dfc413..4f37b5d7c129 100644 --- a/test/correctness/loop_carry.cpp +++ b/test/correctness/loop_carry.cpp @@ -14,7 +14,7 @@ class LoopCarryWrapper : public IRMutator { using IRMutator::visit; int register_count_; - Stmt mutate(const Stmt &stmt) { + Stmt mutate(const Stmt &stmt) override { return simplify(loop_carry(stmt, register_count_)); } From 89c74af1a0d6c73f36158d3a2373cabbb8ac0296 Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Mon, 8 Jan 2024 11:03:58 -0800 Subject: [PATCH 8/8] Update comments --- test/correctness/loop_carry.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/correctness/loop_carry.cpp b/test/correctness/loop_carry.cpp index 4f37b5d7c129..4cfba7d25f3f 100644 --- a/test/correctness/loop_carry.cpp +++ b/test/correctness/loop_carry.cpp @@ -1,15 +1,10 @@ #include "Halide.h" #include -// This file demonstrates two example custom lowering passes. The -// first just makes sure the IR passes some test, and doesn't modify -// it. The second actually changes the IR in some useful way. - using namespace Halide; using namespace Halide::Internal; -// Verify that all floating point divisions by constants have been -// converted to float multiplication. +// Wrapper class to call loop_carry on a given statement. class LoopCarryWrapper : public IRMutator { using IRMutator::visit;