diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 0cf157147423..67ec3ed12b05 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -411,14 +411,33 @@ def storage_align(self, stage_id, it, factor, offset): it : Iterator factor : Int offset : Int + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.clear_cache() + + def tensorize(self, stage_id, it, ti_func_name): + """ The `ti_func_name` corresponds to a global registered funcion + that returns a TensorIntrin + + Parameters + ---------- + stage_id : Int + The index of the stage to do storage align + it : Iterator + The target iterator + ti_func_name : Str + Tensorize intrinsic function name Returns ------- - state : State - The updated state + res_it : Iterator + The tensorized Iterator """ - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.state_object, res = _ffi_api.StateTensorize(self.state_object, + stage_id, it, + ti_func_name) self.clear_cache() + return res def __str__(self): return str(self.state_object) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index f8d3f419dcb4..89b4afd84e86 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -248,6 +248,8 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol else: raise ValueError("Invalid strategy: " + self.strategy) + if self.verbose >= 1: + print("Next tuning task: %d" % task_idx) self.tune_task(task_idx) def tune_task(self, task_idx): diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index de3b98a5106b..5ca0c8503662 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1086,7 +1086,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, iter->iter_type, iter->annotation, - &iter->ori_iters)); + &iter->ori_iters, + iter->attr)); } else { LOG(FATAL) << "Infer bound fails"; } @@ -1161,6 +1162,8 @@ std::pair > ComputeDAG::ReplaySteps( ps->ApplyToSchedule(stages, stage_to_axes, &schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 77361dbf837c..b6e6d854e3e5 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -39,7 +39,8 @@ TVM_REGISTER_NODE_TYPE(IteratorNode); // Maker for other classes Iterator IteratorNode::make(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { + const std::vector* ori_iters, + std::string attr) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); @@ -48,6 +49,7 @@ Iterator IteratorNode::make(std::string name, Range range, if (ori_iters != nullptr) { node->ori_iters = *ori_iters; } + node->attr = std::move(attr); return Iterator(node); } @@ -310,6 +312,15 @@ void State::storage_align(int stage_id, const Iterator& it, int factor, return DoStorageAlignStep(step); } +Iterator State::tensorize(int stage_id, const Iterator& it, + std::string ti_func_name) { + const Stage& stage = operator->()->stages[stage_id]; + TensorizeStep step = TensorizeStepNode::make( + stage_id, GetIndex(stage->iters, it), ti_func_name); + CopyOnWrite()->transform_steps.push_back(step); + return DoTensorizeStep(step); +} + // Steps' implementations void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -509,8 +520,10 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Iterator it = stage->iters[step->iter_id]; + CHECK_EQ(it->annotation, IteratorAnnotation::kNone); Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters); + step->annotation, &it->ori_iters, + it->attr); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -538,7 +551,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { new_iters.push_back(it); } else { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters, + it->attr)); } } @@ -559,7 +573,8 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters, + it->attr)); } // update attach map @@ -747,6 +762,18 @@ void State::DoStorageAlignStep(const StorageAlignStep& step) { stage->storage_offset = step->offset; } +Iterator State::DoTensorizeStep(const TensorizeStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Iterator it = stage->iters[step->iter_id]; + Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = std::move(new_stage); + return new_it; +} + void State::DoStep(const Step& step, const ComputeDAG& dag) { if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); @@ -776,6 +803,8 @@ void State::DoStep(const Step& step, const ComputeDAG& dag) { DoRfactorStep(GetRef(ps), dag); } else if (auto ps = step.as()) { DoStorageAlignStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoTensorizeStep(GetRef(ps)); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -854,15 +883,22 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, case kThreadY: *os << "gpu.threadIdx.y "; break; + case kTensorized: + *os << "tensorize "; + break; + default: + LOG(FATAL) << "Invalid Annotation " << iter->annotation; break; } if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")" - << "\n"; + << iter->range->extent << ")"; } else { - *os << iter->name << " (None)" - << "\n"; + *os << iter->name << " (None)"; } + if (!iter->attr.empty()) { + *os << " " << iter->attr; + } + *os << "\n"; indent += 2; } @@ -1174,6 +1210,13 @@ TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") return state; }); +TVM_REGISTER_GLOBAL("ansor.StateTensorize") +.set_body_typed([](State state, int stage_id, const Iterator& it, + std::string ti_func) { + const auto& res = state.tensorize(stage_id, it, ti_func); + return Array{state, res}; +}); + TVM_REGISTER_GLOBAL("ansor.StateEqual") .set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 90ba48cd92ac..6eef404ae272 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -74,7 +74,8 @@ enum IteratorType { /*! \brief The type of an iterator's annotation */ enum IteratorAnnotation { kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY + kVThread, kBlockX, kThreadX, kBlockY, kThreadY, + kTensorized }; class Iterator; @@ -90,14 +91,17 @@ class IteratorNode : public Object { IteratorType iter_type; IteratorAnnotation annotation; std::vector ori_iters; // The original iterators before fusion + std::string attr; static Iterator make(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); + const std::vector* ori_iters = nullptr, + std::string attr = ""); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); + v->Visit("attr", &attr); } static constexpr const char *_type_key = "ansor.Iterator"; @@ -115,6 +119,7 @@ class FuseStep; class AnnotationStep; class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; class CacheReadStep; class CacheWriteStep; class PragmaStep; class RfactorStep; class StorageAlignStep; +class TensorizeStep; /*! * \brief A stage in the compute declaration @@ -254,19 +259,21 @@ class State : public ObjectRef { Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + Iterator tensorize(int stage_id, const Iterator& it, + std::string ti_func_name); void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); int cache_write(int stage_id, const std::string& scope_name, const ComputeDAG& task_dag); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); /* Do transform steps * Note: The following functions only change loop state but do not change transform_history. @@ -278,14 +285,15 @@ class State : public ObjectRef { std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); Iterator DoAnnotationStep(const AnnotationStep& step); + Iterator DoTensorizeStep(const TensorizeStep& step); void DoComputeAtStep(const ComputeAtStep& step); void DoComputeRootStep(const ComputeRootStep& step); void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPragmaStep(const PragmaStep& step); + void DoStorageAlignStep(const StorageAlignStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); - void DoPragmaStep(const PragmaStep& step); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); - void DoStorageAlignStep(const StorageAlignStep& step); // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 4a045d31a487..7e022e3be3c3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -751,6 +751,11 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, continue; } + if (HasAnnotationIter(stage, IteratorAnnotation::kThreadX)) { + // Skip if this stage has already done thread bind + continue; + } + std::vector to_fuse; // This stage has not been tiled, but in GPU schedule, we must tile it @@ -861,10 +866,16 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, !HasCacheWriteStage((*state), stage_id - 1)) || (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && HasCacheWriteStage((*state), stage_id - 2))) { + const Stage& target_stage = (*state)->stages[stage_id]; + if (HasAnnotationIter(target_stage, IteratorAnnotation::kThreadX) || + HasAnnotationIter(target_stage, IteratorAnnotation::kTensorized)) { + // Skip if this stage has already done thread bind or has been + // tensorized + continue; + } // Get spatial_split_step_ids from the root stage std::unordered_set consumers; std::vector spatial_split_step_ids; - const Stage& target_stage = (*state)->stages[stage_id]; GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); CHECK_EQ(consumers.size(), 1); int target_stage_id = OperationToStage(*consumers.begin(), (*state)); @@ -1129,6 +1140,11 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, continue; } + if (HasAnnotationIter(stage, IteratorAnnotation::kTensorized)) { + // Skip if this stage has been tensorized + continue; + } + // try to fuse and vectorize the space iterators in the inner most tile int cum_length_prod = 1; @@ -1224,7 +1240,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, n--; } - } else if (stage->op->attrs.count(policy->always_unroll_key)) { + } else if (stage->op->attrs.count(policy->always_unroll_key)) { // Special unroll policy auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, policy->always_unroll_key); diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 3d0611173c94..472e90771879 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -143,6 +143,16 @@ inline bool HasReduceIter(const Stage& stage) { return false; } +// Return whether the stage has specific annotated iterators +inline bool HasAnnotationIter(const Stage& stage, IteratorAnnotation type) { + for (const auto& iter : stage->iters) { + if (iter->annotation == type) { + return true; + } + } + return false; +} + // Return whether an op needs multi level tiling inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index b03acb1edc3c..ed5d4b868c27 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -167,6 +167,11 @@ struct Handler > { writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->factor); writer->WriteArrayItem(ps->offset); + } else if (auto ps = data[i].as<::tvm::ansor::TensorizeStepNode>()) { + writer->WriteArrayItem(std::string("TS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->ti_func_name); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -179,7 +184,7 @@ struct Handler > { std::vector<::tvm::ansor::Step> * data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; - std::string name, scope_name, pragma_type; + std::string name, scope_name, pragma_type, ti_func_name; int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; int level, factor_iter_id, factor, offset; @@ -311,6 +316,15 @@ struct Handler > { reader->Read(&offset); data->push_back(::tvm::ansor::StorageAlignStepNode::make( stage_id, iter_id, factor, offset)); + } else if (name == "TS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&ti_func_name); + data->push_back(::tvm::ansor::TensorizeStepNode::make( + stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 3f59ff736e9d..b0e67a481ae3 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -26,6 +26,7 @@ #include "transform_step.h" #include +#include #include #include "utils.h" @@ -801,5 +802,40 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( return ss.str(); } +/********** Tensorize **********/ +TensorizeStep TensorizeStepNode::make(int stage_id, int iter_id, + std::string ti_func_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->ti_func_name = ti_func_name; + return TensorizeStep(node); +} + +void TensorizeStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + auto func = tvm::runtime::Registry::Get(ti_func_name); + CHECK(func != nullptr) << "Cannot find the tensorize intrinsic func"; + tvm::te::TensorIntrin res = (*func)(); + CHECK(res.defined()) << "Tensorize intrinsic func must return a " + << "tvm::te::TensorIntrin object"; + stage.tensorize(axes[iter_id], res); +} + +std::string TensorizeStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].tensorize(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << ti_func_name << "())\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 8240623ae3b1..9af14429bf61 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -23,17 +23,18 @@ * * \Note How to add a new transform step. * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in transform_steps.h, and implement its make function - * in FuseStepNode::make(...) transform_steps.cc - * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. - * - In these two functions you need to lower this step with tvm's schedule API - * 3. Implement State::fuse and State::DoFuseStep. + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its make function + * `FuseStepNode::make(...)` in `transform_steps.cc` + * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * - In these two functions you need to lower this step with tvm's te schedule API + * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style - * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` - * (in serialization.cc) + * in `serialization.cc` * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test */ #ifndef TVM_ANSOR_TRANSFORM_STEP_H_ @@ -365,6 +366,29 @@ class StorageAlignStepNode: public StepNode { }; TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); +/*! \brief Tensorize step that corresponds to te::Schedule::tensorize + * \Note This step takes a global registered function name as input. */ +class TensorizeStepNode: public StepNode { + public: + int iter_id; + std::string ti_func_name; + + static TensorizeStep make(int stage_id, int iter_id, + std::string ti_func_name); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.TensorizeStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); +}; +TVM_DEFINE_COW_OBJECT_REF(TensorizeStep, Step, TensorizeStepNode); + } // namespace ansor } // namespace tvm @@ -451,6 +475,11 @@ struct hash<::tvm::ansor::Step> { ::dmlc::HashCombine(std::hash()(ps->iter_id), ::dmlc::HashCombine(std::hash()(ps->factor), ps->offset)))); + } else if (auto ps = step.as<::tvm::ansor::TensorizeStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->ti_func_name))); } else { LOG(FATAL) << "Invalid step"; } diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 612d320036d8..a2c09aafc07b 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -17,6 +17,7 @@ """Test loop state and schedule primitives""" +import tvm from tvm import ansor, te import topi @@ -468,9 +469,46 @@ def test_rfactor(): " C.repl = ...\n" +@tvm._ffi.register_func +def test_intrin_gemv(): + m = 16 + l = 64 + a = te.placeholder((l,), name='a') + b = te.placeholder((l, m), name='b') + k = te.reduce_axis((0, l), name='k') + c = te.compute((m,), lambda i: te.sum(a[k] * b[k, i], axis=k), name='c') + Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", + offset_factor=1, strides=[1]) + Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", + offset_factor=1, strides=[te.var("s0"), 1]) + Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", + offset_factor=1, strides=[1]) + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + aa, bb = ins + cc = outs[0] + ib.emit(tvm.tir.call_extern("float32", "gemv_update", + cc.access_ptr("w"), + aa.access_ptr("r"), + bb.access_ptr("r"))) + return ib.get() + return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) + +def test_tensorize(): + dag = ansor.ComputeDAG(matmul_ansor_test(1024, 512, 64)) + s0 = dag.get_init_state() + C = 2 + + its = s0.split(C, s0.stages[C].iters[1], [16]) + s0.tensorize(C, its[1], "test_intrin_gemv") + + sch, tensors = dag.apply_steps_from_state(s0) + tvm.lower(sch, tensors, simple_mode=True) + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_follow_split_follow_fused_split() test_compute_at_root_inline() test_cache_read_write() test_rfactor() + test_tensorize()