From 64caf31759932bce9d024027976292f85757bb33 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 28 Apr 2024 14:38:54 -0700 Subject: [PATCH] Faster vars used tracking in simplify let visitor (#8205) * Speed up the vars_used visitor in the simplifier let visitor This visitor shows up as the main cost of lowering in very large pipelines. This visitor is for tracking which lets are actually used for real inside the body of a let block (as opposed to the tracking we do when mutating, which is approximate, because we could construct and Expr that uses a Var and then discard it in a later mutation). The old implementation made a map of all variables referenced, and then checked each let name against that map one by one. If there are a small number of lets outside a huge Stmt, this is bad, because the data structure has to hold a number of names proportional to the stmt size instead of proportional to the number of lets. This new implementation instead makes a hash set of the let names, and than traverses the Stmt, removing names from the set as they are encountered. This is a big speed-up. We then make the speed-up larger by about the same factor again doing the following: 1) Only add names to the map that might be used based on the recursive mutate call. These are very very likely to be used, because we saw them at least once, and mutations that remove *all* uses of a Var are rare. 2) The visitor should early out when the map becomes empty. The let variables are often all used immediately, so this is frequent. Speeds up lowering of local laplacian by 1.44x, 2.6x, and 4.8x respectively for 20, 50, and 100 pyramid levels. Speeds up lowering of resnet50 by 1.04x. Speeds up lowering of lens blur by 1.06x * Exploit the ref count of the replacement Expr * Fix is_sole_reference logic in Simplify_Let.cpp * Reduce hash map size --- src/IntrusivePtr.h | 8 +++++ src/Simplify_Let.cpp | 85 ++++++++++++++++++++++++++++++++------------ 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/IntrusivePtr.h b/src/IntrusivePtr.h index f233420c8009..d265c3dcec8a 100644 --- a/src/IntrusivePtr.h +++ b/src/IntrusivePtr.h @@ -32,6 +32,9 @@ class RefCount { bool is_const_zero() const { return count == 0; } + int atomic_get() const { + return count; + } }; /** @@ -173,6 +176,11 @@ struct IntrusivePtr { bool operator<(const IntrusivePtr &other) const { return ptr < other.ptr; } + + HALIDE_ALWAYS_INLINE + bool is_sole_reference() const { + return ptr && ref_count(ptr).atomic_get() == 1; + } }; } // namespace Internal diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 4f1862abf6ac..342281fa6639 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -1,6 +1,8 @@ #include "Simplify_Internal.h" #include "Substitute.h" +#include + namespace Halide { namespace Internal { @@ -9,34 +11,50 @@ using std::vector; namespace { -class CountVarUses : public IRVisitor { - std::map &var_uses; +class FindVarUses : public IRVisitor { + std::unordered_set &unused_vars; void visit(const Variable *var) override { - var_uses[var->name]++; + unused_vars.erase(var->name); } void visit(const Load *op) override { - var_uses[op->name]++; - IRVisitor::visit(op); + if (!unused_vars.empty()) { + unused_vars.erase(op->name); + IRVisitor::visit(op); + } } void visit(const Store *op) override { - var_uses[op->name]++; - IRVisitor::visit(op); + if (!unused_vars.empty()) { + unused_vars.erase(op->name); + IRVisitor::visit(op); + } + } + + void visit(const Block *op) override { + // Early out at Block nodes if we've already seen every name we're + // interested in. In principal we could early-out at every node, but + // blocks, loads, and stores seem to be enough. + if (!unused_vars.empty()) { + op->first.accept(this); + if (!unused_vars.empty()) { + op->rest.accept(this); + } + } } using IRVisitor::visit; public: - CountVarUses(std::map &var_uses) - : var_uses(var_uses) { + FindVarUses(std::unordered_set &unused_vars) + : unused_vars(unused_vars) { } }; template -void count_var_uses(StmtOrExpr x, std::map &var_uses) { - CountVarUses counter(var_uses); +void find_var_uses(StmtOrExpr x, std::unordered_set &unused_vars) { + FindVarUses counter(unused_vars); x.accept(&counter); } @@ -49,10 +67,11 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { // the call stack where it could overflow onto an explicit stack. struct Frame { const LetOrLetStmt *op; - Expr value, new_value; + Expr value, new_value, new_var; string new_name; bool new_value_alignment_tracked = false, new_value_bounds_tracked = false; bool value_alignment_tracked = false, value_bounds_tracked = false; + VarInfo info; Frame(const LetOrLetStmt *op) : op(op) { } @@ -189,6 +208,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { // Nothing to substitute f.new_value = Expr(); replacement = Expr(); + new_var = Expr(); } else { debug(4) << "new let " << f.new_name << " = " << f.new_value << " in ... " << replacement << " ...\n"; } @@ -197,6 +217,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { info.old_uses = 0; info.new_uses = 0; info.replacement = replacement; + f.new_var = new_var; var_info.push(op->name, info); @@ -226,14 +247,35 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { result = mutate_let_body(result, bounds); - // TODO: var_info and vars_used are pretty redundant; however, at the time + // TODO: var_info and unused_vars are pretty redundant; however, at the time // of writing, both cover cases that the other does not: // - var_info prevents duplicate lets from being generated, even // from different Frame objects. - // - vars_used avoids dead lets being generated in cases where vars are + // - unused_vars avoids dead lets being generated in cases where vars are // seen as used by var_info, and then later removed. - std::map vars_used; - count_var_uses(result, vars_used); + + std::unordered_set unused_vars(frames.size()); + // Insert everything we think *might* be used, and then visit the body, + // removing things from the set as we find uses of them. + for (auto &f : frames) { + f.info = var_info.get(f.op->name); + // Drop any reference to new_var held by the replacement expression so + // that the only references are either f.new_var, or ones in the body or + // new_values of other lets. + f.info.replacement = Expr(); + if (f.new_var.is_sole_reference()) { + // Any new_uses must have been eliminated by later mutations. + f.info.new_uses = 0; + } + var_info.pop(f.op->name); + if (f.info.old_uses) { + internal_assert(f.info.new_uses == 0); + unused_vars.insert(f.op->name); + } else if (f.info.new_uses && f.new_value.defined()) { + unused_vars.insert(f.new_name); + } + } + find_var_uses(result, unused_vars); for (auto it = frames.rbegin(); it != frames.rend(); it++) { if (it->value_bounds_tracked) { @@ -243,20 +285,17 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) { bounds_and_alignment_info.pop(it->new_name); } - VarInfo info = var_info.get(it->op->name); - var_info.pop(it->op->name); - - if (it->new_value.defined() && (info.new_uses > 0 && vars_used.count(it->new_name) > 0)) { + if (it->new_value.defined() && (it->info.new_uses > 0 && !unused_vars.count(it->new_name))) { // The new name/value may be used result = LetOrLetStmt::make(it->new_name, it->new_value, result); - count_var_uses(it->new_value, vars_used); + find_var_uses(it->new_value, unused_vars); } if ((!remove_dead_code && std::is_same::value) || - (info.old_uses > 0 && vars_used.count(it->op->name) > 0)) { + (it->info.old_uses > 0 && !unused_vars.count(it->op->name))) { // The old name is still in use. We'd better keep it as well. result = LetOrLetStmt::make(it->op->name, it->value, result); - count_var_uses(it->value, vars_used); + find_var_uses(it->value, unused_vars); } const LetOrLetStmt *new_op = result.template as();