diff --git a/src/UnifyDuplicateLets.cpp b/src/UnifyDuplicateLets.cpp index 5f6e120d6e76..fe163ae20778 100644 --- a/src/UnifyDuplicateLets.cpp +++ b/src/UnifyDuplicateLets.cpp @@ -1,6 +1,7 @@ #include "UnifyDuplicateLets.h" #include "IREquality.h" #include "IRMutator.h" +#include "Simplify.h" #include namespace Halide { @@ -14,31 +15,32 @@ namespace { class UnifyDuplicateLets : public IRMutator { using IRMutator::visit; - map scope; - map rewrites; - string producing; + // Map from Exprs to a Variable in the let name that first introduced that + // Expr. + map scope; + + // Map from Vars to the Expr they should be replaced with. + map rewrites; public: using IRMutator::mutate; Expr mutate(const Expr &e) override { - if (e.defined()) { - map::iterator iter = scope.find(e); - if (iter != scope.end()) { - return Variable::make(e.type(), iter->second); - } else { - return IRMutator::mutate(e); - } - } else { - return Expr(); + Expr new_e = IRMutator::mutate(e); + + if (auto iter = scope.find(new_e); + iter != scope.end()) { + return iter->second; } + + return new_e; } protected: Expr visit(const Variable *op) override { - map::iterator iter = rewrites.find(op->name); + auto iter = rewrites.find(op->name); if (iter != rewrites.end()) { - return Variable::make(op->type, iter->second); + return iter->second; } else { return op; } @@ -56,36 +58,41 @@ class UnifyDuplicateLets : public IRMutator { return IRMutator::visit(op); } - Stmt visit(const ProducerConsumer *op) override { - if (op->is_producer) { - string old_producing = producing; - producing = op->name; - Stmt stmt = IRMutator::visit(op); - producing = old_producing; - return stmt; - } else { - return IRMutator::visit(op); - } - } - template auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { is_impure = false; Expr value = mutate(op->value); + Expr simplified = simplify(value); auto body = op->body; bool should_pop = false; bool should_erase = false; if (!is_impure) { - map::iterator iter = scope.find(value); - if (iter == scope.end()) { - scope[value] = op->name; - should_pop = true; - } else { - value = Variable::make(value.type(), iter->second); - rewrites[op->name] = iter->second; + if (simplified.as() || + simplified.as()) { + // The RHS collapsed to just a Var or a constant, so uses of + // this should be rewritten to that value and we should drop + // this let. The LetStmts at this point in lowering that we're + // trying to remove are generally bounds inference expressions, + // so it's not worth checking for other types of constant. + rewrites[op->name] = simplified; should_erase = true; + } else { + Expr var = Variable::make(value.type(), op->name); + + // The mutate implementation above checks Exprs + // post-mutation but without simplification, so we should + // put the unsimplified version of the Expr into the scope. + auto [it, inserted] = scope.emplace(value, std::move(var)); + + if (inserted) { + should_pop = true; + } else { + // We have the same RHS as some earlier Let + should_erase = true; + rewrites[op->name] = it->second; + } } } @@ -96,12 +103,14 @@ class UnifyDuplicateLets : public IRMutator { } if (should_erase) { rewrites.erase(op->name); + // We no longer need this let. + return body; } - if (value.same_as(op->value) && body.same_as(op->body)) { + if (simplified.same_as(op->value) && body.same_as(op->body)) { return op; } else { - return LetStmtOrLet::make(op->name, value, body); + return LetStmtOrLet::make(op->name, simplified, body); } }