Skip to content

Commit

Permalink
fix apache#5686: remove a overstrict assert in MakeAllreduce (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
majiang31312 authored and trevor-m committed Jun 18, 2020
1 parent 45719a8 commit 8dd8d9e
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::unordered_set<const VarNode*> reduce_set;
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const VarNode* v = call->args[i].as<VarNode>();
CHECK(v);
reduce_set.insert(v);
// The simply optimization replace a iteration variable with a constant
// when extent of the iteration is 1. As threaded IterVar always started from 0,
// we can just ignore this variable in this case.
if (v) {
reduce_set.insert(v);
} else {
CHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0)
<< "arg" << i << "should be a VarNode or IntImmNode";
}
}

size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
for (const AttrStmtNode* attr : thread_extents_) {
Expand All @@ -165,6 +173,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
const auto* ptr = attr->value.as<IntImmNode>();
CHECK(ptr) << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
// ignore variables equal to 0
if (e.extent == 1) {
continue;
}

if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
Expand Down

0 comments on commit 8dd8d9e

Please sign in to comment.