Skip to content

Commit

Permalink
More aggressively unify duplicate lets (#8204)
Browse files Browse the repository at this point in the history
* Make unify_duplicate_lets more aggressive

The simplifier can also clean up most of these, but it's harder for it
because it has to consider that other mutations may have taken place.
Beefing this up has no impact on lowering times for most apps, but
something pathological was going on for local_laplacian. At 20 pyramid
levels, this speeds up lowering by 1.3x. At 50 pyramid levels it's 2.3x.
At 100 pyramid levels it's 4.1x.

It also slightly reduces binary size.

* Clarify comment; Avoid double-lookup into the scope

Looking up with an Expr key and deep equality is expensive, so this was
bad.

* Add a std::move
  • Loading branch information
abadams authored Apr 28, 2024
1 parent 64caf31 commit 8202163
Showing 1 changed file with 44 additions and 35 deletions.
79 changes: 44 additions & 35 deletions src/UnifyDuplicateLets.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "UnifyDuplicateLets.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "Simplify.h"
#include <map>

namespace Halide {
Expand All @@ -14,31 +15,32 @@ namespace {
class UnifyDuplicateLets : public IRMutator {
using IRMutator::visit;

map<Expr, string, IRDeepCompare> scope;
map<string, string> rewrites;
string producing;
// Map from Exprs to a Variable in the let name that first introduced that
// Expr.
map<Expr, Expr, IRDeepCompare> scope;

// Map from Vars to the Expr they should be replaced with.
map<string, Expr> rewrites;

public:
using IRMutator::mutate;

Expr mutate(const Expr &e) override {
if (e.defined()) {
map<Expr, string, IRDeepCompare>::iterator iter = scope.find(e);
if (iter != scope.end()) {
return Variable::make(e.type(), iter->second);
} else {
return IRMutator::mutate(e);
}
} else {
return Expr();
Expr new_e = IRMutator::mutate(e);

if (auto iter = scope.find(new_e);
iter != scope.end()) {
return iter->second;
}

return new_e;
}

protected:
Expr visit(const Variable *op) override {
map<string, string>::iterator iter = rewrites.find(op->name);
auto iter = rewrites.find(op->name);
if (iter != rewrites.end()) {
return Variable::make(op->type, iter->second);
return iter->second;
} else {
return op;
}
Expand All @@ -56,36 +58,41 @@ class UnifyDuplicateLets : public IRMutator {
return IRMutator::visit(op);
}

Stmt visit(const ProducerConsumer *op) override {
if (op->is_producer) {
string old_producing = producing;
producing = op->name;
Stmt stmt = IRMutator::visit(op);
producing = old_producing;
return stmt;
} else {
return IRMutator::visit(op);
}
}

template<typename LetStmtOrLet>
auto visit_let(const LetStmtOrLet *op) -> decltype(op->body) {
is_impure = false;
Expr value = mutate(op->value);
Expr simplified = simplify(value);
auto body = op->body;

bool should_pop = false;
bool should_erase = false;

if (!is_impure) {
map<Expr, string, IRDeepCompare>::iterator iter = scope.find(value);
if (iter == scope.end()) {
scope[value] = op->name;
should_pop = true;
} else {
value = Variable::make(value.type(), iter->second);
rewrites[op->name] = iter->second;
if (simplified.as<Variable>() ||
simplified.as<IntImm>()) {
// The RHS collapsed to just a Var or a constant, so uses of
// this should be rewritten to that value and we should drop
// this let. The LetStmts at this point in lowering that we're
// trying to remove are generally bounds inference expressions,
// so it's not worth checking for other types of constant.
rewrites[op->name] = simplified;
should_erase = true;
} else {
Expr var = Variable::make(value.type(), op->name);

// The mutate implementation above checks Exprs
// post-mutation but without simplification, so we should
// put the unsimplified version of the Expr into the scope.
auto [it, inserted] = scope.emplace(value, std::move(var));

if (inserted) {
should_pop = true;
} else {
// We have the same RHS as some earlier Let
should_erase = true;
rewrites[op->name] = it->second;
}
}
}

Expand All @@ -96,12 +103,14 @@ class UnifyDuplicateLets : public IRMutator {
}
if (should_erase) {
rewrites.erase(op->name);
// We no longer need this let.
return body;
}

if (value.same_as(op->value) && body.same_as(op->body)) {
if (simplified.same_as(op->value) && body.same_as(op->body)) {
return op;
} else {
return LetStmtOrLet::make(op->name, value, body);
return LetStmtOrLet::make(op->name, simplified, body);
}
}

Expand Down

0 comments on commit 8202163

Please sign in to comment.