diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h index e479f2026c..c708528288 100644 --- a/include/tvm/tir/schedule/block_scope.h +++ b/include/tvm/tir/schedule/block_scope.h @@ -151,7 +151,6 @@ class DepEdge : public runtime::ObjectRef { /*! \brief An object recording the producer-consumer dependency between child blocks of a scope */ class BlockScopeNode : public runtime::Object { public: - // TODO(@junrushao1994): Change std::unordered_map to Map /*! \brief The forward dependency edges of the block */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> forward_edges; /*! \brief The backward dependency edges of the block */ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 6e3b25e736..f6e8dc5f08 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -82,20 +82,21 @@ class Schedule; * \brief The user-facing abstract schedule class */ class ScheduleNode : public runtime::Object { - public: - /*! \brief The internal state of scheduling */ - ScheduleState state; + friend class Schedule; + public: virtual ~ScheduleNode() = default; static constexpr const char* _type_key = "tir.Schedule"; TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object); public: + /*! \return The internal state of scheduling */ + virtual ScheduleState state() const = 0; /*! * \brief Take the PrimFunc out of the schedule */ - virtual IRModule Module() const = 0; + virtual IRModule mod() const { return state()->mod; } /*! * \brief Seed the randomness * \param seed The new random seed, -1 if use device random @@ -160,6 +161,22 @@ class ScheduleNode : public runtime::Object { * \return The corresponding block/loop sref */ virtual StmtSRef GetSRef(const Stmt& stmt) const; + /******** Remove random variables ********/ + /*! + * \brief Remove a random variable from the symbol table + * \param block_rv The symbol to be removed + */ + virtual void RemoveRV(const BlockRV& block_rv) = 0; + /*! + * \brief Remove a random variable from the symbol table + * \param block_rv The symbol to be removed + */ + virtual void RemoveRV(const LoopRV& loop_rv) = 0; + /*! + * \brief Remove a random variable from the symbol table + * \param block_rv The symbol to be removed + */ + virtual void RemoveRV(const VarRV& var_rv) = 0; public: /******** Sampling ********/ @@ -390,7 +407,9 @@ class ScheduleNode : public runtime::Object { class Schedule : public runtime::ObjectRef { public: TVM_DLL static Schedule Concrete(PrimFunc func, int64_t seed, bool debug_mode); + TVM_DLL static Schedule Concrete(IRModule func, int64_t seed, bool debug_mode); TVM_DLL static Schedule Meta(PrimFunc func, int64_t seed, bool debug_mode); + TVM_DLL static Schedule Meta(IRModule func, int64_t seed, bool debug_mode); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 48831b022e..5a5ed8055a 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -40,26 +40,24 @@ namespace tvm { namespace tir { -// TODO(@junrushao1994): change `std::unordered_map` to `Map`? - /*! * \brief The state of scheduling, which provides a primitive `Replace` as an interface of all the * scheduling primitives to transform the TensorIR. */ class ScheduleStateNode : public runtime::Object { public: - /*! \brief The function to be scheduled */ - PrimFunc func; // TODO(@junrushao1994): change to IRModule + /*! \brief The module to be scheduled */ + IRModule mod; /*! \brief The block scopes of each block sref */ - std::unordered_map scopes; + Map scopes; /*! \brief The mapping from block/for stmt to its sref */ std::unordered_map stmt2ref; /*! \brief In debug mode, we do extra correctness checking after each replacement */ bool debug_mode; void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - // `scopes` is not visited + v->Visit("mod", &mod); + v->Visit("scopes", &scopes); // `stmt2ref` is not visited v->Visit("debug_mode", &debug_mode); } @@ -95,9 +93,16 @@ class ScheduleStateNode : public runtime::Object { */ class ScheduleState : public runtime::ObjectRef { public: + /*! + * \brief Construct a schedule from an IRModule + * \param mod The IRModule to be scheduled + * \param debug_mode When turned on, additional checks will be performed after each mutation + */ + TVM_DLL explicit ScheduleState(IRModule mod, bool debug_mode); /*! * \brief Construct a schedule from a PrimFunc - * \param func The PrimFunc to be created + * \param mod The PrimFunc to be scheduled + * \param debug_mode When turned on, additional checks will be performed after each mutation */ TVM_DLL explicit ScheduleState(PrimFunc func, bool debug_mode); diff --git a/python/tvm/meta_schedule/schedule.py b/python/tvm/meta_schedule/schedule.py index 67708f28a5..b36d62643d 100644 --- a/python/tvm/meta_schedule/schedule.py +++ b/python/tvm/meta_schedule/schedule.py @@ -44,7 +44,6 @@ class Schedule(TIRSchedule): """ state: ScheduleState - orig_func: tir.PrimFunc trace: Trace def __init__( # pylint: disable=super-init-not-called diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 9ba1e1b59e..2b25a952a0 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -28,7 +28,7 @@ from typing import Any, Callable, List, Tuple import psutil -from tvm import arith, ir, rpc +from tvm import ir, rpc from tvm._ffi import register_func from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import ndk as build_func_ndk @@ -432,7 +432,7 @@ def timed_func() -> BuildResult.TYPE: filename = os.path.join(tempfile.mkdtemp(), "tmp_func." + build_func.output_format) try: func = tvm_build( - measure_input.sch.module, + measure_input.sch.mod["main"], target=measure_input.task.target, target_host=measure_input.task.target_host, ) @@ -545,7 +545,7 @@ def rpc_runner_run( This is only has effect on CPU task. f_create_args: Callable[[TVMContext], List[NDArray]] = None Optional callback to create arguments for functions to measure. This can be used for sparse - workloads when we cannot use random tensors for measurment. + workloads when we cannot use random tensors for measurement. verbose: int = 1 Verbosity level. 0 for silent, 1 to output information during program measuring. diff --git a/python/tvm/tir/schedule.py b/python/tvm/tir/schedule.py index 21af369549..a6364c7bfe 100644 --- a/python/tvm/tir/schedule.py +++ b/python/tvm/tir/schedule.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from tvm._ffi import register_object as _register_object -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, IRModule from tvm.runtime import Object, String from . import _ffi_api_schedule @@ -99,22 +99,20 @@ def get_successor(self, block: StmtSRef) -> List[DepEdge]: class ScheduleState(Object): """The state of scheduling""" - func: PrimFunc + mod: IRModule + scopes: Dict[StmtSRef, BlockScope] debug_mode: bool - def __init__(self, func: PrimFunc, debug_mode: bool): + def __init__(self, func_or_mod: Union[PrimFunc, IRModule], debug_mode: bool): self.__init_handle_by_constructor__( _ffi_api_schedule.ScheduleState, # pylint: disable=no-member - func, + func_or_mod, debug_mode, ) def get_sref(self, stmt: Stmt) -> Optional[StmtSRef]: return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt) # pylint: disable=no-member - def scope(self, block: StmtSRef) -> BlockScope: - return _ffi_api_schedule.ScheduleStateGetScope(self, block) # pylint: disable=no-member - def replace( self, src_sref: StmtSRef, @@ -152,19 +150,21 @@ class BlockRV(Object): class Schedule(Object): """The schedule node for TIR""" - state: ScheduleState - - def __init__(self, func: PrimFunc, debug_mode: bool = False): + def __init__(self, func_or_mod: Union[PrimFunc, IRModule], debug_mode: bool = False): self.__init_handle_by_constructor__( _ffi_api_schedule.Schedule, # pylint: disable=no-member - func, + func_or_mod, -1, # seed debug_mode, ) @property - def module(self) -> PrimFunc: - return self.state.func + def mod(self) -> IRModule: + return _ffi_api_schedule.ScheduleModule(self) # pylint: disable=no-member + + @property + def state(self) -> ScheduleState: + return _ffi_api_schedule.ScheduleGetState(self) # pylint: disable=no-member def show(self, rand_var: Union[LoopRV, BlockRV, ExprRV]) -> str: # TODO(@junrushao1994): complete it diff --git a/src/meta_schedule/analysis.cc b/src/meta_schedule/analysis.cc index 03a58d678c..1e0a6f85c7 100644 --- a/src/meta_schedule/analysis.cc +++ b/src/meta_schedule/analysis.cc @@ -52,7 +52,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block } bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - tir::StmtSRef parent_block_sref = GetScopeSRef(block_sref); + tir::StmtSRef parent_block_sref = GetScopeRoot(block_sref); return parent_block_sref->parent == nullptr; } @@ -95,14 +95,15 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) } bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - tir::StmtSRef parent_sref = tir::GetScopeSRef(block_sref); + tir::StmtSRef parent_sref = tir::GetScopeRoot(block_sref); const auto* block = block_sref->GetStmt(); const auto* parent = parent_sref->GetStmt(); ICHECK(block) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey(); ICHECK(parent) << "TypeError: Expects Block, but gets: " << block_sref->stmt->GetTypeKey(); if (parent_sref->parent == nullptr) { + const tir::PrimFuncNode* func = tir::GetRootPrimFunc(self, parent_sref); for (const tir::BufferRegion& write : block->writes) { - for (const auto& kv : self->func->buffer_map) { + for (const auto& kv : func->buffer_map) { if (write->buffer.get() == kv.second.get()) { return true; } @@ -224,7 +225,7 @@ Optional> GetReadPattern(const Array& block_vars, bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& producer_sref, const tir::StmtSRef& consumer_sref) { // Assume consumer is the only consumer of the producer - tir::StmtSRef parent_sref = tir::GetScopeSRef(producer_sref); + tir::StmtSRef parent_sref = tir::GetScopeRoot(producer_sref); const auto* producer = producer_sref->GetStmt(); const auto* consumer = consumer_sref->GetStmt(); ICHECK(producer) << "TypeError: Expects Block, but gets: " << producer_sref->stmt->GetTypeKey(); diff --git a/src/meta_schedule/feature/per_block_feature.cc b/src/meta_schedule/feature/per_block_feature.cc index 49004094b7..eabab5d434 100644 --- a/src/meta_schedule/feature/per_block_feature.cc +++ b/src/meta_schedule/feature/per_block_feature.cc @@ -1086,7 +1086,7 @@ runtime::NDArray PerBlockFeature(const Schedule& sch, int max_num_buffer_access_ size_t kNumFeature = kNumFeatureGroup1 + kNumFeatureGroup2Subgroup * max_num_buffer_access_features + kNumFeatureGroup3 + kNumFeatureGroup5; - tir::PrimFunc func = GetOnlyFunc(sch->Module()); + tir::PrimFunc func = GetOnlyFunc(sch->mod()); std::vector feature_map = PerBlockFeatureExtractor::Extract(func); DoubleNDArrayPusher ret( diff --git a/src/meta_schedule/sampler.cc b/src/meta_schedule/sampler.cc index ed00ab3269..6be4a36fb5 100644 --- a/src/meta_schedule/sampler.cc +++ b/src/meta_schedule/sampler.cc @@ -318,6 +318,9 @@ std::vector Sampler::SamplePerfectTile(int n_splits, int extent) { } std::vector Sampler::SamplePerfectTile(int n_splits, int extent, int max_innermost_factor) { + if (max_innermost_factor == -1) { + return this->SamplePerfectTile(n_splits, extent); + } CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; std::vector innermost_candidates; innermost_candidates.reserve(max_innermost_factor); diff --git a/src/meta_schedule/sampling.cc b/src/meta_schedule/sampling.cc index 03c94daa01..981ca12a31 100644 --- a/src/meta_schedule/sampling.cc +++ b/src/meta_schedule/sampling.cc @@ -40,18 +40,27 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler } else if (decision->defined()) { // Case 2. Use previous decision result = AsVector(decision->value()); + int n = result.size(); + ICHECK_GE(n, 2); + int64_t len = extent; + for (int i = n - 1; i > 0; --i) { + int64_t& l = result[i]; + // A previous decision could become invalid because of the change of outer tiles + // To handle this case properly, we check if the tiling strategy is still perfect. + // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles + if (len % l != 0) { + l = len; + } + len /= l; + } + result[0] = len; } else { // Case 3. Use fresh new sampling result - std::vector sampled = sampler->SamplePerfectTile(n, extent); + std::vector sampled = sampler->SamplePerfectTile(n, extent, max_innermost_factor); result = std::vector(sampled.begin(), sampled.end()); + ICHECK_LE(sampled.back(), max_innermost_factor); } - // Record the new decision - Array new_decision; - new_decision.reserve(result.size()); - for (int64_t i : result) { - new_decision.push_back(Integer(i)); - } - *decision = new_decision; + *decision = AsArray(result); return result; } diff --git a/src/meta_schedule/schedule.cc b/src/meta_schedule/schedule.cc index 474b1a4ad6..b617b275f0 100644 --- a/src/meta_schedule/schedule.cc +++ b/src/meta_schedule/schedule.cc @@ -33,48 +33,49 @@ namespace tir { tir::Schedule tir::Schedule::Meta(tir::PrimFunc func, int64_t seed, bool debug_mode) { return meta_schedule::Schedule(func, seed, debug_mode); } +tir::Schedule tir::Schedule::Meta(IRModule mod, int64_t seed, bool debug_mode) { + return meta_schedule::Schedule(mod, seed, debug_mode); +} } // namespace tir namespace meta_schedule { -Schedule::Schedule(tir::PrimFunc func, int64_t seed, bool debug_mode) { +Schedule::Schedule(tir::PrimFunc func, int64_t seed, bool debug_mode) + : Schedule(IRModule({{GlobalVar("main"), func}}), seed, debug_mode) {} + +Schedule::Schedule(IRModule mod, int64_t seed, bool debug_mode) { ObjectPtr n = make_object(); - n->state = tir::ScheduleState(func, debug_mode); - n->symbol_table = {}; + n->state_ = tir::ScheduleState(mod, debug_mode); + n->symbol_table_ = {}; + n->analyzer_ = std::make_unique(); if (seed != -1) { n->sampler.Seed(seed); } - n->orig_func = func; n->trace = Trace(); this->data_ = std::move(n); } -/**************** Copy ****************/ +/**************** Utility ****************/ -Schedule ScheduleNode::Copy(int new_seed) const { - tir::Schedule parent = tir::ConcreteScheduleNode::Copy(); - const auto* p = parent.as(); - ICHECK(p != nullptr); +Schedule ScheduleNode::Copy(int64_t new_seed) const { ObjectPtr n = make_object(); - n->state = std::move(p->state); - n->symbol_table = std::move(p->symbol_table); - n->orig_func = orig_func; + tir::ConcreteScheduleNode::MakeCopy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); n->trace = Trace(this->trace->insts, this->trace->decisions); n->sampler.Seed(new_seed); return Schedule(std::move(n)); } -/**************** Sampling ****************/ +void ScheduleNode::Seed(int64_t seed) { this->sampler.Seed(seed); } -using tir::FromRV; -using tir::SetRV; +/**************** Sampling ****************/ Array ScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision) { std::vector result = meta_schedule::SamplePerfectTile( - this->state, &this->sampler, this->GetSRef(loop_rv), n, max_innermost_factor, &decision); - Array result_rvs = SetRV(this, AsArray(result)); + state_, &this->sampler, this->GetSRef(loop_rv), n, max_innermost_factor, &decision); + Array result_rvs = SetRV(AsArray(result)); // Record the instruction this->trace->Append(SamplePerfectTileAttrs::Make(loop_rv, n, max_innermost_factor, result_rvs), decision); @@ -85,16 +86,16 @@ tir::Var ScheduleNode::SampleCategorical(const Array& candidates, // const Array& probs, // Optional decision) { int64_t result = - meta_schedule::SampleCategorical(this->state, &this->sampler, candidates, probs, &decision); - tir::Var result_rv = SetRV(this, result); + meta_schedule::SampleCategorical(state_, &this->sampler, candidates, probs, &decision); + tir::Var result_rv = SetRV(result); this->trace->Append(SampleCategoricalAttrs::Make(candidates, probs, result_rv), decision); return result_rv; } LoopRV ScheduleNode::SampleComputeLocation(const BlockRV& block_rv, Optional decision) { - tir::StmtSRef result = meta_schedule::SampleComputeLocation(this->state, &this->sampler, + tir::StmtSRef result = meta_schedule::SampleComputeLocation(state_, &this->sampler, this->GetSRef(block_rv), &decision); - LoopRV result_rv = SetRV(this, result); + LoopRV result_rv = SetRV(result); this->trace->Append(SampleComputeLocationAttrs::Make(block_rv, result_rv), decision); return result_rv; } @@ -161,14 +162,12 @@ void ScheduleNode::Reorder(const Array& order) { void ScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loop) { tir::ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loop); - // TODO this->trace->Append(ComputeAtAttrs::Make(block_rv, loop_rv)); } void ScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loop) { tir::ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loop); - // TODO this->trace->Append(ReverseComputeAtAttrs::Make(block_rv, loop_rv)); } @@ -275,7 +274,7 @@ void ScheduleNode::MarkLoop(const LoopRV& loop_rv, const String& ann_key, const ICHECK(ann_val->IsInstance() || ann_val->IsInstance()) << "TypeError: Only StringImm and IntImm are supported for now, but gets: " << ann_val->GetTypeKey(); - AddAnn(this->state, this->GetSRef(loop_rv), ann_key, ann_val); + AddAnn(state_, this->GetSRef(loop_rv), ann_key, ann_val); this->trace->Append(MarkLoopAttrs::Make(loop_rv, ann_key, ann_val)); } @@ -283,8 +282,7 @@ void ScheduleNode::MarkBlock(const BlockRV& block_rv, const String& ann_key, const PrimExpr& ann_val) { PrimExpr value = this->Get(ann_val); const auto* int_imm = TVM_TYPE_AS(int_imm, value, IntImmNode); - AddAnn(this->state, this->GetSRef(block_rv), ann_key, - tir::StringImm(std::to_string(int_imm->value))); + AddAnn(state_, this->GetSRef(block_rv), ann_key, tir::StringImm(std::to_string(int_imm->value))); this->trace->Append(MarkBlockAttrs::Make(block_rv, ann_key, ann_val)); } @@ -297,8 +295,17 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleMarkLoop") .set_body_method(&ScheduleNode::MarkLoop); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleMarkBlock") .set_body_method(&ScheduleNode::MarkBlock); -TVM_REGISTER_GLOBAL("meta_schedule.Schedule") // - .set_body_typed(tir::Schedule::Meta); +TVM_REGISTER_GLOBAL("meta_schedule.Schedule") + .set_body_typed([](ObjectRef obj, int64_t seed, bool debug_mode) -> Schedule { + if (const auto* func = obj.as()) { + return Schedule(GetRef(func), seed, debug_mode); + } + if (const auto* mod = obj.as()) { + return Schedule(GetRef(mod), seed, debug_mode); + } + LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); + throw; + }); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleCopy").set_body_typed([](Schedule self, int new_seed) { return self->Copy(new_seed); }); diff --git a/src/meta_schedule/schedule.h b/src/meta_schedule/schedule.h index 20d444a214..8a321b270e 100644 --- a/src/meta_schedule/schedule.h +++ b/src/meta_schedule/schedule.h @@ -34,29 +34,29 @@ class Schedule; /*! \brief The meta schedule class */ class ScheduleNode : public tir::ConcreteScheduleNode { private: - using tir::ScheduleNode::Copy; + friend class tir::Schedule; + using tir::ConcreteScheduleNode::Copy; - public: + protected: friend class Schedule; using TSymbolTable = tir::ConcreteScheduleNode::TSymbolTable; /*! \brief The schedule state */ - using tir::ConcreteScheduleNode::state; + using tir::ConcreteScheduleNode::state_; /*! \brief The symbol table */ - using tir::ConcreteScheduleNode::symbol_table; - /*! \brief The random number sampler */ - Sampler sampler; - /*! \brief The original TIR PrimFunc to be scheduled */ - tir::PrimFunc orig_func; + using tir::ConcreteScheduleNode::symbol_table_; + + public: /*! \brief The trace of the program execution */ Trace trace; + /*! \brief The random number sampler */ + Sampler sampler; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("state", &state); - v->Visit("symbol_table", &symbol_table); - // `sampler` is not visited - v->Visit("orig_func", &orig_func); + // `state_` is not visited + // `symbol_table_` is not visited v->Visit("trace", &trace); + // `sampler` is not visited } static constexpr const char* _type_key = "meta_schedule.Schedule"; @@ -69,7 +69,10 @@ class ScheduleNode : public tir::ConcreteScheduleNode { * original schedule, and vice versa. * \return A new schedule. */ - Schedule Copy(int new_seed) const; + Schedule Copy(int64_t new_seed) const; + + void Seed(int64_t seed = -1) final; + /**************** Sampling ****************/ /*! * \brief Apply the instruction SamplePerfectTile @@ -114,7 +117,6 @@ class ScheduleNode : public tir::ConcreteScheduleNode { * \brief Get the child blocks of a specific parent block/loop * \param block_rv The random variable that points to the parent block * \return A list of child blocks - * TODO(@junrushao1994): revisit */ Array GetChildBlocks(const BlockRV& block_rv) final; /*! @@ -305,9 +307,8 @@ class ScheduleNode : public tir::ConcreteScheduleNode { class Schedule : public tir::Schedule { public: using TSymbolTable = ScheduleNode::TSymbolTable; - explicit Schedule(tir::PrimFunc func, int64_t seed = -1, bool debug_mode = false); - + explicit Schedule(IRModule mod, int64_t seed = -1, bool debug_mode = false); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, tir::Schedule, ScheduleNode); }; diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 5a240639f7..ffe37c0872 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -119,13 +119,13 @@ class BlockCollector : public tir::StmtVisitor { public: /*! \brief Constructor */ explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) { - const auto* realize = GetOnlyFunc(sch->Module())->body.as(); + const auto* realize = GetOnlyFunc(sch->mod())->body.as(); root_block_ = realize->block.get(); } /*! \brief Entry point */ Array Run() { - VisitStmt(GetOnlyFunc(sch_->Module())->body); + VisitStmt(GetOnlyFunc(sch_->mod())->body); Array result = std::move(result_); return result; } @@ -174,7 +174,7 @@ Array PostOrderApplyNode::GetSupport(const SearchTask& task, Sampler* ICHECK(block) << "TypeError: Expects BlockNode, but gets: " << block_sref->stmt->GetTypeKey(); // TODO(@junrushao1994): replace this quick hack - if (!tir::GetBlocks(sch->state, block->name_hint).empty()) { + if (!tir::GetBlocks(sch->state(), block->name_hint).empty()) { // apply the rule to the block Array applied = rule->Apply(task, sch, /*block=*/sch->GetBlock(block->name_hint)); diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index 50ae2538be..b977036092 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -55,7 +55,7 @@ class PostprocRewriteTensorize { Optional FindTensorized(const Schedule& sch) { Optional result = NullOpt; - tir::PrimFunc func = GetOnlyFunc(sch->Module()); + tir::PrimFunc func = GetOnlyFunc(sch->mod()); tir::PreOrderVisit(func->body, [&result, &sch](const ObjectRef& obj) -> bool { if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); @@ -72,7 +72,7 @@ class PostprocRewriteTensorize { bool CanTensorize(const tir::Schedule& sch, const tir::StmtSRef& block_sref, const tir::TensorIntrin& intrin) { Optional opt_tensorize_info = - GetTensorizeLoopMapping(sch->state, block_sref, intrin->description); + GetTensorizeLoopMapping(sch->state(), block_sref, intrin->description); if (!opt_tensorize_info.defined()) { return false; } @@ -93,9 +93,9 @@ class PostprocRewriteTensorize { while (Optional opt_block_sref = FindTensorized(sch)) { tir::StmtSRef block_sref = opt_block_sref.value(); // Remove the annotation - DelAnn(sch->state, block_sref, tir::attr::auto_tensorize); + DelAnn(sch->state(), block_sref, tir::attr::auto_tensorize); // Get the surrounding loops - Array loop_srefs = tir::GetAxes(sch->state, block_sref); + Array loop_srefs = tir::GetAxes(sch->state(), block_sref); // Decompose Reduction { // (TODO) bohan @@ -103,7 +103,7 @@ class PostprocRewriteTensorize { // Tensorize for (const tir::TensorIntrin& intrin : tensor_intrins) { if (CanTensorize(sch, block_sref, intrin)) { - tir::schedule::Tensorize(sch->state, loop_srefs[0], intrin); + tir::schedule::Tensorize(sch->state(), loop_srefs[0], intrin); return true; } } @@ -144,7 +144,7 @@ class PostprocRewriteCooperativeFetch { bool Proc(const Schedule& sch) const { int idx = 0; while (Optional opt_block_sref = - FindBlockSRef(sch->state, MakeBlockFinder(sch, &idx))) { + FindBlockSRef(sch->state(), MakeBlockFinder(sch, &idx))) { // Extract block info tir::StmtSRef block_sref = opt_block_sref.value(); const auto* block = block_sref->GetStmt(); @@ -156,7 +156,7 @@ class PostprocRewriteCooperativeFetch { LoopRV loop_rv = loop_rvs[n_loops - 1 - idx]; tir::StmtSRef loop_sref = sch->GetSRef(loop_rv); // Remove the annotation - DelAnn(sch->state, loop_sref, tir::attr::loop_type); + DelAnn(sch->state(), loop_sref, tir::attr::loop_type); // Find the threadIdx.x binding PrimExpr thread_idx_extent{nullptr}; for (const tir::StmtSRefNode* sref = loop_sref->parent;; sref = sref->parent) { @@ -233,16 +233,16 @@ class PostprocRewriteParallelizeVectorizeUnroll { static void RemoveParsedAnn(const tir::Schedule& sch, const tir::StmtSRef& block_sref, const Parsed& parsed) { if (parsed.max_parallel_extent != -1) { - DelAnn(sch->state, block_sref, tir::attr::auto_parallel_extent); + DelAnn(sch->state(), block_sref, tir::attr::auto_parallel_extent); } if (parsed.max_vectorize_extent != -1) { - DelAnn(sch->state, block_sref, tir::attr::auto_vectorize_extent); + DelAnn(sch->state(), block_sref, tir::attr::auto_vectorize_extent); } if (parsed.unroll_explicit != -1) { - DelAnn(sch->state, block_sref, tir::attr::auto_unroll_explicit); + DelAnn(sch->state(), block_sref, tir::attr::auto_unroll_explicit); } if (parsed.unroll_implicit != -1) { - DelAnn(sch->state, block_sref, tir::attr::auto_unroll_implicit); + DelAnn(sch->state(), block_sref, tir::attr::auto_unroll_implicit); } } @@ -260,7 +260,7 @@ class PostprocRewriteParallelizeVectorizeUnroll { loop_types.reserve(n_loops); for (const LoopRV& loop_rv : loop_rvs) { loop_srefs.push_back(sch->GetSRef(loop_rv)); - loop_types.push_back(GetLoopIterType(sch->state, loop_srefs.back())); + loop_types.push_back(GetLoopIterType(sch->state(), loop_srefs.back())); } } // Calculate the parallelize extent @@ -330,7 +330,7 @@ class PostprocRewriteParallelizeVectorizeUnroll { bool Proc(const Schedule& sch) const { Parsed parsed; while (Optional opt_block_sref = - FindBlockSRef(sch->state, MakeAnnParser(&parsed))) { + FindBlockSRef(sch->state(), MakeAnnParser(&parsed))) { // Extract block info tir::StmtSRef block_sref = opt_block_sref.value(); RemoveParsedAnn(sch, block_sref, parsed); @@ -393,14 +393,14 @@ class PostprocRewriteUnboundBlocks { /*! \brief Find the first block that is not bound to any thread axes */ static Optional Find(const tir::Schedule& sch) { UnboundBlockFinder finder(sch); - finder(GetOnlyFunc(sch->Module())->body); + finder(GetOnlyFunc(sch->mod())->body); return finder.block_ == nullptr ? Optional(NullOpt) : sch->GetSRef(finder.block_); } private: explicit UnboundBlockFinder(const tir::Schedule& sch) : sch_(sch), block_(nullptr) { - const auto* realize = GetOnlyFunc(sch->Module())->body.as(); + const auto* realize = GetOnlyFunc(sch->mod())->body.as(); root_block_ = realize->block.get(); } @@ -408,7 +408,7 @@ class PostprocRewriteUnboundBlocks { if (block_) { return; } - if (Optional opt_ann = GetAnn(sch_->GetSRef(loop), tir::attr::loop_type)) { + if (Optional opt_ann = GetBinding(sch_->GetSRef(loop))) { String ann = opt_ann.value(); if (ann == "threadIdx.x" || ann == "blockIdx.x" || ann == "vthread") { return; @@ -442,8 +442,8 @@ class PostprocRewriteUnboundBlocks { int n_spatial_loops = 0; for (const LoopRV& loop_rv : loop_rvs) { tir::StmtSRef loop_sref = sch->GetSRef(loop_rv); - tir::IterVarType iter_type = GetLoopIterType(sch->state, loop_sref); - if (iter_type != tir::kDataPar || GetAnn(loop_sref, tir::attr::loop_type).defined()) { + tir::IterVarType iter_type = GetLoopIterType(sch->state(), loop_sref); + if (iter_type != tir::kDataPar || GetBinding(loop_sref).defined()) { break; } ++n_spatial_loops; @@ -512,24 +512,16 @@ class PostprocRewriteReductionBlock { } bool Proc(const Schedule& sch) const { - while (const tir::BlockNode* block = Find(GetOnlyFunc(sch->Module())->body)) { + while (const tir::BlockNode* block = Find(GetOnlyFunc(sch->mod())->body)) { BlockRV block_rv = sch->GetBlock(block->name_hint); Array loop_rvs = sch->GetAxes(block_rv); int n_loops = loop_rvs.size(); for (int i = 0; i < n_loops; ++i) { const LoopRV& loop_rv = loop_rvs[i]; tir::StmtSRef loop_sref = sch->GetSRef(loop_rv); - if (GetLoopIterType(sch->state, loop_sref) != tir::kDataPar) { + if (GetLoopIterType(sch->state(), loop_sref) != tir::kDataPar) { // Insert the initializing block above the first loop which is not data parallel. - BlockRV init = sch->DecomposeReduction(block_rv, loop_rvs[i]); - Array loops = sch->GetAxes(init); - if (!loops.empty()) { - const LoopRV& last_loop = loops.back(); - const tir::StmtSRef& loop_sref = sch->GetSRef(last_loop); - if (HasSingleChild(loop_sref)) { - sch->Vectorize(last_loop); - } - } + sch->DecomposeReduction(block_rv, loop_rvs[i]); break; } } @@ -563,7 +555,7 @@ class PostprocDisallowDynamicLoops { } return true; }; - tir::PreOrderVisit(GetOnlyFunc(sch->Module())->body, f_visit); + tir::PreOrderVisit(GetOnlyFunc(sch->mod())->body, f_visit); return !has_dyn_ext; } }; @@ -606,7 +598,6 @@ class PostprocVerifyGPUCode { {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, {"max_local_memory_per_block", Extract(target, "registers_per_block")}, {"max_threads_per_block", Extract(target, "max_threads_per_block")}, - {"max_vector_bytes", Extract(target, "vector_unit_bytes")}, {"max_vthread", Integer(8)}, }; return tir::VerifyGPUCode(func, constraints); @@ -614,14 +605,13 @@ class PostprocVerifyGPUCode { bool Proc(const SearchTask& task, const Schedule& sch) const { static tir::transform::Sequential passes = MakePasses(); - GlobalVar main_func("main"); - IRModule mod = sch->Module(); + IRModule mod = sch->mod(); try { mod = passes(std::move(mod)); } catch (const dmlc::Error& e) { return false; } - return VerifyGPU(Downcast(mod->Lookup(main_func)), task->target); + return VerifyGPU(GetOnlyFunc(mod), task->target); } }; diff --git a/src/meta_schedule/space/search_rule.cc b/src/meta_schedule/space/search_rule.cc index 6c1e65e4a7..2856fdb991 100644 --- a/src/meta_schedule/space/search_rule.cc +++ b/src/meta_schedule/space/search_rule.cc @@ -77,16 +77,16 @@ class RuleInlinePureSpatial { static bool NeedsInline(const tir::Schedule& sch, const tir::StmtSRef& block_sref, bool strict_mode) { - if (!IsSpatial(sch->state, block_sref)) { + if (!IsSpatial(sch->state(), block_sref)) { return false; } - if (IsOutputBlock(sch->state, block_sref)) { + if (IsOutputBlock(sch->state(), block_sref)) { return false; } - if (strict_mode && !IsStrictlyInlineable(sch->state, block_sref)) { + if (strict_mode && !IsStrictlyInlineable(sch->state(), block_sref)) { return false; } - Array loop_srefs = tir::GetAxes(sch->state, block_sref); + Array loop_srefs = tir::GetAxes(sch->state(), block_sref); for (const tir::StmtSRef& loop_sref : loop_srefs) { if (!HasSingleChild(loop_sref)) { return false; @@ -99,7 +99,7 @@ class RuleInlinePureSpatial { Array Apply(const SearchTask& task, const Schedule& sch, const BlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); - if (IsSubrootBlock(sch->state, block_sref) && NeedsInline(sch, block_sref, strict_mode)) { + if (IsSubrootBlock(sch->state(), block_sref) && NeedsInline(sch, block_sref, strict_mode)) { sch->ComputeInline(block_rv); } return {sch}; @@ -360,15 +360,15 @@ class RuleMultiLevelTiling { } BlockRV consumer_rv = consumers[0]; tir::StmtSRef consumer_sref = sch->GetSRef(consumer_rv); - if (!IsSpatial(sch->state, consumer_sref)) { + if (!IsSpatial(sch->state(), consumer_sref)) { break; } - if (!IsElementWiseMatch(sch->state, sch->GetSRef(current_block_rv), consumer_sref)) { + if (!IsElementWiseMatch(sch->state(), sch->GetSRef(current_block_rv), consumer_sref)) { break; } // Then `consumer_rv` must be an elementwise-matched consumer of `block_rv` if (!RuleInlinePureSpatial::NeedsInline(sch, consumer_sref, this->consumer_inline_strict)) { - if (IsOutputBlock(sch->state, consumer_sref)) { + if (IsOutputBlock(sch->state(), consumer_sref)) { result.push_back(consumer_rv); } break; @@ -428,7 +428,7 @@ class RuleMultiLevelTiling { std::vector> tiles(structure.size()); // Get block vars and loop axes // TODO: fix - Array iter_types = GetBlockVarTypes(sch->state, sch->GetSRef(block_rv)); + Array iter_types = GetBlockVarTypes(sch->state(), sch->GetSRef(block_rv)); Array axes = sch->GetAxes(block_rv); ICHECK_EQ(axes.size(), iter_types.size()); // For each loop axis, tile it @@ -508,7 +508,7 @@ class RuleMultiLevelTiling { if (HasAnyAnn(block_sref)) { return {sch}; } - if (!NeedsMultiLevelTiling(sch->state, block_sref)) { + if (!NeedsMultiLevelTiling(sch->state(), block_sref)) { return {sch}; } // States @@ -556,13 +556,13 @@ SearchRule MultiLevelTiling(String structure, int max_innermost_factor, bool mus class RuleRandomComputeLocation { public: bool IsFreeBlock(const tir::Schedule sch, const tir::StmtSRef& block_sref) const { - if (!IsSubrootBlock(sch->state, block_sref)) { + if (!IsSubrootBlock(sch->state(), block_sref)) { return false; } - if (!sch->state->scopes.at(tir::GetScopeSRef(block_sref))->IsComplete(block_sref)) { + if (!sch->state()->scopes.at(tir::GetScopeRoot(block_sref))->IsComplete(block_sref)) { return false; } - Array loop_srefs = tir::GetAxes(sch->state, block_sref); + Array loop_srefs = tir::GetAxes(sch->state(), block_sref); for (const tir::StmtSRef& loop_sref : loop_srefs) { if (!HasSingleChild(loop_sref)) { return false; @@ -595,11 +595,10 @@ class RuleRandomComputeLocation { sch->ComputeAt(block_rv, compute_at_loc, true); } catch (const dmlc::Error& e) { // ComputeAt fails, cleanup the following before re-try: - // 1) sym_tab - // 2) decisions - // 3) trace + // 1) trace: instruction & decisions + // 2) sym_tab sch->trace->Pop(); - sch->symbol_table.erase(compute_at_loc); + sch->RemoveRV(compute_at_loc); continue; } break; @@ -643,7 +642,7 @@ class RuleParallelizeVectorizeUnroll { warned_num_cores_missing(static_cast(other.warned_num_cores_missing)) {} static bool IsLeftmostSubroot(const tir::Schedule& sch, tir::StmtSRef block_sref) { - if (!IsSubrootBlock(sch->state, block_sref)) { + if (!IsSubrootBlock(sch->state(), block_sref)) { return false; } tir::StmtSRefNode* child_sref = block_sref.operator->(); @@ -672,7 +671,7 @@ class RuleParallelizeVectorizeUnroll { tir::StmtSRef block_sref = sch->GetSRef(block_rv); // Check if the block is root and leaf bool is_leftmost_root = IsLeftmostSubroot(sch, block_sref); - bool is_leaf = IsLeafBlock(sch->state, block_sref); + bool is_leaf = IsLeafBlock(sch->state(), block_sref); // Parallelization if (max_jobs_per_core != -1 && is_leftmost_root) { int max_extent = @@ -792,7 +791,7 @@ class RuleMarkTensorize { } Schedule cur_sch = next_sch.value(); if (Optional opt_tensorize_info = - GetTensorizeLoopMapping(cur_sch->state, block_sref, intrin->description)) { + GetTensorizeLoopMapping(cur_sch->state(), block_sref, intrin->description)) { BlockizeAndMark(cur_sch, block_rv, intrin->description, opt_tensorize_info.value().get()); result.push_back(cur_sch); next_sch = NullOpt; diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index 4861e50306..e167aaa630 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -185,7 +185,7 @@ class MutatorComputeLocation { BlockRV block_rv = Downcast(inputs[0]); tir::StmtSRef block_sref = sch->GetSRef(block_rv); // Extract locations that can be computed at - Array loop_srefs = CollectComputeLocation(sch->state, block_sref); + Array loop_srefs = CollectComputeLocation(sch->state(), block_sref); std::vector locs{-2, -1}; { int i = 0; @@ -366,10 +366,10 @@ class MutatorParallel { // Step 2. Fetch the block and the loops above it. Furthermore, get their loop types. BlockRV block_rv = Downcast(inputs[0]); tir::StmtSRef block_sref = sch->GetSRef(block_rv); - Array loop_srefs = tir::GetAxes(sch->state, block_sref); + Array loop_srefs = tir::GetAxes(sch->state(), block_sref); std::vector loop_types; for (const tir::StmtSRef& loop_sref : loop_srefs) { - loop_types.emplace_back(GetLoopIterType(sch->state, loop_sref)); + loop_types.emplace_back(GetLoopIterType(sch->state(), loop_sref)); } // Step 3. Get the original parallel extent. int ori_extent = inst->inputs[1].as()->value; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index d7b7530bb0..9dc9058646 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -154,7 +154,7 @@ inline bool DomainEqual(const Array& lhs, const Array& rhs) { template inline Optional FindBlockSRef(const tir::ScheduleState& sch, FPredicate predicate) { Optional result = NullOpt; - tir::PreOrderVisit(sch->func->body, [&sch, &result, &predicate](const ObjectRef& obj) -> bool { + auto f_visit = [&sch, &result, &predicate](const ObjectRef& obj) -> bool { if (result.defined()) { return false; } @@ -165,21 +165,47 @@ inline Optional FindBlockSRef(const tir::ScheduleState& sch, FPre } } return true; - }); + }; + for (const auto& kv : sch->mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + tir::PreOrderVisit(func->body, f_visit); + } + } return result; } /**************** TIR Annotation ****************/ -inline bool HasBinding(const tir::StmtSRef& sref, const String& thread_binding) { - const auto* loop = sref->GetStmt(); - ICHECK(loop) << "ValueError: Expect loop sref here"; - if (loop->thread_binding) { - ICHECK(loop->thread_binding.value()->iter_type == tir::IterVarType::kThreadIndex); - return loop->thread_binding.value()->thread_tag == thread_binding; - } else { +inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag) { + const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (!loop->thread_binding.defined()) { + return false; + } + tir::IterVar binding = loop->thread_binding.value(); + if (binding->iter_type != tir::IterVarType::kThreadIndex) { return false; } + return binding->thread_tag == thread_tag; +} + +inline Optional GetBinding(const tir::StmtSRef& loop_sref) { + const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (!loop->thread_binding.defined()) { + return NullOpt; + } + tir::IterVar binding = loop->thread_binding.value(); + if (loop->kind == tir::ForKind::kParallel) { + return String("parallel"); + } else if (loop->kind == tir::ForKind::kVectorized) { + return String("vectorized"); + } else if (loop->kind == tir::ForKind::kUnrolled) { + return String("unrolled"); + } + if (binding->iter_type != tir::IterVarType::kThreadIndex) { + return NullOpt; + } + return binding->thread_tag; } inline Optional GetAnn(const tir::StmtSRef& sref, const String& ann_key) { diff --git a/src/tir/schedule/analysis.cc b/src/tir/schedule/analysis.cc index 6769833a40..77e45a5122 100644 --- a/src/tir/schedule/analysis.cc +++ b/src/tir/schedule/analysis.cc @@ -81,7 +81,7 @@ void VerifyRegionCover(const ScheduleState& self, const StmtSRef& consumer_block return; } const auto* consumer_block = consumer_block_sref->GetStmt(); - const StmtSRef& parent_block_sref = GetScopeSRef(consumer_block_sref); + const StmtSRef& parent_block_sref = GetScopeRoot(consumer_block_sref); // Gather all the producers struct Producer { /*! \brief The block that writes the buffer */ @@ -171,7 +171,12 @@ void VerifySRefTree(const ScheduleState& self) { n_block_sref_visited_(0) {} void Verify() { - VisitStmt(self_->func->body); + for (const auto& kv : self_->mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + VisitStmt(func->body); + } + } ICHECK_EQ(n_sref_visited_, static_cast(self_->stmt2ref.size())); for (const auto& kv : self_->scopes) { const StmtSRef& sref = kv.first; @@ -278,7 +283,7 @@ void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); } -StmtSRef GetScopeSRef(const StmtSRef& sref) { +StmtSRef GetScopeRoot(const StmtSRef& sref) { for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { if (p->stmt->IsInstance()) { return GetRef(p); @@ -336,7 +341,7 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { Array pred_edges = self->scopes - .at(GetScopeSRef(block_sref)) // + .at(GetScopeRoot(block_sref)) // ->GetPredecessors(block_sref); Array results; results.reserve(pred_edges.size()); @@ -350,7 +355,7 @@ Array GetProducers(const ScheduleState& self, const StmtSRef& block_sr Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { Array succ_edges = self->scopes - .at(GetScopeSRef(block_sref)) // + .at(GetScopeRoot(block_sref)) // ->GetSuccessors(block_sref); Array results; results.reserve(succ_edges.size()); @@ -456,5 +461,25 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +const PrimFuncNode* GetRootPrimFunc(const ScheduleState& self, const StmtSRef& sref) { + const StmtSRefNode* p = sref.get(); + for (; p->parent != nullptr; p = p->parent) { + } + for (const auto& kv : self->mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (const auto* realize = func->body.as()) { + if (realize->block.get() == p->stmt) { + return func; + } + } + } + } + LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the " + "statement:\n" + << GetRef(sref->stmt); + throw; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index f8d07f557c..c799befb11 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -71,7 +71,7 @@ IterVarType GetLoopIterType(const ScheduleState& self, const StmtSRef& loop_sref * \param sref The block or loop sref to be retrieved * \return The sref to the scope root block */ -StmtSRef GetScopeSRef(const StmtSRef& sref); +StmtSRef GetScopeRoot(const StmtSRef& sref); /******** Block-loop relation ********/ /*! @@ -118,6 +118,14 @@ bool HasSingleChild(const StmtSRef& loop_or_block_sref); Array CollectComputeLocation(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the pointer to the PrimFunc that the statement pointed by sref belongs to + * \param self The state of scheduling + * \param sref The sref to the statement in the query + * \return A pointer to the PrimFunc the statement belongs to + */ +const PrimFuncNode* GetRootPrimFunc(const ScheduleState& self, const StmtSRef& sref); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index fabe277d60..92ecbc20df 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -26,9 +26,14 @@ namespace tvm { namespace tir { Schedule Schedule::Concrete(PrimFunc func, int64_t seed, bool debug_mode) { + return Schedule::Concrete(IRModule({{GlobalVar("main"), func}}), seed, debug_mode); +} + +Schedule Schedule::Concrete(IRModule mod, int64_t seed, bool debug_mode) { ObjectPtr n = make_object(); - n->state = ScheduleState(func, debug_mode); - n->symbol_table = {}; + n->state_ = ScheduleState(mod, debug_mode); + n->symbol_table_ = {}; + n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); } @@ -115,9 +120,8 @@ struct SRefTranslator { } /*! \brief Translate SMap */ - SMap Trans(const SMap& scopes) { - SMap result; - result.reserve(scopes.size()); + Map Trans(const Map& scopes) { + Map result; for (const auto& kv : scopes) { const StmtSRef& old_sref = kv.first; const BlockScope& old_scope = kv.second; @@ -125,7 +129,7 @@ struct SRefTranslator { scope->forward_edges = Trans(old_scope->forward_edges); scope->backward_edges = Trans(old_scope->backward_edges); scope->buffer_writers = Trans(old_scope->buffer_writers); - result[Trans(old_sref)] = BlockScope(std::move(scope)); + result.Set(Trans(old_sref), BlockScope(std::move(scope))); } return result; } @@ -159,26 +163,24 @@ struct SRefTranslator { std::unordered_map trans_; }; -Schedule ConcreteScheduleNode::Copy() const { - const ScheduleState& src_state = this->state; +void ConcreteScheduleNode::MakeCopy(ScheduleState* new_state, + TSymbolTable* new_symbol_table) const { + const ScheduleState& src_state = state_; SRefTranslator trans(src_state); ObjectPtr n = make_object(); - n->func = src_state->func; + n->mod = src_state->mod; n->scopes = trans.Trans(src_state->scopes); n->stmt2ref = trans.Trans(src_state->stmt2ref); n->debug_mode = src_state->debug_mode; - ObjectPtr p = make_object(); - p->state = ScheduleState(std::move(n)); - p->symbol_table = trans.Trans(this->symbol_table); - return Schedule(std::move(p)); + *new_state = ScheduleState(std::move(n)); + *new_symbol_table = trans.Trans(this->symbol_table_); } -void ConcreteScheduleNode::Seed(int64_t seed) { - // do nothing -} - -IRModule ConcreteScheduleNode::Module() const { - return IRModule({{GlobalVar("main"), this->state->func}}); +Schedule ConcreteScheduleNode::Copy() const { + ObjectPtr n = make_object(); + MakeCopy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); + return Schedule(std::move(n)); } /******** Lookup random variables ********/ @@ -196,8 +198,8 @@ For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { } int64_t ConcreteScheduleNode::Get(const Var& var_rv) const { - auto it = this->symbol_table.find(var_rv); - if (it == this->symbol_table.end()) { + auto it = this->symbol_table_.find(var_rv); + if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << var_rv; } const ObjectRef& obj = (*it).second; @@ -215,12 +217,12 @@ PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { int64_t result = this->Get(var); return Integer(result); }); - return analyzer.Simplify(transformed); + return this->analyzer_->Simplify(transformed); } StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { - auto it = this->symbol_table.find(block_rv); - if (it == this->symbol_table.end()) { + auto it = this->symbol_table_.find(block_rv); + if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; } const ObjectRef& obj = (*it).second; @@ -238,8 +240,8 @@ StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { static StmtSRef inline_mark = StmtSRef::InlineMark(); static StmtSRef root_mark = StmtSRef::RootMark(); - auto it = this->symbol_table.find(loop_rv); - if (it == this->symbol_table.end()) { + auto it = this->symbol_table_.find(loop_rv); + if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; } const ObjectRef& obj = (*it).second; @@ -260,64 +262,69 @@ StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { return GetRef(sref); } -StmtSRef ConcreteScheduleNode::GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); } +void ConcreteScheduleNode::RemoveRV(const BlockRV& block_rv) { RemoveFromSymbolTable(block_rv); } + +void ConcreteScheduleNode::RemoveRV(const LoopRV& loop_rv) { RemoveFromSymbolTable(loop_rv); } + +void ConcreteScheduleNode::RemoveRV(const VarRV& var_rv) { RemoveFromSymbolTable(var_rv); } -StmtSRef ConcreteScheduleNode::GetSRef(const StmtNode* stmt) const { - auto it = this->state->stmt2ref.find(stmt); - if (it == this->state->stmt2ref.end()) { - LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR"; +void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { + auto it = this->symbol_table_.find(obj); + if (it != this->symbol_table_.end()) { + this->symbol_table_.erase(obj); + } else { + LOG(FATAL) << "IndexError: Cannot find the object in the symbol table: " << obj; + throw; } - return it->second; } /******** Block/Loop relation ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name) { - Array blocks = tir::GetBlocks(this->state, name); + Array blocks = tir::GetBlocks(state_, name); CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size() << " blocks with the name: " << name; - return SetRV(this, blocks[0]); + return SetRV(blocks[0]); } Array ConcreteScheduleNode::GetAxes(const BlockRV& block_rv) { - return SetRV(this, tir::GetAxes(this->state, this->GetSRef(block_rv))); + return SetRV(tir::GetAxes(state_, this->GetSRef(block_rv))); } Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - return SetRV(this, tir::GetChildBlocks(this->state, this->GetSRef(block_rv), false)); + return SetRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv), false)); } Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - return SetRV(this, tir::GetChildBlocks(this->state, this->GetSRef(loop_rv), false)); + return SetRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv), false)); } Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { - return SetRV(this, tir::GetProducers(this->state, this->GetSRef(block_rv))); + return SetRV(tir::GetProducers(state_, this->GetSRef(block_rv))); } Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { - return SetRV(this, tir::GetConsumers(this->state, this->GetSRef(block_rv))); + return SetRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); } /******** Schedule: loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; - Array loop_srefs = FromRV(this, loop_rvs); + Array loop_srefs = FromRV(loop_rvs); while (loop_srefs.size() >= 2) { StmtSRef inner_sref = loop_srefs.back(); loop_srefs.pop_back(); StmtSRef outer_sref = loop_srefs.back(); loop_srefs.pop_back(); - StmtSRef fused = schedule::Fuse(this->state, outer_sref, inner_sref); + StmtSRef fused = schedule::Fuse(state_, outer_sref, inner_sref); loop_srefs.push_back(fused); } - return SetRV(this, loop_srefs[0]); + return SetRV(loop_srefs[0]); } Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array>& factor_rvs) { - arith::Analyzer analyzer; // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref); @@ -330,7 +337,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, int p = -1; for (int i = 0; i < n; ++i) { PrimExpr factor = this->Get(factor_rvs[i].value_or(Integer(-1))); - if (analyzer.CanProve(factor == -1)) { + if (analyzer_->CanProve(factor == -1)) { CHECK_EQ(p, -1) << "ValueError: `split` requires at most one `None` factor, but gets: " << factor_rvs; p = i; @@ -344,7 +351,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, for (int i = 1; i < n; ++i) { prod = prod * factors[i]; } - if (analyzer.CanProve(prod == len)) { + if (analyzer_->CanProve(prod == len)) { p = 0; factors[0] = Integer(-1); } else { @@ -357,8 +364,8 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, for (int i = n - 1; i > p; --i) { PrimExpr inner_len = factors[i]; PrimExpr outer_len = floordiv(len + inner_len - 1, inner_len); - Array parts = schedule::Split(this->state, // - loop_sref, // + Array parts = schedule::Split(state_, // + loop_sref, // outer_len, inner_len); ICHECK_EQ(parts.size(), 2); loop_sref = parts[0]; @@ -369,8 +376,8 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, for (int i = 0; i < p; ++i) { PrimExpr outer_len = factors[i]; PrimExpr inner_len = floordiv(len + outer_len - 1, outer_len); - Array parts = schedule::Split(this->state, // - loop_sref, // + Array parts = schedule::Split(state_, // + loop_sref, // outer_len, inner_len); ICHECK_EQ(parts.size(), 2); results[i] = parts[0]; @@ -378,11 +385,11 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, len = inner_len; } results[p] = loop_sref; - return SetRV(this, Array{results.begin(), results.end()}); + return SetRV(Array{results.begin(), results.end()}); } void ConcreteScheduleNode::Reorder(const Array& order) { - schedule::Reorder(this->state, FromRV(this, order)); + schedule::Reorder(state_, FromRV(order)); } /******** Schedule: compute location ********/ @@ -395,9 +402,9 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop if (loop_sref.same_as(root_mark)) { // do nothing } else if (loop_sref.same_as(inline_mark)) { - schedule::ComputeInline(this->state, this->GetSRef(block_rv)); + schedule::ComputeInline(state_, this->GetSRef(block_rv)); } else { - schedule::ComputeAt(this->state, // + schedule::ComputeAt(state_, // this->GetSRef(block_rv), // loop_sref, // preserve_unit_loop); @@ -412,9 +419,9 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR if (loop_sref.same_as(root_mark)) { // do nothing } else if (loop_sref.same_as(inline_mark)) { - schedule::ReverseComputeInline(this->state, this->GetSRef(block_rv)); + schedule::ReverseComputeInline(state_, this->GetSRef(block_rv)); } else { - schedule::ReverseComputeAt(this->state, // + schedule::ReverseComputeAt(state_, // this->GetSRef(block_rv), // loop_sref, // preserve_unit_loop); @@ -422,29 +429,29 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR } void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { - schedule::ComputeInline(this->state, this->GetSRef(block_rv)); + schedule::ComputeInline(state_, this->GetSRef(block_rv)); } void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { - schedule::ReverseComputeInline(this->state, this->GetSRef(block_rv)); + schedule::ReverseComputeInline(state_, this->GetSRef(block_rv)); } /******** Schedule: parallelize / annotate ********/ void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { - schedule::Vectorize(this->state, this->GetSRef(loop_rv)); + schedule::Vectorize(state_, this->GetSRef(loop_rv)); } void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { - schedule::Parallel(this->state, this->GetSRef(loop_rv)); + schedule::Parallel(state_, this->GetSRef(loop_rv)); } void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { - schedule::Unroll(this->state, this->GetSRef(loop_rv)); + schedule::Unroll(state_, this->GetSRef(loop_rv)); } void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const IterVar& thread) { - schedule::Bind(this->state, this->GetSRef(loop_rv), thread); + schedule::Bind(state_, this->GetSRef(loop_rv), thread); } void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread) { @@ -452,16 +459,16 @@ void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread) { Var(thread), // kThreadIndex, // thread); - schedule::Bind(this->state, this->GetSRef(loop_rv), iter_var); + schedule::Bind(state_, this->GetSRef(loop_rv), iter_var); } void ConcreteScheduleNode::DoubleBuffer(const BlockRV& block_rv) { - schedule::DoubleBuffer(this->state, this->GetSRef(block_rv)); + schedule::DoubleBuffer(state_, this->GetSRef(block_rv)); } void ConcreteScheduleNode::Pragma(const LoopRV& loop_rv, const String& pragma_type, const ExprRV& pragma_value) { - schedule::Pragma(this->state, // + schedule::Pragma(state_, // this->GetSRef(loop_rv), // pragma_type, // this->Get(pragma_value)); @@ -471,24 +478,24 @@ void ConcreteScheduleNode::Pragma(const LoopRV& loop_rv, const String& pragma_ty BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int i, const String& storage_scope) { - return SetRV(this, schedule::CacheRead(this->state, // - this->GetSRef(block_rv), // - i, // - storage_scope)); + return SetRV(schedule::CacheRead(state_, // + this->GetSRef(block_rv), // + i, // + storage_scope)); } BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int i, const String& storage_scope) { - return SetRV(this, schedule::CacheWrite(this->state, // - this->GetSRef(block_rv), // - i, // - storage_scope)); + return SetRV(schedule::CacheWrite(state_, // + this->GetSRef(block_rv), // + i, // + storage_scope)); } /******** Schedule: reduction ********/ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { - return SetRV(this, schedule::RFactor(this->state, this->GetSRef(loop_rv), factor_axis)); + return SetRV(schedule::RFactor(state_, this->GetSRef(loop_rv), factor_axis)); } BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, @@ -496,15 +503,14 @@ BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, Optional opt_loop_sref = opt_loop_rv.defined() ? // this->GetSRef(opt_loop_rv.value()) // : Optional(NullOpt); - return SetRV(this, - schedule::DecomposeReduction(this->state, // + return SetRV(schedule::DecomposeReduction(state_, // this->GetSRef(block_rv), // opt_loop_sref)); } void ConcreteScheduleNode::MergeReduction(const BlockRV& init_block_rv, const BlockRV& update_block_rv) { - schedule::MergeReduction(this->state, // + schedule::MergeReduction(state_, // this->GetSRef(init_block_rv), // this->GetSRef(update_block_rv)); } @@ -512,21 +518,31 @@ void ConcreteScheduleNode::MergeReduction(const BlockRV& init_block_rv, /******** Schedule: blockize / tensorize ********/ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, const String& exec_scope) { - return SetRV(this, schedule::Blockize(this->state, this->GetSRef(loop_rv), exec_scope)); + return SetRV(schedule::Blockize(state_, this->GetSRef(loop_rv), exec_scope)); } void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) { - schedule::Tensorize(this->state, this->GetSRef(loop_rv), intrin); + schedule::Tensorize(state_, this->GetSRef(loop_rv), intrin); } void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { - schedule::Tensorize(this->state, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name)); + schedule::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name)); } /******** FFI ********/ TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); -TVM_REGISTER_GLOBAL("tir.schedule.Schedule").set_body_typed(Schedule::Concrete); +TVM_REGISTER_GLOBAL("tir.schedule.Schedule") + .set_body_typed([](ObjectRef obj, int64_t seed, bool debug_mode) -> Schedule { + if (const auto* func = obj.as()) { + return Schedule::Concrete(GetRef(func), seed, debug_mode); + } + if (const auto* mod = obj.as()) { + return Schedule::Concrete(GetRef(mod), seed, debug_mode); + } + LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); + throw; + }); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 1bc4766bb7..2aabea3a08 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -18,27 +18,40 @@ */ #ifndef TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ - #include #include +#include + +namespace tvm { +namespace meta_schedule { +class ScheduleNode; +} // namespace meta_schedule +} // namespace tvm + namespace tvm { namespace tir { class ConcreteScheduleNode : public ScheduleNode { + friend class Schedule; + friend class meta_schedule::ScheduleNode; + public: using TSymbolTable = Map; - public: + protected: + /*! \brief The internal state of scheduling */ + ScheduleState state_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ - TSymbolTable symbol_table; - - mutable arith::Analyzer analyzer; + TSymbolTable symbol_table_; + /*! \brief A persistent stateless arithmetic analyzer. */ + std::unique_ptr analyzer_; + public: void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("state", &state); - v->Visit("symbol_table", &symbol_table); - // `analyzer` is not visitied + // `state_` is not visited + // `symbol_table_` is not visited + // `analyzer_` is not visitied } virtual ~ConcreteScheduleNode() = default; @@ -47,11 +60,13 @@ class ConcreteScheduleNode : public ScheduleNode { TVM_DECLARE_BASE_OBJECT_INFO(ConcreteScheduleNode, ScheduleNode); public: - Schedule Copy() const override; + ScheduleState state() const final { return state_; } - void Seed(int64_t seed = -1) override; + Schedule Copy() const override; - IRModule Module() const override; + void Seed(int64_t seed = -1) override { + // do nothing + } public: /******** Lookup random variables ********/ @@ -67,9 +82,13 @@ class ConcreteScheduleNode : public ScheduleNode { StmtSRef GetSRef(const LoopRV& loop_rv) const final; - StmtSRef GetSRef(const Stmt& stmt) const final; + void RemoveRV(const BlockRV& block_rv) final; + + void RemoveRV(const LoopRV& loop_rv) final; + + void RemoveRV(const VarRV& var_rv) final; - StmtSRef GetSRef(const StmtNode* stmt) const final; + using ScheduleNode::GetSRef; public: /******** Sampling ********/ @@ -167,55 +186,59 @@ class ConcreteScheduleNode : public ScheduleNode { void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) override; void Tensorize(const LoopRV& loop_rv, const String& intrin_name) override; -}; -/******** Utility functions ********/ + /******** Utility functions ********/ + protected: + void MakeCopy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const; + + void RemoveFromSymbolTable(const ObjectRef& rv); + + template + inline Array SetRV(const Array& srefs) { + Array result; + result.reserve(srefs.size()); + for (const StmtSRef& sref : srefs) { + T rv; + this->symbol_table_.Set(rv, sref); + result.push_back(rv); + } + return result; + } -template -inline Array SetRV(ConcreteScheduleNode* self, const Array& srefs) { - Array result; - result.reserve(srefs.size()); - for (const StmtSRef& sref : srefs) { + template + inline T SetRV(const StmtSRef& sref) { T rv; - self->symbol_table.Set(rv, sref); - result.push_back(rv); + this->symbol_table_.Set(rv, sref); + return rv; } - return result; -} - -template -inline T SetRV(ConcreteScheduleNode* self, const StmtSRef& sref) { - T rv; - self->symbol_table.Set(rv, sref); - return rv; -} - -inline Var SetRV(ConcreteScheduleNode* self, int64_t number) { - Var rv; - self->symbol_table.Set(rv, Integer(number)); - return rv; -} - -inline Array SetRV(ConcreteScheduleNode* self, const Array& numbers) { - Array result; - result.reserve(numbers.size()); - for (int64_t number : numbers) { + + inline Var SetRV(int64_t number) { Var rv; - self->symbol_table.Set(rv, Integer(number)); - result.push_back(rv); + this->symbol_table_.Set(rv, Integer(number)); + return rv; } - return result; -} - -template -inline Array FromRV(const ConcreteScheduleNode* self, const Array& rvs) { - Array result; - result.reserve(rvs.size()); - for (const T& rv : rvs) { - result.push_back(self->GetSRef(rv)); + + inline Array SetRV(const Array& numbers) { + Array result; + result.reserve(numbers.size()); + for (int64_t number : numbers) { + Var rv; + this->symbol_table_.Set(rv, Integer(number)); + result.push_back(rv); + } + return result; } - return result; -} + + template + inline Array FromRV(const Array& rvs) { + Array result; + result.reserve(rvs.size()); + for (const T& rv : rvs) { + result.push_back(this->GetSRef(rv)); + } + return result; + } +}; } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitives/bind_annotate.cc b/src/tir/schedule/primitives/bind_annotate.cc index 10d5f4750a..ab990e5076 100644 --- a/src/tir/schedule/primitives/bind_annotate.cc +++ b/src/tir/schedule/primitives/bind_annotate.cc @@ -125,7 +125,7 @@ void ParallelCompute(ScheduleState self, const StmtSRef& loop_sref, const ForKin // Now only support: // 1. All the blocks are complete below // 2. A single block below the loop - const BlockScope& scope = self->scopes.at(GetScopeSRef(loop_sref)); + const BlockScope& scope = self->scopes.at(GetScopeRoot(loop_sref)); bool is_compact_dataflow = scope->IsCompactDataFlow(loop_sref, GetChildBlocks(self, loop_sref, false)); if (!is_compact_dataflow) { @@ -202,7 +202,7 @@ void Pragma(ScheduleState self, const StmtSRef& loop_sref, const String& pragma_ void DoubleBuffer(ScheduleState self, const StmtSRef& block_sref) { const auto* block_ptr = block_sref->GetStmt(); CHECK(block_ptr) << "TypeError: double_buffer expects 'block' as its argument"; - const StmtSRef& parent_block_sref = GetScopeSRef(block_sref); + const StmtSRef& parent_block_sref = GetScopeRoot(block_sref); const auto* parent_block = parent_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(parent_block_sref); CHECK(scope->IsComplete(block_sref)) diff --git a/src/tir/schedule/primitives/blockize_tensorize.cc b/src/tir/schedule/primitives/blockize_tensorize.cc index 233f6948ba..c70b48d0f1 100644 --- a/src/tir/schedule/primitives/blockize_tensorize.cc +++ b/src/tir/schedule/primitives/blockize_tensorize.cc @@ -558,7 +558,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, const String& e auto outer_realize = BlockRealize(outer_bindings, division.back()->outer_extent, outer_block); self->Replace(loop_sref, outer_realize, {{inner_block, block}}); - UpdateScope(self, GetScopeSRef(self->stmt2ref.at(outer_block.get()))); + UpdateScope(self, GetScopeRoot(self->stmt2ref.at(outer_block.get()))); UpdateScope(self, self->stmt2ref.at(outer_block.get())); // Check loop binding @@ -579,7 +579,8 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, const String& e }; BindingValidator validator; validator.self = self; - validator(self->func->body); + const PrimFuncNode* func = GetRootPrimFunc(self, loop_sref); + validator(func->body); } return self->stmt2ref.at(outer_block.get()); } diff --git a/src/tir/schedule/primitives/cache_read_write.cc b/src/tir/schedule/primitives/cache_read_write.cc index c72e68f695..069ab5ebb2 100644 --- a/src/tir/schedule/primitives/cache_read_write.cc +++ b/src/tir/schedule/primitives/cache_read_write.cc @@ -554,7 +554,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& _block_sref, int i, BufferRegion cache_region(nullptr); if (!block_sref.same_as(root)) { // Find the parent scope - scope_sref = GetScopeSRef(block_sref); + scope_sref = GetScopeRoot(block_sref); // Check the block is not a output block ICHECK(!IsOutputBlock(block_sref, scope_sref)); // Find the region to be cache_read @@ -599,7 +599,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int i, ICHECK(block_sref->parent != nullptr) << "ValueError: `cache_write` cannot be applied to an input buffer"; // Find the parent scope - StmtSRef scope_sref = GetScopeSRef(block_sref); + StmtSRef scope_sref = GetScopeRoot(block_sref); CacheLocDetector::Detect(self, block_sref, scope_sref, &info); // Generate cache buffer Block cache_write_stage = MakeCacheStage( diff --git a/src/tir/schedule/primitives/compute_location.cc b/src/tir/schedule/primitives/compute_location.cc index 3155f6c475..eaf6128c84 100644 --- a/src/tir/schedule/primitives/compute_location.cc +++ b/src/tir/schedule/primitives/compute_location.cc @@ -567,13 +567,13 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l << block_sref->stmt->GetTypeKey(); CHECK(loop != nullptr) << "TypeError: 'compute_at' expects 'loop' to be a loop, but get type: " << loop_sref->stmt->GetTypeKey(); - const StmtSRef& parent_block_sref = GetScopeSRef(block_sref); + const StmtSRef& parent_block_sref = GetScopeRoot(block_sref); const auto* parent_block = parent_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(parent_block_sref); Array edges_to_pred = scope->GetPredecessors(block_sref); Array edges_to_succ = scope->GetSuccessors(block_sref); // Cond 0. `block` and `loop` are in the same scope - CHECK_EQ(parent_block_sref.get(), GetScopeSRef(loop_sref).get()) + CHECK_EQ(parent_block_sref.get(), GetScopeRoot(loop_sref).get()) << "ValueError: 'compute_at' expects 'block' and 'loop' be in the same block"; // Cond 1. 'block' is complete/reduction block CHECK(scope->IsComplete(block_sref) || scope->IsReduction(block_sref)) @@ -668,13 +668,13 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt CHECK(loop != nullptr) << "TypeError: 'reverse_compute_at' expects 'loop' to be a loop, but get type: " << loop_sref->stmt->GetTypeKey(); - const StmtSRef& parent_block_sref = GetScopeSRef(block_sref); + const StmtSRef& parent_block_sref = GetScopeRoot(block_sref); const auto* parent_block = parent_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(parent_block_sref); Array edges_to_pred = scope->GetPredecessors(block_sref); Array edges_to_succ = scope->GetSuccessors(block_sref); // Cond 0. `block` and `loop` are in the same scope - CHECK_EQ(parent_block_sref.get(), GetScopeSRef(loop_sref).get()) + CHECK_EQ(parent_block_sref.get(), GetScopeRoot(loop_sref).get()) << "ValueError: 'reverse_compute_at' expects 'block' and 'loop' be in the same block"; // Cond 1. 'block' is complete/reduction block CHECK(scope->IsComplete(block_sref) || scope->IsReduction(block_sref)) @@ -755,7 +755,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& block_sref) { * 2. block_sref if a complete Block */ const auto* block = block_sref->GetStmt(); - const StmtSRef& scope_block_sref = GetScopeSRef(block_sref); + const StmtSRef& scope_block_sref = GetScopeRoot(block_sref); const auto* scope_block = scope_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(scope_block_sref); CHECK(block->body.as()) @@ -790,7 +790,7 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref) { CHECK(block != nullptr) << "TypeError: 'reverse_compute_at' expects 'block' to be a block, but get type: " << block_sref->stmt->GetTypeKey(); - const StmtSRef& scope_block_sref = GetScopeSRef(block_sref); + const StmtSRef& scope_block_sref = GetScopeRoot(block_sref); const auto* scope_block = scope_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(scope_block_sref); // Cond 1. Check block_sref is complete diff --git a/src/tir/schedule/primitives/fuse_split_reorder.cc b/src/tir/schedule/primitives/fuse_split_reorder.cc index 009094fb19..d29c613afc 100644 --- a/src/tir/schedule/primitives/fuse_split_reorder.cc +++ b/src/tir/schedule/primitives/fuse_split_reorder.cc @@ -178,7 +178,7 @@ StmtSRef Fuse(ScheduleState self, const StmtSRef& outer_sref, const StmtSRef& in Array outer_children = GetChildren(GetRef(outer)); CHECK(outer_children.size() == 1 && outer_children[0].get() == inner) << "ValueError: 'fuse' expects 'inner' to be the only child of 'outer'"; - CHECK(GetScopeSRef(outer_sref).get() == GetScopeSRef(inner_sref).get()) + CHECK(GetScopeRoot(outer_sref).get() == GetScopeRoot(inner_sref).get()) << "ValueError: 'fuse' expects 'inner' and 'outer' to be in the same block scope"; // Step 2. Create fused loop var and replace the loop var used in inner and outer loop arith::Analyzer analyzer; @@ -244,7 +244,7 @@ void Reorder(ScheduleState self, const Array& order) { std::unordered_map successor; // Gather all the loops under parent_block int n_loops_not_found = order.size(); - for (const StmtSRefNode* loop : GetLoopsPostOrder(self, GetScopeSRef(order[0]))) { + for (const StmtSRefNode* loop : GetLoopsPostOrder(self, GetScopeRoot(order[0]))) { bool is_in_reorder_list = loops.count(loop); bool has_inner_loop = successor.count(loop); if (is_in_reorder_list || has_inner_loop) { diff --git a/src/tir/schedule/primitives/reduction.cc b/src/tir/schedule/primitives/reduction.cc index ea70db2d11..32e0dc140d 100644 --- a/src/tir/schedule/primitives/reduction.cc +++ b/src/tir/schedule/primitives/reduction.cc @@ -68,7 +68,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, CHECK(ListContainsElement(loops, loop_sref)) << "ValueError: 'decompose_reduction' expect the loop to be an ancestor of block"; // Cond 1. Check block is reduction - CHECK(self->scopes.at(GetScopeSRef(block_sref))->IsReduction(block_sref)) + CHECK(self->scopes.at(GetScopeRoot(block_sref))->IsReduction(block_sref)) << "decompose_reduction expect the block to be a reduction block"; // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { @@ -247,10 +247,10 @@ void MergeReduction(ScheduleState self, const StmtSRef& init_sref, const StmtSRe ExprDeepEqual equal; CHECK(equal(init_realize->predicate, update_realize->predicate)) << "ValueError: 'merge_reduction' expects the predicate of init and update to be the same"; - const StmtSRef& scope = GetScopeSRef(init_sref); + const StmtSRef& scope = GetScopeRoot(init_sref); StmtSRef lca = LowestCommonAncestor({init_sref, update_sref}, scope); // Cond 1. Check init_block is under the same scope with update_sref - CHECK_EQ(scope.get(), GetScopeSRef(update_sref).get()) + CHECK_EQ(scope.get(), GetScopeRoot(update_sref).get()) << "TypeError: 'merge_reduction' expects the 'init' and 'update' to be under the same scope"; // Cond 3. Write region of 'init' is the same as that of 'update' under LCA { @@ -334,7 +334,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) StmtSRef block_sref = child_blocks[0]; BlockRealize block_realize = GetBlockRealize(block_sref); Block block = block_realize->block; - BlockScope scope = self->scopes.at(GetScopeSRef(block_sref)); + BlockScope scope = self->scopes.at(GetScopeRoot(block_sref)); CHECK(scope->IsReduction(block_sref)) << "ValueError: can only rfactor a reduction block"; // Collect the information of loops and blocks. @@ -623,7 +623,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) } // Insert the rfactor buffer into the scope block's allocation. - StmtSRef scope_sref = GetScopeSRef(block_sref); + StmtSRef scope_sref = GetScopeRoot(block_sref); Block scope_block = GetRef(scope_sref->GetStmt()), new_scope_block = scope_block; new_scope_block.CopyOnWrite()->alloc_buffers.push_back(rf_buf); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 4dd30526bb..f0d9d0ff6f 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -30,14 +30,15 @@ LoopRV::LoopRV() { this->data_ = make_object(); } /**************** GetSRef ****************/ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { - auto it = this->state->stmt2ref.find(stmt); - if (it == this->state->stmt2ref.end()) { + ScheduleState state = this->state(); + auto it = state->stmt2ref.find(stmt); + if (it == state->stmt2ref.end()) { LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR"; } return it->second; } -StmtSRef ScheduleNode::GetSRef(const Stmt& stmt) const { return GetSRef(stmt.get()); } +StmtSRef ScheduleNode::GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); } /**************** FFI ****************/ @@ -46,7 +47,9 @@ TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleModule") // - .set_body_method(&ScheduleNode::Module); + .set_body_method(&ScheduleNode::mod); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // + .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // .set_body_method(&ScheduleNode::Seed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 41cdcbd285..82c4c88f31 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -76,6 +76,33 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new sref->stmt = new_stmt; } +/*! + * \brief Get PrimFunc and GlobalVar that the root block belongs to + * \param mod The IRModule + * \param root_block The root block of the PrimFunc + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the root block belongs to + */ +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var) { + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (const auto* realize = func->body.as()) { + if (realize->block.get() == root_block) { + *result_g_var = g_var; + return func; + } + } + } + } + LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the " + "statement:\n" + << GetRef(root_block); + throw; +} + /**************** Creation ****************/ /*! \brief A helper class to create a new ScheduleStateNode */ @@ -85,21 +112,27 @@ class StateCreator : private StmtVisitor { * \brief The entry function * \param self The schedule state to be completed */ - static ObjectPtr Create(PrimFunc func, bool debug_mode) { + static ObjectPtr Create(IRModule mod, bool debug_mode) { ObjectPtr n = make_object(); ScheduleStateNode* self = n.get(); - // Set `n->func` - n->func = std::move(func); + // Set `n->mod` + n->mod = std::move(mod); // Set `n->debug_mode` n->debug_mode = debug_mode; - // Set `n->stmt2ref` - // Set `n->scopes` - (StateCreator(self)).VisitStmt(self->func->body); + // Set `n->stmt2ref` and `n->scopes` + StateCreator creator(self); + for (const auto& kv : n->mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + creator.VisitStmt(func->body); + } + } return n; } private: - explicit StateCreator(ScheduleStateNode* self) : self_(self), srefs_{} {} + explicit StateCreator(ScheduleStateNode* self) + : self_(self), srefs_{}, realizes_{}, block_frames_{} {} /*! * \brief Add a new statement to the stack, which becomes the current scope @@ -139,7 +172,7 @@ class StateCreator : private StmtVisitor { sref->binding_valid = ValidateBlockBinding(GetRef(realizes_.back()), loop_var_ranges); // Collect `scopes` info - self_->scopes[sref] = BlockScope(std::move(block_frames_.back().leaf_blocks)); + self_->scopes.Set(sref, BlockScope(std::move(block_frames_.back().leaf_blocks))); block_frames_.pop_back(); // Update parent scope if exists if (!block_frames_.empty()) { @@ -183,8 +216,8 @@ class StateCreator : private StmtVisitor { /**************** Constructor ****************/ -ScheduleState::ScheduleState(PrimFunc func, bool debug_mode) { - data_ = StateCreator::Create(func, debug_mode); +ScheduleState::ScheduleState(IRModule mod, bool debug_mode) { + data_ = StateCreator::Create(mod, debug_mode); // Verify the region cover ScheduleState self = GetRef(get()); for (const auto& it : self->scopes) { @@ -193,6 +226,9 @@ ScheduleState::ScheduleState(PrimFunc func, bool debug_mode) { } } +ScheduleState::ScheduleState(PrimFunc func, bool debug_mode) + : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {} + /**************** Replace ****************/ /*! @@ -442,7 +478,7 @@ class SRefUpdater : public StmtVisitor { VisitStmt(op->body); parents_.pop_back(); // Additionally, need to update the scope because the block is changed - self_->scopes[sref] = BlockScope(tir::GetChildBlocks(self_, sref)); + self_->scopes.Set(sref, BlockScope(tir::GetChildBlocks(self_, sref))); } void VisitStmt_(const SeqStmtNode* seq_stmt) final { @@ -616,20 +652,34 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace. // Along the way, because we create a new ancestor path, // we need to update those sref points from old ancestors to newly created ones - // `num_copy_steps` is the maximum number of hops until we need to copy - // To reach a node that can be mutated in-place, it needs `num_copy_steps + 1` hops + // Variables: + // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that can + // be mutated in-place, it needs `num_copy_steps + 1` hops. + // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to. + // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to int num_copy_steps = -1; + bool need_module_copy = false; + const PrimFuncNode* g_func = nullptr; + GlobalVar g_var; { int i = 0; - for (const StmtSRefNode* ptr = src_sref.get(); ptr != nullptr; ptr = ptr->parent, ++i) { - if (!ptr->stmt->unique()) { + const StmtSRefNode* p = src_sref.get(); + for (;;) { + if (!p->stmt->unique()) { num_copy_steps = i; } + if (p->parent == nullptr) { + break; + } + ++i; + p = p->parent; } - // If the function itself is not unique, then we assume the root is not unique - if (!this->func.unique()) { - num_copy_steps = i; - } + // Find `g_func` and `g_var` where the `src_sref` is in + g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var); + need_module_copy = num_copy_steps == i || // + !this->mod.unique() || // + !this->mod->functions.unique() || // + !g_func->unique(); } // Loop invariant: // @@ -646,8 +696,8 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // 3) `tgt_stmt` is of type Loop or Block StmtSRefNode* child_sref = src_sref.get(); Stmt child_tgt_stmt = std::move(tgt_stmt); - for (int i = 0; i <= num_copy_steps && child_sref->parent != nullptr; ++i) { - bool parent_unique = (i == num_copy_steps); + for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) { + bool can_cow_parent = !need_module_copy && i == num_copy_steps; // replacing `child_sref->stmt` to `child_tgt_stmt`. const StmtNode* parent_stmt = child_sref->parent->stmt; const StmtNode* child_src_stmt = child_sref->stmt; @@ -662,9 +712,9 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`, Stmt new_parent_stmt = ChildReplacer::Mutate(parent_stmt, child_src_stmt, child_tgt_stmt, /*seq_index=*/child_sref->seq_index, - /*allow_copy_on_write=*/parent_unique); + /*allow_copy_on_write=*/can_cow_parent); // Step 2.3. Go to next parent - if (parent_unique) { + if (can_cow_parent) { // If the node can be directly mutated inplace, // then there is no need to update its parent and the function break; @@ -673,22 +723,35 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ child_sref = child_sref->parent; } // Step 3. Handle the case that we mutate the root - if (child_sref->parent == nullptr) { + if (need_module_copy) { // From the loop invariant, upon exit, while its subtree is properly set, // `child_sref` is not properly to `child_tgt_stmt` yet. if (src_sref->parent != nullptr) { // Not replacing a root UpdateSRef(this, child_sref, child_tgt_stmt.get()); } - // Update the body of the `this->func` - PrimFuncNode* new_func = this->func.CopyOnWrite(); - // Assign `child_tgt_stmt`, which is a Block, to the root block - const auto* realize = TVM_TYPE_AS(realize, func->body, BlockRealizeNode); + // Ensure the uniqueness of `this->mod` and `this->mod->functions` + IRModuleNode* new_mod = this->mod.CopyOnWrite(); + MapNode* new_map = new_mod->functions.CopyOnWrite(); + // Move out the PrimFunc where the sref belong while ensuring uniqueness + PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); + ICHECK(ref_new_func.get() == g_func); + PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); + // If `g_func` was not unique, after the 3 lines above: + // `ref_new_func` points to a unique PrimFunc + // `g_func` points to the previous PrimFunc if it is not unique + // If `g_func` was unique, after the 3 lines above: + // `ref_new_func` points to the same unique function that `g_func` points to + // Then, move the `ref_new_func` back + new_map->at(g_var) = std::move(ref_new_func); + // Update the body of the function the sref belongs to Assign + const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode); + // Make `child_tgt_stmt` the root block const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode); ObjectPtr new_realize = make_object(*realize); new_realize->block = GetRef(child_block); new_func->body = BlockRealize(std::move(new_realize)); - this->func = GetRef(new_func); + this->mod = GetRef(new_mod); } if (this->debug_mode) { VerifySRefTree(GetRef(this)); @@ -699,7 +762,16 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") - .set_body_typed([](PrimFunc func, bool debug_mode) { return ScheduleState(func, debug_mode); }); + .set_body_typed([](ObjectRef obj, bool debug_mode) { + if (const auto* func = obj.as()) { + return ScheduleState(GetRef(func), debug_mode); + } + if (const auto* mod = obj.as()) { + return ScheduleState(GetRef(mod), debug_mode); + } + LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); + throw; + }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") .set_body_method(&ScheduleStateNode::Replace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") @@ -707,11 +779,6 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") auto it = self->stmt2ref.find(stmt.get()); return it != self->stmt2ref.end() ? it->second : Optional(NullOpt); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetScope") - .set_body_typed([](ScheduleState self, StmtSRef block_sref) -> Optional { - auto it = self->scopes.find(block_sref); - return it != self->scopes.end() ? it->second : Optional(NullOpt); - }); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c7fa86d9e3..57fac1707b 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -54,7 +54,7 @@ namespace tir { << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None"); #define TVM_TYPE_AS_E(Result, From, Type) \ - From.as(); \ + (From).as(); \ ICHECK(Result) #define TVM_TYPE_AS(Result, From, Type) \ @@ -76,7 +76,7 @@ inline String Repr(const PrimFunc& func) { return s; } -inline String Repr(const Schedule& self) { return Repr(self->Module()); } +inline String Repr(const Schedule& self) { return Repr(self->mod()); } /*! * \brief Convert a tvm::runtime::Array to std::vector @@ -198,7 +198,7 @@ BufferRegion RelaxRegion(const StmtSRef& block_sref, const StmtSRef& root, * \brief remove the AST leaf and its parent subtree which has only one leaf * \param sref The sref of Block/Loop to be removed * \param root The AST root - * \return The orginal stmt and the removed stmt of the subtree rooted by the parent node + * \return The original stmt and the removed stmt of the subtree rooted by the parent node */ std::pair RemoveLeaf(StmtSRef sref, const StmtSRef& root); @@ -227,7 +227,7 @@ bool StmtExprContainsVar(const ObjectRef& obj, const std::vector& vars); bool StmtExprContainsVar(const ObjectRef& obj, const std::unordered_set& vars); inline void UpdateScope(ScheduleState self, const StmtSRef& sref) { - self->scopes[sref] = BlockScope(tir::GetChildBlocks(self, sref)); + self->scopes.Set(sref, BlockScope(tir::GetChildBlocks(self, sref))); } class StmtReplacer : public StmtMutator { diff --git a/tests/python/meta_schedule/test_gemm_end_to_end.py b/tests/python/meta_schedule/test_gemm_end_to_end.py index 38d68314bd..8ffda35a3d 100644 --- a/tests/python/meta_schedule/test_gemm_end_to_end.py +++ b/tests/python/meta_schedule/test_gemm_end_to_end.py @@ -182,7 +182,7 @@ def schedule_matmul(sch): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) @pytest.mark.skip(reason="needs RPC") @@ -210,7 +210,7 @@ def schedule_matmul(sch: ms.Schedule): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) @pytest.mark.skip(reason="needs RPC") @@ -261,7 +261,7 @@ def schedule_conv2d(sch): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) @pytest.mark.skip(reason="needs RPC") @@ -375,7 +375,7 @@ def test_matmul_evolutionary_end_to_end(): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) @pytest.mark.skip(reason="needs RPC") @@ -429,7 +429,7 @@ def test_matmul_evolutionary_xgb(): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) if __name__ == "__main__": diff --git a/tests/python/meta_schedule/test_integration_cpu.py b/tests/python/meta_schedule/test_integration_cpu.py index 252bcd167d..60c7b49be3 100644 --- a/tests/python/meta_schedule/test_integration_cpu.py +++ b/tests/python/meta_schedule/test_integration_cpu.py @@ -97,7 +97,7 @@ def test_matmul_post_order_apply(): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) @pytest.mark.skip(reason="needs RPC") @@ -136,7 +136,7 @@ def test_matmul_relu_post_order_apply(): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) @pytest.mark.skip(reason="needs RPC") @@ -176,7 +176,7 @@ def test_conv1d_post_order_apply(): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) if __name__ == "__main__": diff --git a/tests/python/meta_schedule/test_integration_cuda.py b/tests/python/meta_schedule/test_integration_cuda.py index 8279155d4a..cdd77fbacd 100644 --- a/tests/python/meta_schedule/test_integration_cuda.py +++ b/tests/python/meta_schedule/test_integration_cuda.py @@ -29,9 +29,9 @@ logging.basicConfig() logging.getLogger("meta_schedule").setLevel(logging.DEBUG) -RPC_KEY = "jetson-agx-xavier" -TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") -TARGET_HOST = tvm.target.Target("llvm -mcpu=carmel -mtriple=aarch64-linux-gnu") +RPC_KEY = "rtx-3070" +TARGET = tvm.target.Target("nvidia/geforce-rtx-3070") +TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( stages=[ ms.rule.multi_level_tiling( @@ -101,7 +101,7 @@ def test_integration_matmul(): if sch is None: print("No valid schedule found") else: - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) if __name__ == "__main__": diff --git a/tests/python/meta_schedule/test_integration_cuda_tensorcore.py b/tests/python/meta_schedule/test_integration_cuda_tensorcore.py index f551ee5ec6..1211d11006 100644 --- a/tests/python/meta_schedule/test_integration_cuda_tensorcore.py +++ b/tests/python/meta_schedule/test_integration_cuda_tensorcore.py @@ -118,7 +118,7 @@ def fetch_to_shared(block, idx, ndim): sch = ms.Schedule(func=workload, seed=1024) schedule(sch) - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) def test_integration_conv2d_nchwc(): @@ -217,7 +217,7 @@ def schedule(sch: ms.Schedule): # Decompose reduction sch.decompose_reduction(block, thread_idx) # sch.tensorize(i_tc, "test.tensorcore.wmma") - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) sch = ms.Schedule(func=workload) schedule(sch) diff --git a/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py b/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py index 72bd9833b9..d2419a9278 100644 --- a/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py +++ b/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py @@ -195,7 +195,7 @@ def f_create_args(ctx): return [X, W_data, W_indices, W_indptr, Y] sch = meta_schedule_sparse_dense_llvm(func, f_create_args) - func = sch.module + func = sch.mod func = tvm.build(func) Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) diff --git a/tests/python/meta_schedule/test_meta_schedule_class.py b/tests/python/meta_schedule/test_meta_schedule_class.py index 14e22cb448..3cd9b18715 100644 --- a/tests/python/meta_schedule/test_meta_schedule_class.py +++ b/tests/python/meta_schedule/test_meta_schedule_class.py @@ -363,7 +363,7 @@ def _check_serialization(sch, func) -> ms.Schedule: record = sch.trace.serialize() new_sch = ms.Schedule(func) ms.Trace.deserialize(record, new_sch) - assert tvm.ir.structural_equal(new_sch.module, sch.module) + assert tvm.ir.structural_equal(new_sch.mod["main"], sch.mod["main"]) py_repr = "\n".join(sch.trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr @@ -372,8 +372,8 @@ def _check_serialization(sch, func) -> ms.Schedule: def test_meta_schedule_creation(): sch = ms.Schedule(func=matmul) - assert tvm.ir.structural_equal(sch.orig_func, sch.module) assert len(sch.trace.insts) == 0 + assert len(sch.trace.decisions) == 0 _check_serialization(sch, func=matmul) @@ -535,9 +535,9 @@ def test_meta_schedule_split(): sch = ms.Schedule(func=matmul) i, _, _ = sch.get_axes(sch.get_block("matmul")) i_0, i_1, i_2 = [sch.get_sref(i).stmt for i in sch.split(i, factors=[-1, 8, 32])] - assert tvm.ir.structural_equal(i_0, sch.module.body.block.body) - assert tvm.ir.structural_equal(i_1, sch.module.body.block.body.body) - assert tvm.ir.structural_equal(i_2, sch.module.body.block.body.body.body) + assert tvm.ir.structural_equal(i_0, sch.mod["main"].body.block.body) + assert tvm.ir.structural_equal(i_1, sch.mod["main"].body.block.body.body) + assert tvm.ir.structural_equal(i_2, sch.mod["main"].body.block.body.body.body) _check_serialization(sch, func=matmul) @@ -547,7 +547,7 @@ def test_meta_schedule_reorder(): sch.reorder(i_2, i_1, i_0) i_0, i_1, i_2 = [sch.get_sref(i).stmt for i in [i_0, i_1, i_2]] - tir_sch = tir.Schedule(func=matmul, debug_mode=True) + tir_sch = tir.Schedule(matmul, debug_mode=True) ti_0, ti_1, ti_2 = tir_sch.get_axes(tir_sch.get_block("matmul")) tir_sch.reorder(ti_2, ti_1, ti_0) @@ -563,7 +563,7 @@ def test_meta_schedule_compute_at(): matmul_block = sch.get_block("matmul") _, _, i_2 = sch.get_axes(matmul_block) sch.compute_at(plus_one_block, i_2) - assert tvm.ir.structural_equal(sch.module, plus_one_matmul_fused) + assert tvm.ir.structural_equal(sch.mod["main"], plus_one_matmul_fused) _check_serialization(sch, func=plus_one_matmul) @@ -573,7 +573,7 @@ def test_meta_schedule_reverse_compute_at(): matmul_block = sch.get_block("matmul") _, i_1, _ = sch.get_axes(matmul_block) sch.reverse_compute_at(relu_block, i_1) - assert tvm.ir.structural_equal(sch.module, matmul_relu_fused) + assert tvm.ir.structural_equal(sch.mod["main"], matmul_relu_fused) _check_serialization(sch, func=matmul_relu) @@ -581,7 +581,7 @@ def test_meta_schedule_compute_inline(): sch = ms.Schedule(func=elementwise) block = sch.get_block(name="B") sch.compute_inline(block=block) - assert tvm.ir.structural_equal(sch.module, elementwise_inlined) + assert tvm.ir.structural_equal(sch.mod["main"], elementwise_inlined) _check_serialization(sch, func=elementwise) @@ -590,7 +590,7 @@ def test_meta_schedule_cache_read(): block = sch.get_block("matmul") sch.cache_read(block, i=1, storage_scope="local") sch.cache_read(block, i=2, storage_scope="local") - assert tvm.ir.structural_equal(sch.module, matmul_cache_read) + assert tvm.ir.structural_equal(sch.mod["main"], matmul_cache_read) _check_serialization(sch, func=matmul) @@ -598,7 +598,7 @@ def test_meta_schedule_cache_write(): sch = ms.Schedule(func=matmul) block = sch.get_block("matmul") sch.cache_write(block, i=0, storage_scope="local") - assert tvm.ir.structural_equal(sch.module, matmul_cache_write) + assert tvm.ir.structural_equal(sch.mod["main"], matmul_cache_write) _check_serialization(sch, func=matmul) @@ -607,7 +607,7 @@ def test_meta_schedule_blockize(): block = sch.get_block("matmul") _, _, k = sch.get_axes(block) sch.blockize(k) - assert tvm.ir.structural_equal(sch.module, matmul_blockized) + assert tvm.ir.structural_equal(sch.mod["main"], matmul_blockized) _check_serialization(sch, func=matmul) @@ -616,7 +616,7 @@ def test_meta_schedule_decompose_reduction(): block = sch.get_block("matmul") _, _, k = sch.get_axes(block) sch.decompose_reduction(block, k) - assert tvm.ir.structural_equal(sch.module, matmul_decomposed) + assert tvm.ir.structural_equal(sch.mod["main"], matmul_decomposed) _check_serialization(sch, func=matmul) @@ -631,7 +631,7 @@ def test_meta_schedule_tensorize(): sch.reorder(i_o, j_o, k_o, i_i, j_i, k_i) sch.decompose_reduction(block, k_o) sch.tensorize(i_i, "ms_test.tensor_intrin") - assert tvm.ir.structural_equal(sch.module, matmul_tensorized) + assert tvm.ir.structural_equal(sch.mod["main"], matmul_tensorized) _check_serialization(sch, func=matmul) diff --git a/tests/python/meta_schedule/test_meta_schedule_sketch_cpu.py b/tests/python/meta_schedule/test_meta_schedule_sketch_cpu.py index ead9520ac4..6815182054 100644 --- a/tests/python/meta_schedule/test_meta_schedule_sketch_cpu.py +++ b/tests/python/meta_schedule/test_meta_schedule_sketch_cpu.py @@ -44,6 +44,7 @@ def _fix_sampling_tile_size( sch: ms.Schedule, + func: tir.PrimFunc, possible_decisions: List[List[List[int]]], expected: List[tir.PrimFunc], ): @@ -62,9 +63,9 @@ def _fix_sampling_tile_size( for inst, decision in zip(insts, decisions): new_decisions[inst] = decision trace = ms.Trace(sch.trace.insts, new_decisions) - new_sch = ms.Schedule(sch.orig_func) + new_sch = ms.Schedule(func) trace.apply(new_sch) - results = [tvm.ir.structural_equal(new_sch.module, i) for i in expected] + results = [tvm.ir.structural_equal(new_sch.mod["main"], i) for i in expected] if sum(results) >= 1: return assert False @@ -77,7 +78,7 @@ def _get_support(func: tir.PrimFunc, task_name: str): def _debug(support: List[ms.Schedule]): for i, sch in enumerate(support): print(f"###### {i}") - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod["main"])) for inst in sch.trace.insts: if inst in sch.trace.decisions: print(sch.trace.decisions[inst], ",") @@ -222,16 +223,19 @@ def test_meta_schedule_sketch_cpu_matmul(): assert len(support) == 3 _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[1], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[2], + func=func, possible_decisions=possible_decisions, expected=expected, ) @@ -385,16 +389,19 @@ def test_meta_schedule_sketch_cpu_matmul_relu(): ] _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[1], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[2], + func=func, possible_decisions=possible_decisions, expected=expected, ) @@ -615,16 +622,19 @@ def test_meta_schedule_sketch_cpu_conv2d_nchw(): ] _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[1], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[2], + func=func, possible_decisions=possible_decisions, expected=expected, ) @@ -853,16 +863,19 @@ def test_meta_schedule_sketch_cpu_conv2d_nchw_bias_bn_relu(): # pylint: disable ] _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[1], + func=func, possible_decisions=possible_decisions, expected=expected, ) _fix_sampling_tile_size( sch=support[2], + func=func, possible_decisions=possible_decisions, expected=expected, ) @@ -884,6 +897,7 @@ def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): possible_decisions = [[]] _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) diff --git a/tests/python/meta_schedule/test_meta_schedule_sketch_cpu_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_sketch_cpu_tensorize.py index cef859fc86..879399427c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_sketch_cpu_tensorize.py +++ b/tests/python/meta_schedule/test_meta_schedule_sketch_cpu_tensorize.py @@ -59,7 +59,7 @@ def test_meta_schedule_sketch_cpu_matmul_dot(): schs = space.get_support(task=task) for sch in schs: space.postprocess(task, sch) - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod)) if __name__ == "__main__": diff --git a/tests/python/meta_schedule/test_meta_schedule_sketch_cuda.py b/tests/python/meta_schedule/test_meta_schedule_sketch_cuda.py index cdccd930e5..fa850a6cc6 100644 --- a/tests/python/meta_schedule/test_meta_schedule_sketch_cuda.py +++ b/tests/python/meta_schedule/test_meta_schedule_sketch_cuda.py @@ -54,6 +54,7 @@ def _fix_sampling_tile_size( sch: ms.Schedule, + func: tir.PrimFunc, possible_decisions: List[List[List[int]]], expected: List[tir.PrimFunc], ): @@ -72,9 +73,9 @@ def _fix_sampling_tile_size( for inst, decision in zip(insts, decisions): new_decisions[inst] = decision trace = ms.Trace(sch.trace.insts, new_decisions) - new_sch = ms.Schedule(sch.orig_func) + new_sch = ms.Schedule(func) trace.apply(new_sch) - results = [tvm.ir.structural_equal(new_sch.module, i) for i in expected] + results = [tvm.ir.structural_equal(new_sch.mod["main"], i) for i in expected] if sum(results) >= 1: return assert False @@ -94,7 +95,7 @@ def _get_support(func: tir.PrimFunc, task_name: str): def _debug(support: List[ms.Schedule]): for i, sch in enumerate(support): print(f"###### {i}") - print(tvm.script.asscript(sch.module)) + print(tvm.script.asscript(sch.mod["main"])) for inst in sch.trace.insts: if inst in sch.trace.decisions: print(sch.trace.decisions[inst], ",") @@ -173,6 +174,7 @@ def test_meta_schedule_sketch_cuda_matmul(): assert len(support) == 1 _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) @@ -282,6 +284,7 @@ def test_meta_schedule_sketch_cuda_conv2d_nchw_bias_bn_relu(): # pylint: disabl assert len(support) == 1 _fix_sampling_tile_size( sch=support[0], + func=func, possible_decisions=possible_decisions, expected=expected, ) diff --git a/tests/python/meta_schedule/test_meta_schedule_tensorize_rule.py b/tests/python/meta_schedule/test_meta_schedule_tensorize_rule.py index d6769ba237..8e91b1cfd5 100644 --- a/tests/python/meta_schedule/test_meta_schedule_tensorize_rule.py +++ b/tests/python/meta_schedule/test_meta_schedule_tensorize_rule.py @@ -35,7 +35,7 @@ def _check_sketch(result, expected): for x in result: found = False for y in expected: - if tvm.ir.structural_equal(x.module, y): + if tvm.ir.structural_equal(x.mod["main"], y): found = True break assert found diff --git a/tests/python/meta_schedule/test_resnet_end_to_end.py b/tests/python/meta_schedule/test_resnet_end_to_end.py index 1014a1ad9b..9ef9c7b529 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end.py @@ -159,7 +159,7 @@ def test_end_to_end_resnet(log): ] ), ) - tuned_result[target][func] = sch.module + tuned_result[target][func] = sch.mod with tvm.transform.PassContext(config={"relay.with_tir_schedule": True}): lib = relay.build_module.build(mod, TARGET, params=params, tune_result=tuned_result) diff --git a/tests/python/tir/conv_tensorcore_demo.py b/tests/python/tir/conv_tensorcore_demo.py index 05eeac5ac4..0ae917aeb7 100644 --- a/tests/python/tir/conv_tensorcore_demo.py +++ b/tests/python/tir/conv_tensorcore_demo.py @@ -333,9 +333,9 @@ def test_tensorcore(): s.tensorize(s.get_axes(AF)[-2], tir.TensorIntrin(load_a_desc, load_a_intrin)) s.tensorize(s.get_axes(WF)[-2], tir.TensorIntrin(load_b_desc, load_b_intrin)) - print(tvm.script.asscript(s.module)) - print(tvm.lower(s.module, None, simple_mode=True)) - build_and_test(conv, s.module) + print(tvm.script.asscript(s.mod["main"])) + print(tvm.lower(s.mod["main"], None, simple_mode=True)) + build_and_test(conv, s.mod["main"]) if __name__ == "__main__": diff --git a/tests/python/tir/gemm_cpu_demo.py b/tests/python/tir/gemm_cpu_demo.py index 87d527a154..629e6cc289 100644 --- a/tests/python/tir/gemm_cpu_demo.py +++ b/tests/python/tir/gemm_cpu_demo.py @@ -61,12 +61,12 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: c = tvm.nd.array(np.zeros((M, N)).astype("float32")) -def build_and_test(func): - build_func = tvm.build(func, target=target) +def build_and_test(mod): + build_func = tvm.build(mod["main"], target=target) build_func(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5) evaluator = build_func.time_evaluator(build_func.entry_name, ctx, number=1) - print(tvm.script.asscript(func)) + print(tvm.script.asscript(mod)) return evaluator(a, b, c).mean @@ -91,10 +91,10 @@ def build_and_test(func): j_o, j_i = s.split(j, factor=bn) k_o, k_i = s.split(k, factor=4) s.reorder(i_o, j_o, k_o, k_i, i_i, j_i) -func_opt1 = s.module +func_opt1 = s.mod # s.decompose_reduction(update, j_o) -print("Opt1: %f" % build_and_test(s.module)) +print("Opt1: %f" % build_and_test(s.mod)) ################################################################################################ # Vectorization @@ -111,10 +111,10 @@ def build_and_test(func): i_o, j_o, k_o, k_i, i_i, j_i = s.get_axes(update) s.vectorize(j_i) -func_opt2 = s.module +func_opt2 = s.mod s.decompose_reduction(update, j_o) -print("Opt2: %f" % build_and_test(s.module)) +print("Opt2: %f" % build_and_test(s.mod)) ################################################################################################ # Loop Permutation @@ -130,11 +130,11 @@ def build_and_test(func): i_o, j_o, k_o, k_i, i_i, j_i = s.get_axes(update) s.reorder(i_o, j_o, k_o, i_i, k_i, j_i) -func_opt3 = s.module +func_opt3 = s.mod s.decompose_reduction(update, j_o) -print("Opt3: %f" % build_and_test(s.module)) +print("Opt3: %f" % build_and_test(s.mod)) ################################################################################################ @@ -194,10 +194,10 @@ def matmul_packed(a: ty.handle, b: ty.handle, c: ty.handle) -> None: k_o, k_i = s.split(k, factor=4) s.reorder(i_o, j_o, k_o, i_i, k_i, j_i) s.vectorize(j_i) -func_opt3 = s.module +func_opt3 = s.mod s.decompose_reduction(update, j_o) -print("Opt4: %f" % build_and_test(s.module)) +print("Opt4: %f" % build_and_test(s.mod)) ################################################################################################ # Write cache for blocks @@ -227,10 +227,10 @@ def matmul_packed(a: ty.handle, b: ty.handle, c: ty.handle) -> None: x, y, z = s.get_axes(packedB) s.vectorize(z) s.parallel(x) -func_opt5 = s.module +func_opt5 = s.mod s.decompose_reduction(cached_update, j_o) -print("Opt5: %f" % build_and_test(s.module)) +print("Opt5: %f" % build_and_test(s.mod)) ################################################################################################### # Parallel @@ -244,7 +244,7 @@ def matmul_packed(a: ty.handle, b: ty.handle, c: ty.handle) -> None: s.parallel(i_o) s.decompose_reduction(cached_update, j_o) -print("Opt6: %f" % build_and_test(s.module)) +print("Opt6: %f" % build_and_test(s.mod)) ################################################################################################### diff --git a/tests/python/tir/gemm_gpu_demo.py b/tests/python/tir/gemm_gpu_demo.py index 4621b76d24..bf41940b57 100644 --- a/tests/python/tir/gemm_gpu_demo.py +++ b/tests/python/tir/gemm_gpu_demo.py @@ -18,7 +18,7 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: n = 2048 -device = 'cuda' +device = "cuda" ctx = tvm.context(device, 0) mod = tvm.script.create_module({"matmul": matmul}) @@ -31,12 +31,12 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: c = tvm.nd.array(np.zeros((n, n)).astype("float32"), ctx) -def build_and_test(func): +def build_and_test(mod): if not ctx.exist: print("Skip because %s is not enabled" % device) return - f = tvm.build(func, target=device) - print(tvm.script.asscript(func)) + f = tvm.build(mod["main"], target=device) + print(tvm.script.asscript(mod)) f(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np.T, b_np), rtol=1e-5) @@ -121,5 +121,11 @@ def build_and_test(func): s.decompose_reduction(CC, decompose_pos) with tvm.transform.PassContext( - config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}}): - build_and_test(s.module) + config={ + "tir.UnrollLoop": { + "auto_max_step": 128, + "explicit_unroll": device != "cuda", + }, + } +): + build_and_test(s.mod) diff --git a/tests/python/tir/test_block_dependency.py b/tests/python/tir/test_block_dependency.py index 0a5a5c9192..77d3d75da0 100644 --- a/tests/python/tir/test_block_dependency.py +++ b/tests/python/tir/test_block_dependency.py @@ -50,11 +50,11 @@ def test_element_wise_dependency(): block_b = s.get_sref(s.get_block("B")) block_c = s.get_sref(s.get_block("C")) # Check get_predecessors - (predecessor_c,) = s.state.scope(root).get_predecessors(block_c) + (predecessor_c,) = s.state.scopes[root].get_predecessors(block_c) assert predecessor_c.dst.same_as(block_b) assert predecessor_c.type == tir.schedule.DepEdge.kRAW # Check get_successor - (successor_b,) = s.state.scope(root).get_successor(block_b) + (successor_b,) = s.state.scopes[root].get_successor(block_b) assert successor_b.dst.same_as(block_c) assert predecessor_c.type == tir.schedule.DepEdge.kRAW @@ -66,7 +66,7 @@ def test_matmul_dependency(): init = s.get_sref(s.get_block("init")) update = s.get_sref(s.get_block("update")) # Check predecessors - p0, p1 = s.state.scope(root).get_predecessors(update) + p0, p1 = s.state.scopes[root].get_predecessors(update) assert p0.dst.same_as(init) assert p1.dst.same_as(init) # WAW and RAW @@ -74,7 +74,7 @@ def test_matmul_dependency(): p0.type == tir.schedule.DepEdge.kWAW and p1.type == tir.schedule.DepEdge.kRAW ) # Check successors - p0, p1 = s.state.scope(root).get_successor(init) + p0, p1 = s.state.scopes[root].get_successor(init) assert p0.dst == update assert p1.dst == update # WAW and RAW diff --git a/tests/python/tir/test_schedule_primitive.py b/tests/python/tir/test_schedule_primitive.py index ddada9249f..a0cdb8c293 100644 --- a/tests/python/tir/test_schedule_primitive.py +++ b/tests/python/tir/test_schedule_primitive.py @@ -425,7 +425,7 @@ def test_fuse(): s.fuse(outer, inner) mod = tvm.script.create_module({"fused_element_wise": fused_element_wise}) fused_func = mod["fused_element_wise"] - tvm.ir.assert_structural_equal(fused_func, s.module) + tvm.ir.assert_structural_equal(fused_func, s.mod["main"]) def test_split_fuse(): @@ -440,7 +440,7 @@ def test_split_fuse(): s.split(inner, nparts=10) mod = tvm.script.create_module({"split_element_wise": split_element_wise}) split_func = mod["split_element_wise"] - tvm.ir.assert_structural_equal(split_func, s.module) + tvm.ir.assert_structural_equal(split_func, s.mod["main"]) def test_compute_at(): @@ -453,7 +453,7 @@ def test_compute_at(): s.compute_at(B, outer) mod = tvm.script.create_module({"compute_at_element_wise": compute_at_element_wise}) split_func = mod["compute_at_element_wise"] - tvm.ir.assert_structural_equal(split_func, s.module) + tvm.ir.assert_structural_equal(split_func, s.mod["main"]) def test_reverse_compute_at(): @@ -467,7 +467,7 @@ def test_reverse_compute_at(): j1, j2 = s.split(j, factor=16) s.reorder(i1, j1, i2, j2) s.reverse_compute_at(C, i2) - tvm.ir.assert_structural_equal(reverse_compute_at_element_wise, s.module) + tvm.ir.assert_structural_equal(reverse_compute_at_element_wise, s.mod["main"]) def test_fuse_loop_sref(): @@ -483,7 +483,7 @@ def test_fuse_loop_sref(): mod = tvm.script.create_module({"predicate_fuse": predicate_fuse}) predicate_fuse_func = mod["predicate_fuse"] - tvm.ir.assert_structural_equal(s.module, predicate_fuse_func) + tvm.ir.assert_structural_equal(s.mod["main"], predicate_fuse_func) def test_reorder_normal(): @@ -498,7 +498,7 @@ def test_reorder_normal(): s.fuse(i, j) mod = tvm.script.create_module({"matmul_reorder": matmul_reorder}) matmul_reorder_func = mod["matmul_reorder"] - tvm.ir.assert_structural_equal(s.module, matmul_reorder_func) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_reorder_func) def test_compute_inline(): @@ -511,7 +511,7 @@ def test_compute_inline(): inlined_func = inline_element_wise - tvm.ir.assert_structural_equal(inlined_func, s.module) + tvm.ir.assert_structural_equal(inlined_func, s.mod["main"]) def test_reverse_compute_inline(): @@ -521,7 +521,7 @@ def test_reverse_compute_inline(): s = tir.Schedule(func, debug_mode=True) C = s.get_block("C") s.reverse_compute_inline(C) - tvm.ir.assert_structural_equal(element_wise_reverse_inline, s.module) + tvm.ir.assert_structural_equal(element_wise_reverse_inline, s.mod["main"]) def test_compute_at_fail(): @@ -554,7 +554,7 @@ def test_reduction(): s.decompose_reduction(update, k) mod = tvm.script.create_module({"matmul_reduction": matmul_reduction}) matmul_reduction_func = mod["matmul_reduction"] - tvm.ir.assert_structural_equal(s.module, matmul_reduction_func) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_reduction_func) def test_cache_read(): @@ -568,7 +568,7 @@ def test_cache_read(): _ = s.cache_read(B, 0, "local") mod = tvm.script.create_module({"cache_read": cache_read}) cached_func = mod["cache_read"] - tvm.ir.assert_structural_equal(cached_func, s.module) + tvm.ir.assert_structural_equal(cached_func, s.mod["main"]) def test_cache_write(): @@ -579,7 +579,7 @@ def test_cache_write(): _ = s.cache_write(C, 0, "local") mod = tvm.script.create_module({"cache_write": cache_write}) cached_func = mod["cache_write"] - tvm.ir.assert_structural_equal(cached_func, s.module) + tvm.ir.assert_structural_equal(cached_func, s.mod["main"]) def test_blockize(): @@ -595,7 +595,7 @@ def test_blockize(): s.blockize(xi) mod = tvm.script.create_module({"blockize": blockize}) blockized_func = mod["blockize"] - tvm.ir.assert_structural_equal(blockized_func, s.module) + tvm.ir.assert_structural_equal(blockized_func, s.mod["main"]) def test_cache_read_write(): @@ -604,13 +604,13 @@ def test_cache_read_write(): s = tir.Schedule(func, debug_mode=True) blockA = s.get_block("A") s.cache_read(blockA, 0, "local") - tvm.ir.assert_structural_equal(test_func_cache_read, s.module) + tvm.ir.assert_structural_equal(test_func_cache_read, s.mod["main"]) # schedule cache write s = tir.Schedule(func, debug_mode=True) blockA = s.get_block("A") s.cache_write(blockA, 0, "local") - tvm.ir.assert_structural_equal(test_func_cache_write, s.module) + tvm.ir.assert_structural_equal(test_func_cache_write, s.mod["main"]) def test_blockize_schedule(): @@ -626,7 +626,7 @@ def test_blockize_schedule(): s.blockize(xi) s.reverse_compute_at(C, yo) s.blockize(s.get_axes(C)[-2]) - tvm.ir.assert_structural_equal(s.module, blockize_schedule_1) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) # test 2 s = tir.Schedule(func, debug_mode=True) B = s.get_block("B") @@ -638,7 +638,7 @@ def test_blockize_schedule(): s.blockize(xi) s.compute_at(B, yo) s.blockize(s.get_axes(B)[-2]) - tvm.ir.assert_structural_equal(s.module, blockize_schedule_1) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) # test 3 s = tir.Schedule(func, debug_mode=True) B = s.get_block("B") @@ -653,7 +653,7 @@ def test_blockize_schedule(): yCo, yCi = s.split(yC, factor=32) s.reorder(xCo, yCo, xCi, yCi) s.compute_at(b_outer, yCo) - tvm.ir.assert_structural_equal(s.module, blockize_schedule_2) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) def test_pragma(): @@ -663,7 +663,7 @@ def test_pragma(): i, _, _ = s.get_axes(C) s.pragma(i, "auto_unroll_max_step", 16) s.pragma(i, "unroll_explicit", False) - tvm.ir.assert_structural_equal(matmul_pragma, s.module) + tvm.ir.assert_structural_equal(matmul_pragma, s.mod["main"]) def test_double_buffer(): @@ -671,7 +671,7 @@ def test_double_buffer(): s = tir.Schedule(func, debug_mode=True) B = s.get_block("B") s.double_buffer(B) - tvm.ir.assert_structural_equal(element_wise_double_buffer, s.module) + tvm.ir.assert_structural_equal(element_wise_double_buffer, s.mod["main"]) if __name__ == "__main__": diff --git a/tests/python/tir/test_schedule_reduction.py b/tests/python/tir/test_schedule_reduction.py index aebe81df4f..e770b68fc5 100644 --- a/tests/python/tir/test_schedule_reduction.py +++ b/tests/python/tir/test_schedule_reduction.py @@ -264,12 +264,12 @@ def test_reduction_decompose(): C = s.get_block("update") i, _, _ = s.get_axes(C) s.decompose_reduction(C, i) - tvm.ir.assert_structural_equal(matmul_decompose0, s.module) + tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"]) # Test 2 s = tir.Schedule(matmul, debug_mode=True) C = s.get_block("update") s.decompose_reduction(C, loop=None) - tvm.ir.assert_structural_equal(matmul_decompose1, s.module) + tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"]) def test_reduction_merge(): @@ -277,7 +277,7 @@ def test_reduction_merge(): init = s.get_block("init") update = s.get_block("update") s.merge_reduction(init, update) - tvm.ir.assert_structural_equal(matmul, s.module) + tvm.ir.assert_structural_equal(matmul, s.mod["main"]) def test_reduction_blockize(): @@ -285,14 +285,14 @@ def test_reduction_blockize(): C = s.get_block("update") _, j, _ = s.get_axes(C) s.blockize(j) - tvm.ir.assert_structural_equal(matmul_blockized, s.module) + tvm.ir.assert_structural_equal(matmul_blockized, s.mod["main"]) def test_reduction_compute_inline(): s = tir.Schedule(matmul_scale, debug_mode=True) D = s.get_block("D") s.compute_inline(D) - tvm.ir.assert_structural_equal(s.module, matmul_scale_inline) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_scale_inline) def test_reduction_rfactor(): @@ -304,9 +304,9 @@ def test_reduction_rfactor(): _, ki = s.split(k, factor=32) _, kii = s.split(ki, factor=4) _ = s.rfactor(kii, factor=0) - tvm.ir.assert_structural_equal(s.module, matmul_rfactor) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) - f = tvm.build(s.module, target="llvm") + f = tvm.build(s.mod["main"], target="llvm") a_np = np.random.uniform(size=(128, 128)).astype("float32") b_np = np.random.uniform(size=(128, 128)).astype("float32") a = tvm.nd.array(a_np) @@ -321,9 +321,9 @@ def test_reduction_rfactor(): C = s.get_block("C") b, i, j = s.get_axes(C) _ = s.rfactor(j, 1) - tvm.ir.assert_structural_equal(s.module, square_sum_rfactor) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) - f = tvm.build(s.module, target="llvm") + f = tvm.build(s.mod["main"], target="llvm") a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") a = tvm.nd.array(a_np) c = tvm.nd.array(np.zeros((16,), dtype="float32")) @@ -338,9 +338,9 @@ def test_reduction_rfactor(): fuse = s.fuse(i, j) _, fi = s.split(fuse, factor=1) _ = s.rfactor(fi, 0) - tvm.ir.assert_structural_equal(s.module, square_sum_square_root_rfactor) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) - f = tvm.build(s.module, target="llvm") + f = tvm.build(s.mod["main"], target="llvm") a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") a = tvm.nd.array(a_np) c = tvm.nd.array(np.zeros((16,), dtype="float32")) @@ -349,7 +349,6 @@ def test_reduction_rfactor(): tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4) -@pytest.mark.skip("Needs GPU") def test_reduction_allreduce(): ctx = tvm.gpu(0) # Test 1 @@ -362,7 +361,7 @@ def test_reduction_allreduce(): s.bind(ax_j, thread_x) s.bind(ax_i, thread_y) - f = tvm.build(s.module, target="cuda") + f = tvm.build(s.mod["main"], target="cuda") a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros((16,), dtype="float32"), ctx) @@ -378,7 +377,7 @@ def test_reduction_allreduce(): _, ax_i, ax_j = s.get_axes(B_block) s.bind(ax_j, thread_x) - f = tvm.build(s.module, target="cuda") + f = tvm.build(s.mod["main"], target="cuda") a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros((16,), dtype="float32"), ctx) @@ -394,7 +393,7 @@ def test_reduction_allreduce(): _, ax_i, ax_j = s.get_axes(B_block) s.bind(ax_i, thread_x) - f = tvm.build(s.module, target="cuda") + f = tvm.build(s.mod["main"], target="cuda") a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros((16,), dtype="float32"), ctx) @@ -422,7 +421,7 @@ def test_reduction_allreduce(): s.reverse_compute_at(B_block, ax_i_rf_o) - f = tvm.build(s.module, target="cuda") + f = tvm.build(s.mod["main"], target="cuda") a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros((16,), dtype="float32"), ctx) diff --git a/tests/python/tir/test_schedule_replace.py b/tests/python/tir/test_schedule_replace.py index 952ea8e46f..1e143ebfbc 100644 --- a/tests/python/tir/test_schedule_replace.py +++ b/tests/python/tir/test_schedule_replace.py @@ -20,6 +20,7 @@ import tvm from tvm import tir +from tvm.ir import IRModule import util @@ -29,7 +30,26 @@ def replace_ir_builder(deep_copy=False, realize=False): new_func = tvm.script.from_source(tvm.script.asscript(func)) s = tir.ScheduleState(new_func, debug_mode=True) # The target stmt - target = tvm.tir.Block([], [], [], [], {}, [], "", "target", s.func.body.block.body[1]) + target = tvm.tir.Block([], [], [], [], {}, [], "", "target", s.mod["main"].body.block.body[1]) + if realize: + target = tvm.tir.BlockRealize([], 1, target) + if deep_copy: + target.__setstate__(target.__getstate__()) + + # It's important to collect garbage explicitly to make + # sure that there is only one reference of the function + gc.collect() + return s, target + + +def replace_ir_builder_module(deep_copy=False, realize=False): + func = util.element_wise_stmt() + new_func = tvm.script.from_source(tvm.script.asscript(func)) + other_func = tvm.script.from_source(tvm.script.asscript(func)) + mod = IRModule(functions={"main": new_func, "other": other_func}) + s = tir.ScheduleState(mod, debug_mode=True) + # The target stmt + target = tvm.tir.Block([], [], [], [], {}, [], "", "target", s.mod["main"].body.block.body[1]) if realize: target = tvm.tir.BlockRealize([], 1, target) if deep_copy: @@ -52,14 +72,14 @@ def replace_ir_builder_with_opaque(): def test_replace_direct_write0(): s, target = replace_ir_builder(realize=True) - old_hash = s.func.__hash__() - sref = s.get_sref(s.func.body.block.body[1]) + old_hash = s.mod["main"].__hash__() + sref = s.get_sref(s.mod["main"].body.block.body[1]) s.replace(sref, target) # There is no other reference so the AST node can be write directly - assert old_hash == s.func.__hash__() + assert old_hash == s.mod["main"].__hash__() # Check the replaced part is equal to the target - tvm.ir.assert_structural_equal(s.func.body.block.body[1], target) + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) # The target reuse the stmt of the sref, so the sref won't be none assert sref.stmt is not None @@ -67,16 +87,16 @@ def test_replace_direct_write0(): def test_replace_direct_write1(): s, target = replace_ir_builder(realize=True) - old_hash = s.func.body.block.body.__hash__() - hold_ref = s.func.body.block.body[1] - sref = s.get_sref(s.func.body.block.body[1]) + old_hash = s.mod["main"].body.block.body.__hash__() + hold_ref = s.mod["main"].body.block.body[1] + sref = s.get_sref(s.mod["main"].body.block.body[1]) s.replace(sref, target) # There is no other reference so the AST node can be write directly - assert old_hash == s.func.body.block.body.__hash__() + assert old_hash == s.mod["main"].body.block.body.__hash__() assert not tvm.ir.structural_equal(hold_ref.body, target) # Check the replaced part is equal to the target - tvm.ir.assert_structural_equal(s.func.body.block.body[1], target) + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) # The target reuse the sref's stmt, so the sref won't be none assert sref.stmt is not None @@ -84,18 +104,18 @@ def test_replace_direct_write1(): def test_replace_copy(): s, target = replace_ir_builder(deep_copy=True, realize=True) - old_hash = s.func.__hash__() + old_hash = s.mod["main"].__hash__() # We hold another reference of func - old_func = s.func - sref = s.get_sref(s.func.body.block.body[0]) + old_func = s.mod["main"] + sref = s.get_sref(s.mod["main"].body.block.body[0]) s.replace(sref, target) # We need to copy the whole func to remain the old_func unchanged - assert old_hash != s.func.__hash__() - assert not tvm.ir.structural_equal(old_func.body, s.func.body) + assert old_hash != s.mod["main"].__hash__() + assert not tvm.ir.structural_equal(old_func.body, s.mod["main"].body) assert old_hash == old_func.__hash__() # Check the replaced part is equal to the target - tvm.ir.assert_structural_equal(s.func.body.block.body[0], target) + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target) # The replaced AST node will be deleted, so the ref will be None assert sref.stmt is None @@ -103,21 +123,21 @@ def test_replace_copy(): def test_replace_partial_copy0(): s, target = replace_ir_builder(deep_copy=True, realize=True) - func_old_hash = s.func.__hash__() - hold_ref = s.func.body.block.body[0] + func_old_hash = s.mod["main"].__hash__() + hold_ref = s.mod["main"].body.block.body[0] ref_old_hash = hold_ref.__hash__() - sref = s.get_sref(s.func.body.block.body[0].body) - other_part_hash = s.func.body.block.body[1].__hash__() + sref = s.get_sref(s.mod["main"].body.block.body[0].body) + other_part_hash = s.mod["main"].body.block.body[1].__hash__() s.replace(sref, target) # The hold stmt will not change but copy a new one - assert ref_old_hash != s.func.body.block.body[0].__hash__() + assert ref_old_hash != s.mod["main"].body.block.body[0].__hash__() assert not tvm.ir.structural_equal(hold_ref.body, target) # The function and the other part stmt can be directly write - assert func_old_hash == s.func.__hash__() - assert other_part_hash == s.func.body.block.body[1].__hash__() + assert func_old_hash == s.mod["main"].__hash__() + assert other_part_hash == s.mod["main"].body.block.body[1].__hash__() # Check the replaced part is equal to the target - tvm.ir.assert_structural_equal(s.func.body.block.body[0].body, target) + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body, target) # The replaced AST node will be deleted, so the ref will be None assert sref.stmt is None @@ -125,21 +145,21 @@ def test_replace_partial_copy0(): def test_replace_partial_copy1(): s, target = replace_ir_builder(deep_copy=True) - func_old_hash = s.func.__hash__() - hold_ref = s.func.body.block.body[0].body - stmt_old_hash = s.func.body.block.body[0].__hash__() - sref = s.get_sref(s.func.body.block.body[0].body.body.block) - other_part_hash = s.func.body.block.body[1].__hash__() + func_old_hash = s.mod["main"].__hash__() + hold_ref = s.mod["main"].body.block.body[0].body + stmt_old_hash = s.mod["main"].body.block.body[0].__hash__() + sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) + other_part_hash = s.mod["main"].body.block.body[1].__hash__() s.replace(sref, target) # The father stmt will change since there is only one reference - assert stmt_old_hash == s.func.body.block.body[0].__hash__() + assert stmt_old_hash == s.mod["main"].body.block.body[0].__hash__() assert not tvm.ir.structural_equal(hold_ref.body, target) # The function and the other part stmt can be directly write - assert func_old_hash == s.func.__hash__() - assert other_part_hash == s.func.body.block.body[1].__hash__() + assert func_old_hash == s.mod["main"].__hash__() + assert other_part_hash == s.mod["main"].body.block.body[1].__hash__() # Check the replaced part is equal to the target - tvm.ir.assert_structural_equal(s.func.body.block.body[0].body.body.block, target) + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body.body.block, target) # The replaced AST node will be deleted, so the ref will be None assert sref.stmt is None @@ -147,24 +167,24 @@ def test_replace_partial_copy1(): def test_replace_root_write(): s, target = replace_ir_builder() - old_hash = s.func.__hash__() - sref = s.get_sref(s.func.body.block) + old_hash = s.mod["main"].__hash__() + sref = s.get_sref(s.mod["main"].body.block) s.replace(sref, target) # Check no copy and the new body equals to target - assert old_hash == s.func.__hash__() - tvm.ir.assert_structural_equal(s.func.body.block, target) + assert old_hash == s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) def test_replace_root_copy0(): s, target = replace_ir_builder(deep_copy=True) - old_hash = s.func.__hash__() - func_ref = s.func - sref = s.get_sref(s.func.body.block) + old_hash = s.mod["main"].__hash__() + func_ref = s.mod["main"] + sref = s.get_sref(s.mod["main"].body.block) s.replace(sref, target) # Check the new body equals to target - assert old_hash != s.func.__hash__() - tvm.ir.assert_structural_equal(s.func.body.block, target) + assert old_hash != s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) # Check the original func remains unchanged assert old_hash == func_ref.__hash__() assert not tvm.ir.structural_equal(func_ref.body, target) @@ -173,25 +193,56 @@ def test_replace_root_copy0(): def test_replace_root_copy1(): s, target = replace_ir_builder(deep_copy=True, realize=True) - old_hash = s.func.body.block.__hash__() - func_ref = s.func.body.block - sref = s.get_sref(s.func.body.block.body[0]) + old_hash = s.mod["main"].body.block.__hash__() + func_ref = s.mod["main"].body.block + sref = s.get_sref(s.mod["main"].body.block.body[0]) s.replace(sref, target) # Check the new body equals to target - assert old_hash != s.func.body.block.__hash__() - tvm.ir.assert_structural_equal(s.func.body.block.body[0], target) + assert old_hash != s.mod["main"].body.block.__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target) # Check the original func remains unchanged assert old_hash == func_ref.__hash__() assert not tvm.ir.structural_equal(func_ref.body, target) +def test_replace_root_copy2(): + s, target = replace_ir_builder(deep_copy=True) + + old_hash = s.mod.functions.__hash__() + func_ref = s.mod.functions + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod.functions.__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + for _, v in func_ref.items(): + assert not tvm.ir.structural_equal(v.body.block, target) + + +def test_replace_root_copy3(): + s, target = replace_ir_builder(deep_copy=True) + + old_hash = s.mod.__hash__() + func_ref = s.mod + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod.__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + assert not tvm.ir.structural_equal(func_ref["main"].body.block, target) + + def test_replace_block_remap(): func = util.element_wise_stmt() s = tir.Schedule(func, debug_mode=True) # The target stmt target = util.matmul_stmt_original().body.block.body.body.body[0].block - sref = s.get_sref(s.module.body.block.body[0].body.body.block) - s.state.replace(sref, target, {target: s.module.body.block.body[0].body.body.block}) + sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) + s.state.replace(sref, target, {target: s.mod["main"].body.block.body[0].body.body.block}) sref_new = s.get_sref(s.get_block("init")) # Check the original sref has been remapped assert sref.__hash__() == sref_new.__hash__() @@ -201,8 +252,8 @@ def test_replace_block_remap(): def test_replace_block_in_opaque_block(): pass # TODO # s = replace_ir_builder_with_opaque() - # root_hash = s.func.__hash__() - # for_loop = s.func.body.block.body.body.block.body[1].then_case.block.body + # root_hash = s.mod["main"].__hash__() + # for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body # sref = s.get_sref(for_loop) # new_for_loop = tir.Loop( # loop_var=for_loop.loop_var, @@ -212,10 +263,28 @@ def test_replace_block_in_opaque_block(): # body=tir.Evaluate(0), # ) # s.replace(sref, new_for_loop) - # assert root_hash == s.func.__hash__() + # assert root_hash == s.mod["main"].__hash__() # tvm.ir.assert_structural_equal(sref.stmt, new_for_loop) +def test_replace_ir_module(): + s, target = replace_ir_builder_module(deep_copy=True) + + old_hash = s.mod["main"].__hash__() + other_func_hash = s.mod["other"].__hash__() + func_ref = s.mod["main"] + + sref = s.get_sref(s.mod["main"].body.block) + s.replace(sref, target) + # Check the new body equals to target + assert old_hash != s.mod["main"].__hash__() + tvm.ir.assert_structural_equal(s.mod["main"].body.block, target) + # Check the original func remains unchanged + assert old_hash == func_ref.__hash__() + assert not tvm.ir.structural_equal(func_ref.body, target) + assert other_func_hash == s.mod["other"].__hash__() + + if __name__ == "__main__": test_replace_direct_write0() test_replace_direct_write1() @@ -225,5 +294,8 @@ def test_replace_block_in_opaque_block(): test_replace_root_write() test_replace_root_copy0() test_replace_root_copy1() + test_replace_root_copy2() + test_replace_root_copy3() test_replace_block_remap() test_replace_block_in_opaque_block() + test_replace_ir_module() diff --git a/tests/python/tir/test_schedule_sparse.py b/tests/python/tir/test_schedule_sparse.py index a1635dc8e6..ad9a5b79df 100644 --- a/tests/python/tir/test_schedule_sparse.py +++ b/tests/python/tir/test_schedule_sparse.py @@ -79,13 +79,13 @@ def schedule_sparse_dense_llvm(func): bsr_par = s.get_block("bsr_par") bsr_block = s.get_block("bsr_block") i, j = s.get_axes(bsr_block) - data = s.module.params[1] - jo, ji = s.split(j, factor=s.module.buffer_map[data].shape[1]) + data = s.mod["main"].params[1] + jo, ji = s.split(j, factor=s.mod["main"].buffer_map[data].shape[1]) s.compute_at(bsr_par, ji) s.vectorize(ji) i_jo = s.fuse(i, jo) s.parallel(i_jo) - return s.module + return s.mod["main"] _sparse_dense_implement_tir = { diff --git a/tests/python/tir/test_schedule_tensorize.py b/tests/python/tir/test_schedule_tensorize.py index 7d3822d373..7f9308f59c 100644 --- a/tests/python/tir/test_schedule_tensorize.py +++ b/tests/python/tir/test_schedule_tensorize.py @@ -290,7 +290,7 @@ def test_tensorize_gemm(): s.tensorize(ii, tensor_intrin) - func = tvm.build(s.module) + func = tvm.build(s.mod["main"]) a_np = np.random.uniform(size=(128, 128)).astype("float32") b_np = np.random.uniform(size=(128, 128)).astype("float32") @@ -314,7 +314,7 @@ def test_tensorize_buffer_bind(): s.decompose_reduction(update, ko) tensor_intrin = tvm.tir.TensorIntrin(desc_func, lower_intrin_func) s.tensorize(ii, tensor_intrin) - tvm.ir.assert_structural_equal(tensorized_func, s.module) + tvm.ir.assert_structural_equal(tensorized_func, s.mod["main"]) def test_high_dim_tensorize(): @@ -327,7 +327,7 @@ def test_high_dim_tensorize(): s.reorder(io, jo, ko, ii, ji, ki) tensor_intrin = tvm.tir.TensorIntrin(desc_func, lower_intrin_func) s.tensorize(ii, tensor_intrin) - tvm.ir.assert_structural_equal(tensorized_batch_matmul, s.module) + tvm.ir.assert_structural_equal(tensorized_batch_matmul, s.mod["main"]) def test_tensorize_dot_product(): @@ -344,7 +344,7 @@ def test_tensorize_dot_product(): a = tvm.nd.array(a_np) b = tvm.nd.array(b_np) c = tvm.nd.array(np.zeros((1, 4, 4), dtype="float32"), ctx) - func = tvm.build(s.module, target=target) + func = tvm.build(s.mod["main"], target=target) func(a, b, c) tvm.testing.assert_allclose( c.asnumpy(), diff --git a/tests/python/tir/test_schedule_vectorize.py b/tests/python/tir/test_schedule_vectorize.py index ab219d7a74..791a661ada 100644 --- a/tests/python/tir/test_schedule_vectorize.py +++ b/tests/python/tir/test_schedule_vectorize.py @@ -90,7 +90,7 @@ def test_vectorize_normal(): _, _, ji = s.get_axes(B) s.vectorize(ji) mod = tvm.script.create_module({"predicate_vectorize": predicate_vectorize}) - tvm.ir.assert_structural_equal(s.module, mod["predicate_vectorize"]) + tvm.ir.assert_structural_equal(s.mod["main"], mod["predicate_vectorize"]) def test_vectorize_complete(): @@ -105,7 +105,7 @@ def test_vectorize_complete(): mod = tvm.script.create_module( {"element_wise_compute_at_vectorize": element_wise_compute_at_vectorize} ) - tvm.ir.assert_structural_equal(s.module, mod["element_wise_compute_at_vectorize"]) + tvm.ir.assert_structural_equal(s.mod["main"], mod["element_wise_compute_at_vectorize"]) def test_vectorize_fail_on_reduce_var(): @@ -124,7 +124,7 @@ def test_unroll_normal(): _, _, ji = s.get_axes(B) s.unroll(ji) mod = tvm.script.create_module({"predicate_unroll": predicate_unroll}) - tvm.ir.assert_structural_equal(s.module, mod["predicate_unroll"]) + tvm.ir.assert_structural_equal(s.mod["main"], mod["predicate_unroll"]) if __name__ == "__main__":