diff --git a/src/ModulusRemainder.h b/src/ModulusRemainder.h index c0341b75abf6..cbcdce10b98c 100644 --- a/src/ModulusRemainder.h +++ b/src/ModulusRemainder.h @@ -7,6 +7,8 @@ #include +#include "Util.h" + namespace Halide { struct Expr; @@ -83,8 +85,8 @@ ModulusRemainder modulus_remainder(const Expr &e, const Scope /** Reduce an expression modulo some integer. Returns true and assigns * to remainder if an answer could be found. */ ///@{ -bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder); -bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope &scope); +HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder); +HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope &scope); ///@} void modulus_remainder_test(); diff --git a/src/Solve.cpp b/src/Solve.cpp index a08eedadbd27..22bd14e44412 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -44,18 +44,22 @@ class SolveExpression : public IRMutator { map::iterator iter = cache.find(e); if (iter == cache.end()) { // Not in the cache, call the base class version. - debug(4) << "Mutating " << e << " (" << uses_var << ")\n"; + debug(4) << "Mutating " << e << " (" << uses_var << ", " << failed << ")\n"; bool old_uses_var = uses_var; uses_var = false; + bool old_failed = failed; + failed = false; Expr new_e = IRMutator::mutate(e); - CacheEntry entry = {new_e, uses_var}; + CacheEntry entry = {new_e, uses_var, failed}; uses_var = old_uses_var || uses_var; + failed = old_failed || failed; cache[e] = entry; - debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ")\n"; + debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ", " << failed << ")\n"; return new_e; } else { // Cache hit. uses_var = uses_var || iter->second.uses_var; + failed = failed || iter->second.failed; debug(4) << "(Hit) Rewrote " << e << " -> " << iter->second.expr << " (" << uses_var << ")\n"; return iter->second.expr; } @@ -75,7 +79,7 @@ class SolveExpression : public IRMutator { // stateless, so we can cache everything. struct CacheEntry { Expr expr; - bool uses_var; + bool uses_var, failed; }; map cache; @@ -388,16 +392,25 @@ class SolveExpression : public IRMutator { const Mul *mul_a = a.as(); Expr expr; if (a_uses_var && !b_uses_var) { + const int64_t *ib = as_const_int(b); + auto is_multiple_of_b = [&](const Expr &e) { + if (ib) { + int64_t r = 0; + return reduce_expr_modulo(e, *ib, &r) && r == 0; + } else { + return can_prove(e / b * b == e); + } + }; if (add_a && !a_failed && - can_prove(add_a->a / b * b == add_a->a)) { + is_multiple_of_b(add_a->a)) { // (f(x) + a) / b -> f(x) / b + a / b expr = mutate(simplify(add_a->a / b) + add_a->b / b); } else if (sub_a && !a_failed && - can_prove(sub_a->a / b * b == sub_a->a)) { + is_multiple_of_b(sub_a->a)) { // (f(x) - a) / b -> f(x) / b - a / b expr = mutate(simplify(sub_a->a / b) - sub_a->b / b); } else if (mul_a && !a_failed && no_overflow_int(op->type) && - can_prove(mul_a->b / b * b == mul_a->b)) { + is_multiple_of_b(mul_a->b)) { // (f(x) * a) / b -> f(x) * (a / b) expr = mutate(mul_a->a * (mul_a->b / b)); } @@ -776,6 +789,7 @@ class SolveExpression : public IRMutator { } else if (scope.contains(op->name)) { CacheEntry e = scope.get(op->name); uses_var = uses_var || e.uses_var; + failed = failed || e.failed; return e.expr; } else if (external_scope.contains(op->name)) { Expr e = external_scope.get(op->name); @@ -790,11 +804,14 @@ class SolveExpression : public IRMutator { Expr visit(const Let *op) override { bool old_uses_var = uses_var; + bool old_failed = failed; uses_var = false; + failed = false; Expr value = mutate(op->value); - CacheEntry e = {value, uses_var}; - + CacheEntry e = {value, uses_var, failed}; uses_var = old_uses_var; + failed = old_failed; + ScopedBinding bind(scope, op->name, e); return mutate(op->body); }