Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a caching version of stmt_uses_vars in TightenProducerConsumer nodes #8102

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 68 additions & 12 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,67 @@ class InitializeSemaphores : public IRMutator {
}
};

// A class to support stmt_uses_vars queries that repeatedly hit the same
// sub-stmts. Used to support TightenProducerConsumerNodes below.
class CachingStmtUsesVars : public IRMutator {
const Scope<> &query;
bool found_use = false;
std::map<Stmt, bool> cache;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not std::set<> instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just fer grins, would unordered_map or unordered_set be any better here? (Surely the order doesn't matter)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three possible states:

  • I've seen this stmt before and it doesn't contain the vars (map contains false)
  • I've seen this stmt before and it does contain the vars (map contains true)
  • I haven't seen this stmt before and it needs to be analyzed (not in the map)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally avoid unordered_set and unordered_map because every time I've benchmarked them within Halide they haven't been enough faster to outweigh a really annoying property that I've been burned by: If you have a bug where you accidentally depend on the order of the keys, it can fail on some machines but not others depending on the standard library used, which makes debugging it really annoying.


using IRMutator::visit;
Expr visit(const Variable *op) override {
found_use |= query.contains(op->name);
return op;
}

Expr visit(const Call *op) override {
found_use |= query.contains(op->name);
IRMutator::visit(op);
return op;
}

Stmt visit(const Provide *op) override {
found_use |= query.contains(op->name);
IRMutator::visit(op);
return op;
}

public:
CachingStmtUsesVars(const Scope<> &q)
: query(q) {
}

using IRMutator::mutate;
Stmt mutate(const Stmt &s) override {
auto it = cache.find(s);
if (it != cache.end()) {
found_use |= it->second;
} else {
bool old = found_use;
found_use = false;
Stmt stmt = IRMutator::mutate(s);
if (found_use) {
cache.emplace(s, true);
} else {
cache.emplace(s, false);
}
found_use |= old;
}
return s;
}

bool check_stmt(const Stmt &s) {
found_use = false;
mutate(s);
return found_use;
}
};

// Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
class TightenProducerConsumerNodes : public IRMutator {
using IRMutator::visit;

Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<int> &scope) {
Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) {
if (const LetStmt *let = body.as<LetStmt>()) {
Stmt orig = body;
// 'orig' is only used to keep a reference to the let
Expand All @@ -595,7 +651,7 @@ class TightenProducerConsumerNodes : public IRMutator {
body = ProducerConsumer::make(name, is_producer, body);
} else {
// Recurse onto a non-let-node
body = make_producer_consumer(name, is_producer, body, scope);
body = make_producer_consumer(name, is_producer, body, scope, uses_vars);
}

for (auto it = containing_lets.rbegin(); it != containing_lets.rend(); it++) {
Expand All @@ -611,44 +667,44 @@ class TightenProducerConsumerNodes : public IRMutator {
vector<Stmt> sub_stmts;
Stmt rest;
do {
Stmt first = block->first;
sub_stmts.push_back(block->first);
rest = block->rest;
block = rest.as<Block>();
} while (block);
sub_stmts.push_back(rest);

for (Stmt &s : sub_stmts) {
if (stmt_uses_vars(s, scope)) {
s = make_producer_consumer(name, is_producer, s, scope);
if (uses_vars.check_stmt(s)) {
s = make_producer_consumer(name, is_producer, s, scope, uses_vars);
}
}

return Block::make(sub_stmts);
} else if (const ProducerConsumer *pc = body.as<ProducerConsumer>()) {
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope));
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope, uses_vars));
} else if (const Realize *r = body.as<Realize>()) {
return Realize::make(r->name, r->types, r->memory_type,
r->bounds, r->condition,
make_producer_consumer(name, is_producer, r->body, scope));
make_producer_consumer(name, is_producer, r->body, scope, uses_vars));
} else {
return ProducerConsumer::make(name, is_producer, body);
}
}

Stmt visit(const ProducerConsumer *op) override {
Stmt body = mutate(op->body);
Scope<int> scope;
scope.push(op->name, 0);
Scope<> scope;
scope.push(op->name);
Function f = env.find(op->name)->second;
if (f.outputs() == 1) {
scope.push(op->name + ".buffer", 0);
scope.push(op->name + ".buffer");
} else {
for (int i = 0; i < f.outputs(); i++) {
scope.push(op->name + "." + std::to_string(i) + ".buffer", 0);
scope.push(op->name + "." + std::to_string(i) + ".buffer");
}
}
return make_producer_consumer(op->name, op->is_producer, body, scope);
CachingStmtUsesVars uses_vars{scope};
return make_producer_consumer(op->name, op->is_producer, body, scope, uses_vars);
}

const map<string, Function> &env;
Expand Down
2 changes: 1 addition & 1 deletion src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void lower_impl(const vector<Function> &output_funcs,
debug(1) << "Simplifying...\n";
s = simplify(s);
s = unify_duplicate_lets(s);
log("Lowering after second simplifcation:", s);
log("Lowering after second simplification:", s);

debug(1) << "Reduce prefetch dimension...\n";
s = reduce_prefetch_dimension(s, t);
Expand Down
Loading