Skip to content

Commit

Permalink
[Refactor] Making schedule based on IRModule not PrimFunc (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and Hzfengsy committed Mar 26, 2021
1 parent 724d728 commit bc6e609
Show file tree
Hide file tree
Showing 51 changed files with 816 additions and 521 deletions.
1 change: 0 additions & 1 deletion include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ class DepEdge : public runtime::ObjectRef {
/*! \brief An object recording the producer-consumer dependency between child blocks of a scope */
class BlockScopeNode : public runtime::Object {
public:
// TODO(@junrushao1994): Change std::unordered_map to Map
/*! \brief The forward dependency edges of the block */
std::unordered_map<StmtSRef, Array<DepEdge>, ObjectPtrHash, ObjectPtrEqual> forward_edges;
/*! \brief The backward dependency edges of the block */
Expand Down
27 changes: 23 additions & 4 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,21 @@ class Schedule;
* \brief The user-facing abstract schedule class
*/
class ScheduleNode : public runtime::Object {
public:
/*! \brief The internal state of scheduling */
ScheduleState state;
friend class Schedule;

public:
virtual ~ScheduleNode() = default;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object);

public:
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*!
* \brief Take the PrimFunc out of the schedule
*/
virtual IRModule Module() const = 0;
virtual IRModule mod() const { return state()->mod; }
/*!
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random
Expand Down Expand Up @@ -160,6 +161,22 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding block/loop sref
*/
virtual StmtSRef GetSRef(const Stmt& stmt) const;
/******** Remove random variables ********/
/*!
* \brief Remove a random variable from the symbol table
* \param block_rv The symbol to be removed
*/
virtual void RemoveRV(const BlockRV& block_rv) = 0;
/*!
* \brief Remove a random variable from the symbol table
* \param block_rv The symbol to be removed
*/
virtual void RemoveRV(const LoopRV& loop_rv) = 0;
/*!
* \brief Remove a random variable from the symbol table
* \param block_rv The symbol to be removed
*/
virtual void RemoveRV(const VarRV& var_rv) = 0;

public:
/******** Sampling ********/
Expand Down Expand Up @@ -390,7 +407,9 @@ class ScheduleNode : public runtime::Object {
class Schedule : public runtime::ObjectRef {
public:
TVM_DLL static Schedule Concrete(PrimFunc func, int64_t seed, bool debug_mode);
TVM_DLL static Schedule Concrete(IRModule func, int64_t seed, bool debug_mode);
TVM_DLL static Schedule Meta(PrimFunc func, int64_t seed, bool debug_mode);
TVM_DLL static Schedule Meta(IRModule func, int64_t seed, bool debug_mode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
21 changes: 13 additions & 8 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,24 @@
namespace tvm {
namespace tir {

// TODO(@junrushao1994): change `std::unordered_map` to `Map`?

/*!
* \brief The state of scheduling, which provides a primitive `Replace` as an interface of all the
* scheduling primitives to transform the TensorIR.
*/
class ScheduleStateNode : public runtime::Object {
public:
/*! \brief The function to be scheduled */
PrimFunc func; // TODO(@junrushao1994): change to IRModule
/*! \brief The module to be scheduled */
IRModule mod;
/*! \brief The block scopes of each block sref */
std::unordered_map<StmtSRef, BlockScope, ObjectPtrHash, ObjectPtrEqual> scopes;
Map<StmtSRef, BlockScope> scopes;
/*! \brief The mapping from block/for stmt to its sref */
std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
/*! \brief In debug mode, we do extra correctness checking after each replacement */
bool debug_mode;

void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
// `scopes` is not visited
v->Visit("mod", &mod);
v->Visit("scopes", &scopes);
// `stmt2ref` is not visited
v->Visit("debug_mode", &debug_mode);
}
Expand Down Expand Up @@ -95,9 +93,16 @@ class ScheduleStateNode : public runtime::Object {
*/
class ScheduleState : public runtime::ObjectRef {
public:
/*!
* \brief Construct a schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode When turned on, additional checks will be performed after each mutation
*/
TVM_DLL explicit ScheduleState(IRModule mod, bool debug_mode);
/*!
* \brief Construct a schedule from a PrimFunc
* \param func The PrimFunc to be created
* \param mod The PrimFunc to be scheduled
* \param debug_mode When turned on, additional checks will be performed after each mutation
*/
TVM_DLL explicit ScheduleState(PrimFunc func, bool debug_mode);

Expand Down
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class Schedule(TIRSchedule):
"""

state: ScheduleState
orig_func: tir.PrimFunc
trace: Trace

def __init__( # pylint: disable=super-init-not-called
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Any, Callable, List, Tuple

import psutil
from tvm import arith, ir, rpc
from tvm import ir, rpc
from tvm._ffi import register_func
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import ndk as build_func_ndk
Expand Down Expand Up @@ -432,7 +432,7 @@ def timed_func() -> BuildResult.TYPE:
filename = os.path.join(tempfile.mkdtemp(), "tmp_func." + build_func.output_format)
try:
func = tvm_build(
measure_input.sch.module,
measure_input.sch.mod["main"],
target=measure_input.task.target,
target_host=measure_input.task.target_host,
)
Expand Down Expand Up @@ -545,7 +545,7 @@ def rpc_runner_run(
This is only has effect on CPU task.
f_create_args: Callable[[TVMContext], List[NDArray]] = None
Optional callback to create arguments for functions to measure. This can be used for sparse
workloads when we cannot use random tensors for measurment.
workloads when we cannot use random tensors for measurement.
verbose: int = 1
Verbosity level. 0 for silent, 1 to output information during program measuring.
Expand Down
26 changes: 13 additions & 13 deletions python/tvm/tir/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from tvm._ffi import register_object as _register_object
from tvm.ir import PrimExpr
from tvm.ir import PrimExpr, IRModule
from tvm.runtime import Object, String

from . import _ffi_api_schedule
Expand Down Expand Up @@ -99,22 +99,20 @@ def get_successor(self, block: StmtSRef) -> List[DepEdge]:
class ScheduleState(Object):
"""The state of scheduling"""

func: PrimFunc
mod: IRModule
scopes: Dict[StmtSRef, BlockScope]
debug_mode: bool

def __init__(self, func: PrimFunc, debug_mode: bool):
def __init__(self, func_or_mod: Union[PrimFunc, IRModule], debug_mode: bool):
self.__init_handle_by_constructor__(
_ffi_api_schedule.ScheduleState, # pylint: disable=no-member
func,
func_or_mod,
debug_mode,
)

def get_sref(self, stmt: Stmt) -> Optional[StmtSRef]:
return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt) # pylint: disable=no-member

def scope(self, block: StmtSRef) -> BlockScope:
return _ffi_api_schedule.ScheduleStateGetScope(self, block) # pylint: disable=no-member

def replace(
self,
src_sref: StmtSRef,
Expand Down Expand Up @@ -152,19 +150,21 @@ class BlockRV(Object):
class Schedule(Object):
"""The schedule node for TIR"""

state: ScheduleState

def __init__(self, func: PrimFunc, debug_mode: bool = False):
def __init__(self, func_or_mod: Union[PrimFunc, IRModule], debug_mode: bool = False):
self.__init_handle_by_constructor__(
_ffi_api_schedule.Schedule, # pylint: disable=no-member
func,
func_or_mod,
-1, # seed
debug_mode,
)

@property
def module(self) -> PrimFunc:
return self.state.func
def mod(self) -> IRModule:
return _ffi_api_schedule.ScheduleModule(self) # pylint: disable=no-member

@property
def state(self) -> ScheduleState:
return _ffi_api_schedule.ScheduleGetState(self) # pylint: disable=no-member

def show(self, rand_var: Union[LoopRV, BlockRV, ExprRV]) -> str:
# TODO(@junrushao1994): complete it
Expand Down
9 changes: 5 additions & 4 deletions src/meta_schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block
}

bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
tir::StmtSRef parent_block_sref = GetScopeSRef(block_sref);
tir::StmtSRef parent_block_sref = GetScopeRoot(block_sref);
return parent_block_sref->parent == nullptr;
}

Expand Down Expand Up @@ -95,14 +95,15 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref)
}

bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
tir::StmtSRef parent_sref = tir::GetScopeSRef(block_sref);
tir::StmtSRef parent_sref = tir::GetScopeRoot(block_sref);
const auto* block = block_sref->GetStmt<tir::BlockNode>();
const auto* parent = parent_sref->GetStmt<tir::BlockNode>();
ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
ICHECK(parent) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey();
if (parent_sref->parent == nullptr) {
const tir::PrimFuncNode* func = tir::GetRootPrimFunc(self, parent_sref);
for (const tir::BufferRegion& write : block->writes) {
for (const auto& kv : self->func->buffer_map) {
for (const auto& kv : func->buffer_map) {
if (write->buffer.get() == kv.second.get()) {
return true;
}
Expand Down Expand Up @@ -224,7 +225,7 @@ Optional<Array<Bool>> GetReadPattern(const Array<tir::IterVar>& block_vars,
bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& producer_sref,
const tir::StmtSRef& consumer_sref) {
// Assume consumer is the only consumer of the producer
tir::StmtSRef parent_sref = tir::GetScopeSRef(producer_sref);
tir::StmtSRef parent_sref = tir::GetScopeRoot(producer_sref);
const auto* producer = producer_sref->GetStmt<tir::BlockNode>();
const auto* consumer = consumer_sref->GetStmt<tir::BlockNode>();
ICHECK(producer) << "TypeError: Expects Block, but gets: " << producer_sref->stmt->GetTypeKey();
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/feature/per_block_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@ runtime::NDArray PerBlockFeature(const Schedule& sch, int max_num_buffer_access_
size_t kNumFeature = kNumFeatureGroup1 +
kNumFeatureGroup2Subgroup * max_num_buffer_access_features +
kNumFeatureGroup3 + kNumFeatureGroup5;
tir::PrimFunc func = GetOnlyFunc(sch->Module());
tir::PrimFunc func = GetOnlyFunc(sch->mod());
std::vector<FeatureSet> feature_map = PerBlockFeatureExtractor::Extract(func);

DoubleNDArrayPusher ret(
Expand Down
3 changes: 3 additions & 0 deletions src/meta_schedule/sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ std::vector<int> Sampler::SamplePerfectTile(int n_splits, int extent) {
}

std::vector<int> Sampler::SamplePerfectTile(int n_splits, int extent, int max_innermost_factor) {
if (max_innermost_factor == -1) {
return this->SamplePerfectTile(n_splits, extent);
}
CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits";
std::vector<int> innermost_candidates;
innermost_candidates.reserve(max_innermost_factor);
Expand Down
25 changes: 17 additions & 8 deletions src/meta_schedule/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,27 @@ std::vector<int64_t> SamplePerfectTile(tir::ScheduleState self, Sampler* sampler
} else if (decision->defined()) {
// Case 2. Use previous decision
result = AsVector<Integer, int64_t>(decision->value());
int n = result.size();
ICHECK_GE(n, 2);
int64_t len = extent;
for (int i = n - 1; i > 0; --i) {
int64_t& l = result[i];
// A previous decision could become invalid because of the change of outer tiles
// To handle this case properly, we check if the tiling strategy is still perfect.
// If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles
if (len % l != 0) {
l = len;
}
len /= l;
}
result[0] = len;
} else {
// Case 3. Use fresh new sampling result
std::vector<int> sampled = sampler->SamplePerfectTile(n, extent);
std::vector<int> sampled = sampler->SamplePerfectTile(n, extent, max_innermost_factor);
result = std::vector<int64_t>(sampled.begin(), sampled.end());
ICHECK_LE(sampled.back(), max_innermost_factor);
}
// Record the new decision
Array<Integer> new_decision;
new_decision.reserve(result.size());
for (int64_t i : result) {
new_decision.push_back(Integer(i));
}
*decision = new_decision;
*decision = AsArray<int64_t, Integer>(result);
return result;
}

Expand Down
Loading

0 comments on commit bc6e609

Please sign in to comment.