Skip to content

Commit

Permalink
Faster vars used tracking in simplify let visitor (#8205)
Browse files Browse the repository at this point in the history
* Speed up the vars_used visitor in the simplifier let visitor

This visitor shows up as the main cost of lowering in very large
pipelines.

This visitor is for tracking which lets are actually used for real
inside the body of a let block (as opposed to the tracking we do when
mutating, which is approximate, because we could construct and Expr that
uses a Var and then discard it in a later mutation).

The old implementation made a map of all variables referenced, and then
checked each let name against that map one by one. If there are a small
number of lets outside a huge Stmt, this is bad, because the data
structure has to hold a number of names proportional to the stmt size
instead of proportional to the number of lets.

This new implementation instead makes a hash set of the let names, and
than traverses the Stmt, removing names from the set as they are
encountered. This is a big speed-up.

We then make the speed-up larger by about the same factor again doing
the following:

1) Only add names to the map that might be used based on the recursive
mutate call. These are very very likely to be used, because we saw them
at least once, and mutations that remove *all* uses of a Var are rare.

2) The visitor should early out when the map becomes empty. The let
variables are often all used immediately, so this is frequent.

Speeds up lowering of local laplacian by 1.44x, 2.6x, and 4.8x
respectively for 20, 50, and 100 pyramid levels.

Speeds up lowering of resnet50 by 1.04x. Speeds up lowering of lens blur
by 1.06x

* Exploit the ref count of the replacement Expr

* Fix is_sole_reference logic in Simplify_Let.cpp

* Reduce hash map size
  • Loading branch information
abadams authored Apr 28, 2024
1 parent 302aa1c commit 64caf31
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 23 deletions.
8 changes: 8 additions & 0 deletions src/IntrusivePtr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class RefCount {
bool is_const_zero() const {
return count == 0;
}
int atomic_get() const {
return count;
}
};

