Skip to content

Commit

Permalink
[MetaSchedule] Add search rule "RuleAddRfactor" (#306)
Browse files Browse the repository at this point in the history
* [MetaSchedule] change after refactor

* [MetaSchedule] fix typo

* [MetaSchedule] add rule "AddRfactor"

* [MetaSchedule] Move "NeedsInline"

* [MetaSchedule] Update "NeedsRfactor"

* [MetaSchedule] add reorder

* [MetaSchedule] LoopNode -> ForNode after rebasing

* [MetaSchedule] simply fix some ce after rebasing

* [MetaSchedule] little refactor

* [MetaSchedule] use `Schedule::Copy` to make copy

* fix ce after rebasing

* [MetaSchedule] use SamplePerfectTile. Successfully align!

* [MetaSchedule] rename

* [MetaSchedule] add the original schedule

* [MetaSchedule] move helper functions

* sorry...

* [MetaSchedule] optimize imports

* [MetaSchedule] refactor & document & use TVM_SREF_TO_BLOCK/FOR

* [MetaSchedule] fix comments

* [MetaSchedule] fix comments

* [MetaSchedule] fix comments

* [MetaSchedule] rebase
  • Loading branch information
MasterJH5574 authored and Hzfengsy committed Mar 26, 2021
1 parent bc6e609 commit d68997f
Show file tree
Hide file tree
Showing 11 changed files with 444 additions and 46 deletions.
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,8 @@ class ParallelAttrs(InstAttrs):
@register_object("meta_schedule.attrs.VectorizeAttrs")
class VectorizeAttrs(InstAttrs):
"""Attrs of the instruction that applies vectorize"""


@register_object("meta_schedule.attrs.RFactorAttrs")
class RFactorAttrs(InstAttrs):
"""Attrs of the instruction that applies rfactor"""
29 changes: 26 additions & 3 deletions python/tvm/meta_schedule/search_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def parallelize_vectorize_unroll(
----------
max_jobs_per_core: int
The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU
parallism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallism.
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_vectorize_extent: int
The maximum extent to be vectorized. It sets the uplimit of the CPU vectorization.
Use -1 to disable vectorization.
Expand Down Expand Up @@ -278,4 +278,27 @@ def simplify_compute_with_const_tensor(max_innermost_factor: int = 16) -> Search
rule: SearchRUle
The rule created
"""
return _ffi_api_search_rule.SimplifyComputeWithConstTensor(max_innermost_factor) # pylint: disable=no-member
return _ffi_api_search_rule.SimplifyComputeWithConstTensor(max_innermost_factor) # pylint: disable=no-member


def add_rfactor(
max_jobs_per_core: int = 16,
max_innermost_factor: int = 64,
) -> SearchRule:
"""Add rfactor to some blocks if needed
Parameters
----------
max_jobs_per_core: int
The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_innermost_factor: int
The maximum size of the innermost factor
Returns
-------
rule: SearchRule
The rule created
"""
return _ffi_api_search_rule.AddRFactor(max_jobs_per_core, max_innermost_factor) # pylint: disable=no-member
159 changes: 130 additions & 29 deletions src/meta_schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "./analysis.h" // NOLINT(build/include)

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt_functor.h>

#include <numeric>

#include "../tir/schedule/analysis.h"
#include "../tir/schedule/primitives/primitives.h"
#include "./utils.h"
#include "analysis.h"

namespace tvm {
namespace meta_schedule {

bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = block_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
tir::BlockRealize realize = tir::GetBlockRealize(block_sref);
Array<tir::StmtSRef> loops = tir::GetAxes(self, block_sref);
const Array<PrimExpr>& bindings = realize->binding_values;
Expand All @@ -42,8 +37,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block
int n = loops.size();
for (int i = 0; i < n; ++i) {
const PrimExpr& bind = bindings[i];
const auto* loop = loops[i]->GetStmt<tir::ForNode>();
ICHECK(loop) << "TypeError: Expects Loop, but gets: " << loops[i]->stmt->GetTypeKey();
const auto* loop = TVM_SREF_TO_FOR(loop, loops[i]);
if (bind.as<tir::VarNode>() != loop->loop_var.get()) {
return false;
}
Expand All @@ -57,7 +51,7 @@ bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_s
}

bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = block_sref->GetStmt<tir::BlockNode>();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
bool no_child = true;
tir::PreOrderVisit(block->body, [&no_child](const ObjectRef& obj) -> bool {
if (!no_child) {
Expand All @@ -73,8 +67,7 @@ bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref
}

Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = block_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
Array<Integer> result;
for (const tir::IterVar& iter_var : block->iter_vars) {
int iter_type = iter_var->iter_type;
Expand All @@ -84,8 +77,7 @@ Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtS
}

bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = block_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
for (const tir::IterVar& iter_var : block->iter_vars) {
if (iter_var->iter_type != tir::IterVarType::kDataPar) {
return false;
Expand All @@ -96,10 +88,8 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref)

bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
tir::StmtSRef parent_sref = tir::GetScopeRoot(block_sref);
const auto* block = block_sref->GetStmt<tir::BlockNode>();
const auto* parent = parent_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
ICHECK(parent) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
const auto* parent = TVM_SREF_TO_BLOCK(parent, parent_sref);
if (parent_sref->parent == nullptr) {
const tir::PrimFuncNode* func = tir::GetRootPrimFunc(self, parent_sref);
for (const tir::BufferRegion& write : block->writes) {
Expand All @@ -122,8 +112,7 @@ bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sr
}

int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const Op& op) {
const auto* block = block_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
int count = 0;
tir::PostOrderVisit(block->body, [&count, &op](const ObjectRef& obj) {
if (const auto* call = obj.as<tir::CallNode>()) {
Expand All @@ -136,8 +125,7 @@ int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, con
}

bool HasBranch(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
const auto* block = block_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
bool has_branch = false;
arith::Analyzer analyzer;
auto f_visit = [&has_branch, &analyzer](const ObjectRef& obj) -> bool {
Expand Down Expand Up @@ -226,10 +214,8 @@ bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& pro
const tir::StmtSRef& consumer_sref) {
// Assume consumer is the only consumer of the producer
tir::StmtSRef parent_sref = tir::GetScopeRoot(producer_sref);
const auto* producer = producer_sref->GetStmt<tir::BlockNode>();
const auto* consumer = consumer_sref->GetStmt<tir::BlockNode>();
ICHECK(producer) << "TypeError: Expects Block, but gets: " << producer_sref->stmt->GetTypeKey();
ICHECK(consumer) << "TypeError: Expects Block, but gets: " << consumer_sref->stmt->GetTypeKey();
const auto* producer = TVM_SREF_TO_BLOCK(producer, producer_sref);
const auto* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref);
if (producer->writes.empty()) {
return false;
}
Expand Down Expand Up @@ -299,8 +285,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&
if (!IsTrivialBinding(self, block_sref)) {
return false;
}
const auto* block = block_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Assume complete/reduction block
if (block->writes.size() != 1) {
return false;
Expand Down Expand Up @@ -348,7 +333,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&

bool IsStrictlyInlineable(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
static const Op& op_tir_exp = Op::Get("tir.exp");
const auto* block = block_sref->GetStmt<tir::BlockNode>();
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Const tensors are strictly inlineable
if (block->reads.empty()) {
return true;
Expand Down Expand Up @@ -742,6 +727,122 @@ double CountFlop(const tir::PrimFunc& func) {
return cnt;
}

std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref) {
Array<tir::StmtSRef> loops = tir::GetAxes(self, block_sref);
int64_t cum_space_len = 1, cum_reduce_len = 1;
/*
* Return (-1, -1) if
* 1. there is some loop with type other than kDataPar and kCommReduce;
* 2. there is some loop which is dynamic.
*/
for (const tir::StmtSRef& loop_sref : loops) {
tir::IterVarType type = GetLoopIterType(self, loop_sref);
if (type == tir::kDataPar) {
int64_t extent = GetLoopIntExtent(loop_sref);
if (extent != -1) {
cum_space_len *= extent;
} else {
return std::make_pair(-1, -1);
}
} else if (type == tir::kCommReduce) {
int64_t extent = GetLoopIntExtent(loop_sref);
if (extent != -1) {
cum_reduce_len *= extent;
} else {
return std::make_pair(-1, -1);
}
} else {
return std::make_pair(-1, -1);
}
}
return std::make_pair(cum_space_len, cum_reduce_len);
}

bool NeedsRFactor(const tir::ScheduleState& self, const tir::StmtSRef& block_sref,
const SearchTask& task, const int& max_jobs_per_core,
std::atomic<int>* warned_num_cores_missing) {
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
Array<tir::StmtSRef> loops = tir::GetAxes(self, block_sref);

// Cond 1. The block is a reduction block and has trivial binding.
if (self->scopes.at(GetScopeRoot(block_sref))->IsReduction(block_sref)
&& !IsTrivialBinding(self, block_sref)) {
return false;
}

// Cond 2. Every the loop axis must be either spatial axis or reduction axis.
for (const tir::StmtSRef& loop_sref : loops) {
const tir::IterVarType& type = GetLoopIterType(self, loop_sref);
if (type != tir::kDataPar && type != tir::kCommReduce) {
return false;
}
}

// Cond 3. Whether there is at least one reduction loop.
// Cond 4. The loops are continuous, and the body of the innermost loop is exactly the block.
bool has_reduction_loop = false;
for (int i = 0; i < static_cast<int>(loops.size()); ++i) {
// Cond 3.
if (GetLoopIterType(self, loops[i]) == tir::kCommReduce) {
has_reduction_loop = true;
}

// Cond 4.
const auto* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
if (i < static_cast<int>(loops.size()) - 1) {
const auto* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
if (loop_i->body.get() != loop_i1) {
return false;
}
} else {
const auto* block_realize = loop_i->body.as<tir::BlockRealizeNode>();
if (!block_realize || block_realize->block.get() != block) {
return false;
}
}
}
if (!has_reduction_loop) {
return false;
}

// Cond 5. Can successfully calculating the cumulative loop length.
int64_t cum_space_len, cum_reduce_len;
std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(self, block_sref);
if (cum_space_len == -1 || cum_reduce_len == -1) {
return false;
}

// Cond 6.
int target_num_cores = GetTargetNumCores(task->target, warned_num_cores_missing);
if (NeedsMultiLevelTiling(self, block_sref)) {
// Do not use rfactor if we have enough parallelism on spatial loops.
if (cum_space_len > cum_reduce_len || cum_space_len > target_num_cores * max_jobs_per_core) {
return false;
} else {
return true;
}
} else if (cum_reduce_len > 1) {
// Always try rfactor for other reduction blocks.
return cum_reduce_len > target_num_cores;
}

return false;
}

bool HasCacheWriteBlock(const Schedule& sch, const BlockRV& block_rv, const int& i) {
for (const Instruction& inst : sch->trace->insts) {
if (const auto inst_attr = inst->inst_attrs.as<CacheWriteAttrs>()) {
CHECK_EQ(inst->inputs.size(), 1);
const BlockRV& input_rv = Downcast<BlockRV>(inst->inputs[0]);
if (block_rv.same_as(input_rv) && inst_attr->i == i) {
return true;
}
}
}
return false;
}

TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);

TVM_REGISTER_GLOBAL("meta_schedule.analysis.IsTrivialBinding").set_body_typed(IsTrivialBinding);
Expand Down
Loading

0 comments on commit d68997f

Please sign in to comment.