From 23afcc9cd73f03ef69a499fcb1da5e9a81be5ebc Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 5 Dec 2023 16:00:46 -0800 Subject: [PATCH 1/7] Track whether or not let expressions failed to solve in solver After mutating an expression, the solver needs to know two things: 1) Did the expression contain the variable we're solving for 2) Was the expression successfully "solved" for the variable. I.e. the variable only appears once in the leftmost position. We need to know this to know property 1 of any subexpressions (i.e. does the right child of the expression contain the variable). This drives what transformations we do in ways that are guaranteed to terminate and not take exponential time. We were tracking property 1 through lets but not property 2, and this meant we were doing unhelpful transformations in some cases. I found a case in the wild where this made a pipeline take > 1 hour to compile (I killed it after an hour). It may have been in an infinite transformation loop, or it might have just been exponential. Not sure. --- src/Solve.cpp | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/Solve.cpp b/src/Solve.cpp index a08eedadbd27..af079e4e9f65 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -44,18 +44,22 @@ class SolveExpression : public IRMutator { map::iterator iter = cache.find(e); if (iter == cache.end()) { // Not in the cache, call the base class version. - debug(4) << "Mutating " << e << " (" << uses_var << ")\n"; + debug(4) << "Mutating " << e << " (" << uses_var << ", " << failed << ")\n"; bool old_uses_var = uses_var; uses_var = false; + bool old_failed = failed; + failed = false; Expr new_e = IRMutator::mutate(e); - CacheEntry entry = {new_e, uses_var}; + CacheEntry entry = {new_e, uses_var, failed}; uses_var = old_uses_var || uses_var; + failed = old_failed || failed; cache[e] = entry; - debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ")\n"; + debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ", " << failed << ")\n"; return new_e; } else { // Cache hit. uses_var = uses_var || iter->second.uses_var; + failed = failed || iter->second.failed; debug(4) << "(Hit) Rewrote " << e << " -> " << iter->second.expr << " (" << uses_var << ")\n"; return iter->second.expr; } @@ -75,7 +79,7 @@ class SolveExpression : public IRMutator { // stateless, so we can cache everything. struct CacheEntry { Expr expr; - bool uses_var; + bool uses_var, failed; }; map cache; @@ -388,16 +392,26 @@ class SolveExpression : public IRMutator { const Mul *mul_a = a.as(); Expr expr; if (a_uses_var && !b_uses_var) { + const int64_t *ib = as_const_int(b); + auto is_multiple_of_b = [&](const Expr &e) { + if (ib) { + int64_t r; + reduce_expr_modulo(e, *ib, &r); + return r == 0; + } else { + return can_prove(e / b * b == e); + } + }; if (add_a && !a_failed && - can_prove(add_a->a / b * b == add_a->a)) { + is_multiple_of_b(add_a->a)) { // (f(x) + a) / b -> f(x) / b + a / b expr = mutate(simplify(add_a->a / b) + add_a->b / b); } else if (sub_a && !a_failed && - can_prove(sub_a->a / b * b == sub_a->a)) { + is_multiple_of_b(sub_a->a)) { // (f(x) - a) / b -> f(x) / b - a / b expr = mutate(simplify(sub_a->a / b) - sub_a->b / b); } else if (mul_a && !a_failed && no_overflow_int(op->type) && - can_prove(mul_a->b / b * b == mul_a->b)) { + is_multiple_of_b(mul_a->b)) { // (f(x) * a) / b -> f(x) * (a / b) expr = mutate(mul_a->a * (mul_a->b / b)); } @@ -776,6 +790,7 @@ class SolveExpression : public IRMutator { } else if (scope.contains(op->name)) { CacheEntry e = scope.get(op->name); uses_var = uses_var || e.uses_var; + failed = failed || e.failed; return e.expr; } else if (external_scope.contains(op->name)) { Expr e = external_scope.get(op->name); @@ -790,11 +805,14 @@ class SolveExpression : public IRMutator { Expr visit(const Let *op) override { bool old_uses_var = uses_var; + bool old_failed = failed; uses_var = false; + failed = false; Expr value = mutate(op->value); - CacheEntry e = {value, uses_var}; - + CacheEntry e = {value, uses_var, failed}; uses_var = old_uses_var; + failed = old_failed; + ScopedBinding bind(scope, op->name, e); return mutate(op->body); } From 224672105d17c6be27d72d742d5f5fced173940b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 22 Jan 2024 10:27:43 -0800 Subject: [PATCH 2/7] Remove surplus comma --- src/IRPrinter.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index dc07d0e0f010..52cb3714268c 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -1109,7 +1109,6 @@ void IRPrinter::visit(const VectorReduce *op) { stream << "(" << op->type << ")vector_reduce_" << op->op << "(" - << ", " << op->value << ")"; } From c5578b254b4d6ee696a1256e1a1655b6a909275e Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 22 Jan 2024 10:49:23 -0800 Subject: [PATCH 3/7] Fix use of uninitialized value that could cause bad transformation --- src/ModulusRemainder.h | 6 ++++-- src/Solve.cpp | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/ModulusRemainder.h b/src/ModulusRemainder.h index c0341b75abf6..cbcdce10b98c 100644 --- a/src/ModulusRemainder.h +++ b/src/ModulusRemainder.h @@ -7,6 +7,8 @@ #include +#include "Util.h" + namespace Halide { struct Expr; @@ -83,8 +85,8 @@ ModulusRemainder modulus_remainder(const Expr &e, const Scope /** Reduce an expression modulo some integer. Returns true and assigns * to remainder if an answer could be found. */ ///@{ -bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder); -bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope &scope); +HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder); +HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope &scope); ///@} void modulus_remainder_test(); diff --git a/src/Solve.cpp b/src/Solve.cpp index af079e4e9f65..22bd14e44412 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -395,9 +395,8 @@ class SolveExpression : public IRMutator { const int64_t *ib = as_const_int(b); auto is_multiple_of_b = [&](const Expr &e) { if (ib) { - int64_t r; - reduce_expr_modulo(e, *ib, &r); - return r == 0; + int64_t r = 0; + return reduce_expr_modulo(e, *ib, &r) && r == 0; } else { return can_prove(e / b * b == e); } From 85d7eddbef457b4981cbd29adf8ded7271a02b77 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Wed, 24 Jan 2024 18:43:49 -0800 Subject: [PATCH 4/7] trigger buildbots From ccea846d0526f618601b521cf695c192da0b23a9 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Wed, 24 Jan 2024 21:47:28 -0800 Subject: [PATCH 5/7] trigger buildbots From 9bd13a50618b6371a1b66bed321c6692470244c9 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Thu, 25 Jan 2024 09:25:11 -0800 Subject: [PATCH 6/7] trigger buildbots From dae1bfbbbdb348f00b8345de7fc3f1171f695409 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Thu, 25 Jan 2024 17:50:08 -0800 Subject: [PATCH 7/7] trigger buildbots