/**
Expand Down Expand Up @@ -173,6 +176,11 @@ struct IntrusivePtr {
bool operator<(const IntrusivePtr<T> &other) const {
return ptr < other.ptr;
}

HALIDE_ALWAYS_INLINE
bool is_sole_reference() const {
return ptr && ref_count(ptr).atomic_get() == 1;
}
};

} // namespace Internal
Expand Down
85 changes: 62 additions & 23 deletions src/Simplify_Let.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "Simplify_Internal.h"
#include "Substitute.h"

#include <unordered_set>

namespace Halide {
namespace Internal {

Expand All @@ -9,34 +11,50 @@ using std::vector;

namespace {

class CountVarUses : public IRVisitor {
std::map<std::string, int> &var_uses;
class FindVarUses : public IRVisitor {
std::unordered_set<std::string> &unused_vars;

void visit(const Variable *var) override {
var_uses[var->name]++;
unused_vars.erase(var->name);
}

void visit(const Load *op) override {
var_uses[op->name]++;
IRVisitor::visit(op);
if (!unused_vars.empty()) {
unused_vars.erase(op->name);
IRVisitor::visit(op);
}
}

void visit(const Store *op) override {
var_uses[op->name]++;
IRVisitor::visit(op);
if (!unused_vars.empty()) {
unused_vars.erase(op->name);
IRVisitor::visit(op);
}
}

void visit(const Block *op) override {
// Early out at Block nodes if we've already seen every name we're
// interested in. In principal we could early-out at every node, but
// blocks, loads, and stores seem to be enough.
if (!unused_vars.empty()) {
op->first.accept(this);
if (!unused_vars.empty()) {
op->rest.accept(this);
}
}
}

using IRVisitor::visit;

public:
CountVarUses(std::map<std::string, int> &var_uses)
: var_uses(var_uses) {
FindVarUses(std::unordered_set<std::string> &unused_vars)
: unused_vars(unused_vars) {
}
};

template<typename StmtOrExpr>
void count_var_uses(StmtOrExpr x, std::map<std::string, int> &var_uses) {
CountVarUses counter(var_uses);
void find_var_uses(StmtOrExpr x, std::unordered_set<std::string> &unused_vars) {
FindVarUses counter(unused_vars);
x.accept(&counter);
}

Expand All @@ -49,10 +67,11 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
// the call stack where it could overflow onto an explicit stack.
struct Frame {
const LetOrLetStmt *op;
Expr value, new_value;
Expr value, new_value, new_var;
string new_name;
bool new_value_alignment_tracked = false, new_value_bounds_tracked = false;
bool value_alignment_tracked = false, value_bounds_tracked = false;
VarInfo info;
Frame(const LetOrLetStmt *op)
: op(op) {
}
Expand Down Expand Up @@ -189,6 +208,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
// Nothing to substitute
f.new_value = Expr();
replacement = Expr();
new_var = Expr();
} else {
debug(4) << "new let " << f.new_name << " = " << f.new_value << " in ... " << replacement << " ...\n";
}
Expand All @@ -197,6 +217,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
info.old_uses = 0;
info.new_uses = 0;
info.replacement = replacement;
f.new_var = new_var;

var_info.push(op->name, info);

Expand Down Expand Up @@ -226,14 +247,35 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {

result = mutate_let_body(result, bounds);

// TODO: var_info and vars_used are pretty redundant; however, at the time
// TODO: var_info and unused_vars are pretty redundant; however, at the time
// of writing, both cover cases that the other does not:
// - var_info prevents duplicate lets from being generated, even
// from different Frame objects.
// - vars_used avoids dead lets being generated in cases where vars are
// - unused_vars avoids dead lets being generated in cases where vars are
// seen as used by var_info, and then later removed.
std::map<std::string, int> vars_used;
count_var_uses(result, vars_used);

std::unordered_set<std::string> unused_vars(frames.size());
// Insert everything we think *might* be used, and then visit the body,
// removing things from the set as we find uses of them.
for (auto &f : frames) {
f.info = var_info.get(f.op->name);
// Drop any reference to new_var held by the replacement expression so
// that the only references are either f.new_var, or ones in the body or
// new_values of other lets.
f.info.replacement = Expr();
if (f.new_var.is_sole_reference()) {
// Any new_uses must have been eliminated by later mutations.
f.info.new_uses = 0;
}
var_info.pop(f.op->name);
if (f.info.old_uses) {
internal_assert(f.info.new_uses == 0);
unused_vars.insert(f.op->name);
} else if (f.info.new_uses && f.new_value.defined()) {
unused_vars.insert(f.new_name);
}
}
find_var_uses(result, unused_vars);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
if (it->value_bounds_tracked) {
Expand All @@ -243,20 +285,17 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
bounds_and_alignment_info.pop(it->new_name);
}

VarInfo info = var_info.get(it->op->name);
var_info.pop(it->op->name);

if (it->new_value.defined() && (info.new_uses > 0 && vars_used.count(it->new_name) > 0)) {
if (it->new_value.defined() && (it->info.new_uses > 0 && !unused_vars.count(it->new_name))) {
// The new name/value may be used
result = LetOrLetStmt::make(it->new_name, it->new_value, result);
count_var_uses(it->new_value, vars_used);
find_var_uses(it->new_value, unused_vars);
}

if ((!remove_dead_code && std::is_same<LetOrLetStmt, LetStmt>::value) ||
(info.old_uses > 0 && vars_used.count(it->op->name) > 0)) {
(it->info.old_uses > 0 && !unused_vars.count(it->op->name))) {
// The old name is still in use. We'd better keep it as well.
result = LetOrLetStmt::make(it->op->name, it->value, result);
count_var_uses(it->value, vars_used);
find_var_uses(it->value, unused_vars);
}

const LetOrLetStmt *new_op = result.template as<LetOrLetStmt>();
Expand Down

0 comments on commit 64caf31

Please sign in to comment.