diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 860401735896..058014cd7233 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -149,9 +149,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_set reduce_set; for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { const VarNode* v = call->args[i].as(); - CHECK(v); - reduce_set.insert(v); + // The simply optimization replace a iteration variable with a constant + // when extent of the iteration is 1. As threaded IterVar always started from 0, + // we can just ignore this variable in this case. + if (v) { + reduce_set.insert(v); + } else { + CHECK(call->args[i].as() && call->args[i].as()->value == 0) + << "arg" << i << "should be a VarNode or IntImmNode"; + } } + size_t nmatch = 0; std::vector vred, vpar; for (const AttrStmtNode* attr : thread_extents_) { @@ -165,6 +173,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const auto* ptr = attr->value.as(); CHECK(ptr) << "Need constant extent for reduce set " << iv; e.extent = static_cast(ptr->value); + // ignore variables equal to 0 + if (e.extent == 1) { + continue; + } + if (reduce_set.count(iv->var.get())) { vred.push_back(e); ++nmatch;