From 82021630e8e5bb545d77afd6271c150c60f1cfcd Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 28 Apr 2024 14:39:41 -0700 Subject: [PATCH] More aggressively unify duplicate lets (#8204) * Make unify_duplicate_lets more aggressive The simplifier can also clean up most of these, but it's harder for it because it has to consider that other mutations may have taken place. Beefing this up has no impact on lowering times for most apps, but something pathological was going on for local_laplacian. At 20 pyramid levels, this speeds up lowering by 1.3x. At 50 pyramid levels it's 2.3x. At 100 pyramid levels it's 4.1x. It also slightly reduces binary size. * Clarify comment; Avoid double-lookup into the scope Looking up with an Expr key and deep equality is expensive, so this was bad. * Add a std::move --- src/UnifyDuplicateLets.cpp | 79 +++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/src/UnifyDuplicateLets.cpp b/src/UnifyDuplicateLets.cpp index 5f6e120d6e76..fe163ae20778 100644 --- a/src/UnifyDuplicateLets.cpp +++ b/src/UnifyDuplicateLets.cpp @@ -1,6 +1,7 @@ #include "UnifyDuplicateLets.h" #include "IREquality.h" #include "IRMutator.h" +#include "Simplify.h" #include namespace Halide { @@ -14,31 +15,32 @@ namespace { class UnifyDuplicateLets : public IRMutator { using IRMutator::visit; - map scope; - map rewrites; - string producing; + // Map from Exprs to a Variable in the let name that first introduced that + // Expr. + map scope; + + // Map from Vars to the Expr they should be replaced with. + map rewrites; public: using IRMutator::mutate; Expr mutate(const Expr &e) override { - if (e.defined()) { - map::iterator iter = scope.find(e); - if (iter != scope.end()) { - return Variable::make(e.type(), iter->second); - } else { - return IRMutator::mutate(e); - } - } else { - return Expr(); + Expr new_e = IRMutator::mutate(e); + + if (auto iter = scope.find(new_e); + iter != scope.end()) { + return iter->second; } + + return new_e; } protected: Expr visit(const Variable *op) override { - map::iterator iter = rewrites.find(op->name); + auto iter = rewrites.find(op->name); if (iter != rewrites.end()) { - return Variable::make(op->type, iter->second); + return iter->second; } else { return op; } @@ -56,36 +58,41 @@ class UnifyDuplicateLets : public IRMutator { return IRMutator::visit(op); } - Stmt visit(const ProducerConsumer *op) override { - if (op->is_producer) { - string old_producing = producing; - producing = op->name; - Stmt stmt = IRMutator::visit(op); - producing = old_producing; - return stmt; - } else { - return IRMutator::visit(op); - } - } - template auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) { is_impure = false; Expr value = mutate(op->value); + Expr simplified = simplify(value); auto body = op->body; bool should_pop = false; bool should_erase = false; if (!is_impure) { - map::iterator iter = scope.find(value); - if (iter == scope.end()) { - scope[value] = op->name; - should_pop = true; - } else { - value = Variable::make(value.type(), iter->second); - rewrites[op->name] = iter->second; + if (simplified.as() || + simplified.as()) { + // The RHS collapsed to just a Var or a constant, so uses of + // this should be rewritten to that value and we should drop + // this let. The LetStmts at this point in lowering that we're + // trying to remove are generally bounds inference expressions, + // so it's not worth checking for other types of constant. + rewrites[op->name] = simplified; should_erase = true; + } else { + Expr var = Variable::make(value.type(), op->name); + + // The mutate implementation above checks Exprs + // post-mutation but without simplification, so we should + // put the unsimplified version of the Expr into the scope. + auto [it, inserted] = scope.emplace(value, std::move(var)); + + if (inserted) { + should_pop = true; + } else { + // We have the same RHS as some earlier Let + should_erase = true; + rewrites[op->name] = it->second; + } } } @@ -96,12 +103,14 @@ class UnifyDuplicateLets : public IRMutator { } if (should_erase) { rewrites.erase(op->name); + // We no longer need this let. + return body; } - if (value.same_as(op->value) && body.same_as(op->body)) { + if (simplified.same_as(op->value) && body.same_as(op->body)) { return op; } else { - return LetStmtOrLet::make(op->name, value, body); + return LetStmtOrLet::make(op->name, simplified, body); } }