diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 210ed53a7904..43f2379a0b56 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object { */ virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) = 0; + /*! + * \brief Sample a compute-at location of the given block + * \param block_rv The block whose compute-at location is to be sampled + * \param decision The sampling decision + * \return The sampled loop where the input block is to be computed at + */ + virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index d3fbf0f91214..7bc3b697945d 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1361,6 +1361,13 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ */ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; +/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ +constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; + +/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ +constexpr const char* meta_schedule_random_compute_producer = + "meta_schedule.random_compute_producer"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index be5c0e0b620b..9ad3c0627ea9 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -18,3 +18,4 @@ """ from .auto_inline import AutoInline from .schedule_rule import PyScheduleRule, ScheduleRule +from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py new file mode 100644 index 000000000000..2355b0bfa8e5 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rule that randomly select a compute-at location for a free block""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.RandomComputeLocation") +class RandomComputeLocation(ScheduleRule): + """A rule that randomly select a compute-at location for a free block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py new file mode 100644 index 000000000000..10e31e7213cb --- /dev/null +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.tir import Schedule +from tvm.tir.schedule import Trace + + +def check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b261fd0a7518..7d352f156a31 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -369,6 +369,32 @@ def sample_perfect_tile( ) ) + @type_checked + def sample_compute_location( + self, + block: BlockRV, + decision: Optional[int] = None, + ) -> LoopRV: + """Sample a compute-at location of the given block + + Parameters + ---------- + block : BlockRV + The block whose compute-at location is to be sampled + decision : Optional[int] + The sampling decision + + Returns + ------- + result : LoopRV + The sampled loop where the input block is to be computed at + """ + return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member + self, + block, + decision, + ) + ########## Schedule: Get blocks & loops ########## @type_checked def get_block( diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc new file mode 100644 index 000000000000..957ad89af106 --- /dev/null +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class RandomComputeLocationNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + if (!CheckConditions(sch, block_rv)) { + return {sch}; + } + + // Step 1. If the producer of the input block needs a random compute-at location (specified by + // the annotation), we collect the producer first, and transform the producer block later. + // - The reason we collect the producer before transforming the input block is that, if the + // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer + // access the input block. Hence we collect its producer ahead of time. + // - Note that only single producer is allowed in this case. + Array producers{nullptr}; + if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, + true)) { + producers = sch->GetProducers(block_rv); + sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); + ICHECK_EQ(producers.size(), 1); + } + + // Step 2. Transform the input block. + tir::Schedule res = RandomlyComputeAt(sch, block_rv); + + // Step 3. Transform the producer block if compute-location sampling is needed. + if (producers.defined()) { + res = RandomlyComputeAt(res, producers[0]); + } + + return {res}; + } + + private: + bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + + // Cond 1. The block is not the root block. + if (block_sref->parent == nullptr) { + return false; + } + // Cond 2. The block should be the direct child block of the root block. + if (GetScopeRoot(sch->state(), block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false) + ->parent != nullptr) { + return false; + } + // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child + // block. + Array loop_srefs = tir::GetLoops(block_sref); + if (loop_srefs.empty()) { + return false; + } + if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { + return false; + } + // Cond 5. The block is not tiled. We check this condition by examine the block's annotation. + if (tir::HasBeenMultiLevelTiled(block_sref)) { + return false; + } + // Cond 6. The block has at lease one consumer. + if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { + return false; + } + return true; + } + + /*! + * \brief Keep sampling a compute-at location for the input block until success. + * \param sch The TIR schedule + * \param block_rv The block whose compute-at location is to be sampled + * \return The TIR schedule after transformation + */ + tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); + sch->ComputeAt(block_rv, compute_at_loc, true); + return sch; + } + + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::RandomComputeLocation() { + return ScheduleRule(make_object()); +} + +TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") + .set_body_typed(ScheduleRule::RandomComputeLocation); +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 1070833be19d..9622e2dcd318 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -267,6 +267,39 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to + * \param loop_sref The loop to be checked + * \return The IterVarType of the specific loop + */ +IterVarType GetLoopIterType(const StmtSRef& loop_sref); + +/*! + * \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree + * \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried + * \return The lowest common ancestor of the input block srefs or loop srefs + * \note The input array is required to have at least one sref + */ +StmtSRef GetSRefLowestCommonAncestor(const Array& srefs); + +/*! + * \brief Checks if the given block has been applied by multi-level tiling. We check this by + * examine the block's annotation. + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has been multi-level tiled. + */ +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); + +/*! + * \brief Collect all the feasible compute-at locations of the input block + * \param self The schedule state + * \param block_sref The block whose compute-at locations are to be collected + * \return All the feasible compute-at locations of the input block, given as an array of loop srefs + * and an array of their indices among the outer loops of the input block + */ +std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, + const StmtSRef& block_sref); + /******** Producer-consumer relation ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 36a1d05f4cf2..052097314ee2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -646,6 +646,158 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +IterVarType GetLoopIterType(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const Var& loop_var = loop->loop_var; + int n_spatial = 0; + int n_reduce = 0; + int n_other = 0; + auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { + if (const auto* realize = obj.as()) { + const BlockNode* block = realize->block.get(); + // Number of block vars and their bindings + ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); + size_t n = realize->iter_values.size(); + for (size_t i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + // Categorize the current block var + int* ref = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + ref = &n_spatial; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + ref = &n_reduce; + } else { + ref = &n_other; + } + // Visit the binding to see if `loop_var` appears + PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void { + if (obj.same_as(loop_var)) { + (*ref) += 1; + } + }); + } + return false; + } + return true; + }; + PreOrderVisit(loop->body, f_visit); + if (n_other) { + return IterVarType::kOpaque; + } else if (n_spatial && n_reduce) { + return IterVarType::kOpaque; + } else if (n_reduce) { + return IterVarType::kCommReduce; + } else { + return IterVarType::kDataPar; + } +} + +StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { + CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref"; + + std::unordered_map sref_visited_cnt; + for (const StmtSRef& sref : srefs) { + const StmtSRefNode* p = sref.get(); + while (p != nullptr) { + ++sref_visited_cnt[p]; + p = p->parent; + } + } + size_t n_sref = srefs.size(); + const StmtSRefNode* p = srefs[0].get(); + while (p != nullptr && sref_visited_cnt[p] != n_sref) { + p = p->parent; + } + ICHECK(p != nullptr); + return GetRef(p); +} + +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).defined(); +} + +std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, + const StmtSRef& block_sref) { + Array location_srefs; + std::vector location_indices; + + // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can + // be inlined. + if (CanComputeInline(self, block_sref)) { + location_srefs.push_back(StmtSRef::InlineMark()); + location_indices.push_back(-2); + } + location_srefs.push_back(StmtSRef::RootMark()); + location_indices.push_back(-1); + + // Step 2. If the block has no consumer, there is no more candidate. + Array consumers = GetConsumers(self, block_sref); + if (consumers.empty()) { + return std::make_pair(location_srefs, location_indices); + } + + // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If + // such a loop cannot be found, there is no more candidate and we just return. + StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers) + : GetRef(consumers[0]->parent); + if (loop_boundary->StmtAs() == nullptr) { + return std::make_pair(location_srefs, location_indices); + } + + // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position + // of the boundary loop reveals the number of possible additional candidates. + Array loop_srefs = GetLoops(consumers[0]); + size_t lca_pos = + std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); + ICHECK_LT(lca_pos, loop_srefs.size()); + size_t n_candidate = lca_pos + 1; + + // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This + // position is used for removing the unwanted candidates from the perspective of performance. + std::vector loop_iter_types; + loop_iter_types.reserve(n_candidate); + int i_last_datapar = -1; + for (size_t i = 0; i < n_candidate; ++i) { + // TODO(siyuan): improve the performance + IterVarType iter_type = GetLoopIterType(loop_srefs[i]); + loop_iter_types.push_back(iter_type); + if (iter_type == IterVarType::kDataPar) { + i_last_datapar = i; + } + } + // Step 6. Check and add the candidates in turn according to the following rules: + // - skip the unit loops (loops with extent 1); + // - do not consider the data-parallel loops after a not-data-parallel loop; + // - do not consider the trailing not-data-parallel loops. + location_srefs.reserve(n_candidate + 2); + location_indices.reserve(n_candidate + 2); + bool visited_reduce = false; + for (size_t i = 0; i < n_candidate; ++i) { + const int64_t* loop_extent = GetLoopIntExtent(loop_srefs[i]); + if (loop_extent != nullptr && *loop_extent == 1) { + continue; + } + + if (loop_iter_types[i] == IterVarType::kDataPar) { + if (visited_reduce) { + break; + } + } else { + visited_reduce = true; + if (static_cast(i) > i_last_datapar) { + break; + } + } + if (CanComputeAt(self, block_sref, loop_srefs[i], true)) { + location_srefs.push_back(loop_srefs[i]); + location_indices.push_back(i); + } + } + + return std::make_pair(location_srefs, location_indices); +} + /******** Producer-consumer relation ********/ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9e5b6f949feb..9f8dc6dd2daf 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -242,6 +242,15 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int throw; } +LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV( + tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); + TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index d420728a9e3c..96cb0f728835 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -86,6 +86,8 @@ class ConcreteScheduleNode : public ScheduleNode { Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; + LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 45efd9f76cef..f0b38af01b5f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -98,6 +98,17 @@ TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, Optional>* decision); +/*! + * \brief Sample a compute-at location of the given block + * \param self The schedule state + * \param rand_state The random state + * \param block_sref The sref of the block whose compute-at location is to be sampled + * \param decision The sampling decision + * \return The sampled loop where the input block is to be computed at + */ +TVM_DLL tir::StmtSRef SampleComputeLocation( + tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, + const tir::StmtSRef& block_sref, Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6d944b38d46a..0e767825573f 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -354,6 +354,40 @@ std::vector SamplePerfectTile( return result; } +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const StmtSRef& block_sref, Optional* decision) { + // Step 1. Collect all possible compute-at locations. + Array location_srefs; + std::vector location_indices; + std::tie(location_srefs, location_indices) = CollectComputeLocation(self, block_sref); + ICHECK_EQ(location_srefs.size(), location_indices.size()); + + // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the + // location candidates. Otherwise, pick the location before the previous decision. + // Step 3. If there was not a previous decision, sample a decision from the collected locations. + if (decision->defined()) { + int64_t old_decision = Downcast(*decision)->value; + auto it = std::lower_bound(location_indices.begin(), location_indices.end(), old_decision); + int idx = it - location_indices.begin(); + + if (it != location_indices.end() && *it == old_decision) { + *decision = Integer(old_decision); + return location_srefs[idx]; + } else if (it != location_indices.begin()) { + *decision = Integer(location_indices[idx - 1]); + return location_srefs[idx - 1]; + } else { + *decision = Integer(-1); + return StmtSRef::RootMark(); + } + } else { + int sampled_idx = SampleInt(rand_state, 0, location_indices.size()); + *decision = Integer(location_indices[sampled_idx]); + return location_srefs[sampled_idx]; + } +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -418,8 +452,38 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleComputeLocation"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 1; + + static LoopRV UnpackedApplyToSchedule(Schedule sch, // + BlockRV block_rv, // + Optional decision) { + return sch->SampleComputeLocation(block_rv, decision); + } + + static String UnpackedAsPython(Array outputs, // + String block_rv, // + Optional decision) { + PythonAPICall py("sample_compute_location"); + py.Input("block", block_rv); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); +TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 75939f00b8f4..6e33862c07ca 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -125,6 +125,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") + .set_body_method(&ScheduleNode::SampleComputeLocation); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b4d1ba01e93e..da7a2641b162 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -73,6 +73,20 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } +LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, + this->GetSRef(block_rv), &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 5ce4763f117f..b35f1b6e17bb 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -51,6 +51,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; + LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 4df335079f93..2acab384af0b 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -308,6 +308,18 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& an return result.defined() && result.value() == ann_val; } +/*! + * \brief Check if a Block/For has a specific pair of annotation key and values + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be checked + * \param ann_val The boolean annotation value to be checked + * \return Whether a Block/For has a specific pair of annotation key and values + */ +inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { + Optional result = GetAnn(sref, ann_key); + return result.defined() && result.value()->value == ann_val; +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py new file mode 100644 index 000000000000..92c7da922c39 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.schedule_rule import RandomComputeLocation +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Add: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + # body + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("move"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([A_cached[vi, vj, vk]]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads([A_cached[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_random_compute_location(): + expected = [ + [ + 'b0 = sch.get_block(name="move", func_name="main")', + "l1 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=1)", + ] + ] + mod = Add + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=RandomComputeLocation(), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_random_compute_location() diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 5d2676e41d1c..4a4cd6c6c2b9 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -37,6 +37,67 @@ def elementwise(a: T.handle, b: T.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 +@T.prim_func +def tiled_conv2d_with_padding( + inputs: T.Buffer[(1, 224, 224, 3), "float32"], + weight: T.Buffer[(7, 7, 3, 64), "float32"], + conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, + inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], + T.float32(0), + dtype="float32", + ) + for ( + i0_0, + i1_0, + i2_0, + i3_0, + i0_1_1, + i1_1_1, + i2_1_1, + i3_1_1, + i4_0, + i5_0, + i6_0, + i0_2, + i1_2, + i2_2, + i3_2, + i4_1, + i5_1, + i6_1, + i0_3, + i1_3, + i2_3, + i3_3, + ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): + with T.block("conv2d_nhwc"): + n = T.axis.spatial(1, 0) + h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) + w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) + co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1]) + T.reads( + conv2d_nhwc[n, h, w, co], + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], + weight[rh, rw, rc, co], + ) + T.writes(conv2d_nhwc[n, h, w, co]) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float32(0) + conv2d_nhwc[n, h, w, co] = ( + conv2d_nhwc[n, h, w, co] + + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + ) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -116,5 +177,21 @@ def test_sample_perfect_tile_composite(): verify_trace_roundtrip(sch, mod=elementwise) +def test_sample_compute_location(): + n = 100 + sch = tir.Schedule(tiled_conv2d_with_padding, seed=42, debug_mask="all") + pad_input = sch.get_block("PadInput") + decision_dict = dict() + for _ in range(n): + _ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name + decision = sch.trace.decisions[sch.trace.insts[-1]] + decision_dict[decision] = decision_dict[decision] + 1 if decision in decision_dict else 1 + + n_candidates = 8 + expected_rate = 1.0 / n_candidates + for _, cnt in decision_dict.items(): + assert (expected_rate - 0.03) * n <= cnt <= (expected_rate + 0.03) * n + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))