diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 1d16c3428279e..745277308c704 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -136,7 +136,7 @@ using FCompute = std::function& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -Tensor Placeholder(Array shape, +Tensor placeholder(Array shape, Type dtype = Float(32), std::string name = "placeholder"); @@ -147,7 +147,7 @@ Tensor Placeholder(Array shape, * \param fcompute The compute function to create the tensor. * \param name The optional name of the tensor. */ -Tensor Compute(Array shape, FCompute fcompute, std::string name = "tensor"); +Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor"); /*! * \brief Construct new tensors by scan over scan_axis. @@ -158,36 +158,36 @@ Tensor Compute(Array shape, FCompute fcompute, std::string name = "tensor" * \param state_placeholder The placeholder for the states. * \param name The optional name of the tensor. */ -Array Scan(IterVar scan_axis, +Array scan(IterVar scan_axis, Array init, Array update, Array state_placeholder, std::string name = "scan"); // same as compute, specialized for different fcompute function -inline Tensor Compute(Array shape, +inline Tensor compute(Array shape, std::function f, std::string name = "tensor") { FCompute fc = [f] (const Array& i) { return f(i[0]); }; - return Compute(shape, fc, name); + return compute(shape, fc, name); } -inline Tensor Compute(Array shape, +inline Tensor compute(Array shape, std::function f, std::string name = "tensor") { FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; - return Compute(shape, fc, name); + return compute(shape, fc, name); } -inline Tensor Compute(Array shape, +inline Tensor compute(Array shape, std::function f, std::string name = "tensor") { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; - return Compute(shape, fc, name); + return compute(shape, fc, name); } -inline Tensor Compute(Array shape, +inline Tensor compute(Array shape, std::function f, std::string name = "tensor") { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2], i[3]); }; - return Compute(shape, fc, name); + return compute(shape, fc, name); } } // namespace tvm diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index b8c24903f88f7..18407567744a3 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -131,6 +131,13 @@ class Stage : public NodeRef { IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, Expr x_factor, Expr y_factor); + /*! + * \brief Specify thread launching group in + * outer most scope of the stage. + * This is only valid for composite operators. + * \param threads The threads to be launched. + */ + Stage& outermost_threads(Array threads); /*! * \brief Vectorize iteration. * \param var The axis to be vectorized. @@ -179,6 +186,28 @@ class Schedule : public NodeRef { Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } + /*! + * \brief create a cache read of original tensor for readers. + * This will mutate the body of the readers. + * A new stage will be created for the tensor. + * \param tensor The tensor cached. + * \param scope The scope of the cache. + * \param readers The readers to redirect to the tensor. + * \return The created tensor. + */ + Tensor cache_read(const Tensor& tensor, + const std::string& scope, + const Array& readers); + /*! + * \brief Create a cache write tensor for producing tensor. + * The the tensor will take over body of original tensor op. + * The original tensor's body will be changed to an identity read + * from the corresponding cache. + * \param tensor The tensor to be produced. + * \param scope The scope of the storage. + * \return The created tensor. + */ + Tensor cache_write(const Tensor& tensor, const std::string& scope); /*! * \brief Normalize the schedule. * This is needed before bound inference. @@ -193,6 +222,11 @@ class Schedule : public NodeRef { * \return the pointer to the internal node container */ inline const ScheduleNode* operator->() const; + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline ScheduleNode* operator->(); // declare container type using ContainerType = ScheduleNode; }; @@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef { */ class StageNode : public Node { public: - /*! \brief The operation to be scheduled */ - Operation op; /*! \brief The thread scope level of the stage */ std::string scope; + /*! \brief The operation of stage, can be different from original op. */ + Operation op; + /*! + * \brief The original operator. + * The op field can change during schedule to alternate the dataflow, + * while origin_op remains fixed. + */ + Operation origin_op; /*! \brief All the nodes in the iter var */ Array all_iter_vars; /*! @@ -255,6 +295,11 @@ class StageNode : public Node { * Operations can only be performed in leaves. */ Array leaf_iter_vars; + /*! + * \brief Specify threads to be launched at the stage. + * This is only valid for composite ops such as Scan. + */ + Array outermost_threads; /*! \brief The relation bwteen of IterVars */ Array relations; /*! \brief additional attributes about iter var. */ @@ -265,17 +310,22 @@ class StageNode : public Node { IterVar attach_ivar; /*! \brief The stage this node attaches to */ Stage attach_stage; + /*! \brief Whether this is an output stage */ + bool is_output{false}; void VisitAttrs(AttrVisitor* v) final { v->Visit("scope", &scope); v->Visit("op", &op); + v->Visit("origin_op", &origin_op); v->Visit("all_iter_vars", &all_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars); + v->Visit("outermost_threads", &outermost_threads); v->Visit("relations", &relations); v->Visit("iter_var_attrs", &iter_var_attrs); v->Visit("attach_type", &attach_type); v->Visit("attach_ivar", &attach_ivar); v->Visit("attach_stage", &attach_stage); + v->Visit("is_output", &is_output); } static constexpr const char* _type_key = "Stage"; @@ -285,18 +335,18 @@ class StageNode : public Node { /*! \brief node container for schedule */ class ScheduleNode : public Node { public: - /*! \brief The root operations */ - Array roots; + /*! \brief The output operations in original data flow graph */ + Array outputs; /*! - * \brief list of all stages for non-placeholder ops - * The stage are ordered in PostDFS order of their op. + * \brief list of all stages for non-placeholder ops. + * The stages are sorted in dependency order. */ Array stages; /*! \brief map of operation to the stages */ Map stage_map; void VisitAttrs(AttrVisitor* v) final { - v->Visit("roots", &roots); + v->Visit("outputs", &outputs); v->Visit("stages", &stages); v->Visit("stage_map", &stage_map); } @@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() { inline bool Stage::is_scheduled() const { const StageNode* n = operator->(); - return !(n->relations.empty() && n->attach_type == kNone); + return !(n->relations.empty() && n->attach_type == kNone && + n->all_iter_vars.same_as(n->leaf_iter_vars)); } inline const ScheduleNode* Schedule::operator->() const { return static_cast(node_.get()); } +inline ScheduleNode* Schedule::operator->() { + return static_cast(node_.get()); +} inline const IterVarRelationNode* IterVarRelation::operator->() const { return static_cast(node_.get()); diff --git a/python/tvm/build.py b/python/tvm/build.py index 4704efe76face..40cb92b458aa7 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -63,7 +63,6 @@ def build(sch, arg_list.append(x) else: raise ValueError("args must be Tensor, Buffer or Var") - # lowering bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 842b9a6054255..f0db2562d372c 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -4,6 +4,7 @@ from ._ctypes._node import NodeBase, register_node from . import _api_internal from . import tensor as _tensor +from . import collections as _collections @register_node class Buffer(NodeBase): @@ -41,6 +42,53 @@ def normalize(self): """ _api_internal._ScheduleNormalize(self) + def cache_read(self, tensor, scope, readers): + """Create a cache read of original tensor for readers. + + This will mutate the body of the readers. + A new cache stage will be created for the tensor. + Call this before doing any split/fuse schedule. + + Parameters + ---------- + tensor : Tensor + The tensor to be cached. + scope : str + The scope of cached + readers : list of Tensor or Operation + The readers to read the cache. + + Returns + ------- + cache : Tensor + The created cache tensor. + """ + if isinstance(readers, (_tensor.Tensor, _tensor.Operation)): + readers = [readers] + readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers] + return _api_internal._ScheduleCacheRead(self, tensor, scope, readers) + + def cache_write(self, tensor, scope): + """Create a cache write of original tensor, before storing into tensor. + + This will mutate the body of the tensor. + A new cache stage will created before feed into the tensor. + + Parameters + ---------- + tensor : Tensor + The tensor to be feed to. + scope : str + The scope of cached + + Returns + ------- + cache : Tensor + The created cache tensor. + """ + return _api_internal._ScheduleCacheWrite(self, tensor, scope) + + @register_node class Stage(NodeBase): """A Stage represents schedule for one operation.""" @@ -104,6 +152,18 @@ def set_scope(self, scope): """ return _api_internal._StageSetScope(self, scope) + def outermost_threads(self, threads): + """Force launch threads at outermost scope of the stage. + + Parameters + ---------- + threads : list of threads + The threads to be launched. + """ + if isinstance(threads, _collections.IterVar): + threads = [threads] + _api_internal._StageOutermostThreads(self, threads) + def compute_at(self, parent, scope): """Attach the stage at parent's scope diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 32fcc41a1593a..ea49bbae18cf1 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash) TVM_REGISTER_API(_Placeholder) .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Placeholder(args[0], + *ret = placeholder(args[0], args[1], args[2]); }); @@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile) *ret = Array({x_outer, y_outer, x_inner, y_inner}); }); +TVM_REGISTER_API(_StageOutermostThreads) + .set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .outermost_threads(args[1]); + }); + TVM_REGISTER_API(_StageUnroll) .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() @@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize) .normalize(); }); +TVM_REGISTER_API(_ScheduleCacheRead) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Schedule() + .cache_read(args[1], args[2], args[3]); + }); + +TVM_REGISTER_API(_ScheduleCacheWrite) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0].operator Schedule() + .cache_write(args[1], args[2]); + }); + } // namespace tvm diff --git a/src/lang/operation.cc b/src/lang/operation.cc index 9e16f1c1ba386..ddc4770f0bb99 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name, -Tensor Placeholder(Array shape, Type dtype, std::string name) { +Tensor placeholder(Array shape, Type dtype, std::string name) { return PlaceholderOpNode::make(name, shape, dtype).output(0); } @@ -82,7 +82,7 @@ Array ComputeOpNode::output_shape(size_t i) const { return Array(shape); } -Tensor Compute(Array shape, FCompute fcompute, std::string name) { +Tensor compute(Array shape, FCompute fcompute, std::string name) { auto op_node = std::make_shared(); // compute dimension. size_t ndim = shape.size(); @@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name, return Operation(n); } -Array Scan(IterVar scan_axis, +Array scan(IterVar scan_axis, Array init, Array update, Array state_placeholder, diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 66816c955acb7..36ca32d07631f 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -6,9 +6,11 @@ #include namespace tvm { -namespace ir { +namespace schedule { + +using namespace ir; -class ElemWiseDetector : public IRVisitor { +class ElemWiseDetector : public ir::IRVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} @@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor { } for (size_t i = 0; i < axis_.size(); ++i) { - // const Variable *v1 = axis_[i]->var.as(); - // const Variable *v2 = axis[i].as(); if (!axis[i].same_as(axis_[i]->var)) { - // if (!(v1 && v2) || (v1 != v2)) { is_elem_wise_ = false; return; } @@ -52,22 +51,10 @@ bool IsElemWise(const Operation& op) { return false; } -} // namespace ir - -namespace schedule { - void AutoInlineElemWise(Schedule sch) { for (Stage s : sch->stages) { - if (!s.is_scheduled() && ir::IsElemWise(s->op)) { - bool is_root = false; - for (auto r : sch->roots) { - if (r == s->op) { - is_root = true; - break; - } - } - if (!is_root) - s.compute_inline(); + if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) { + s.compute_inline(); } } } diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 88729a3ce42a3..4724d97627a74 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan, const TensorDom& d = tmap.at(output[i]); time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); } - LOG(INFO) << time_dom.size(); CHECK(!rmap->count(scan->scan_axis)); Range sdom = scan->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); @@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op, const ComputeOpNode* compute = op.as(); const TensorDom& tdom = tmap.at(op.output(0)); for (size_t i = 0; i < compute->axis.size(); ++i) { - Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom); + Range r = arith::Union(tdom.data.at(i)).cover_range(compute->axis[i]->dom); CHECK(!rmap->count(compute->axis[i])); (*rmap)[compute->axis[i]] = r; } @@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage, direct_consume_by_parent = true; } } + } else { + LOG(INFO) << "not in feed graph consumer = " << stage->op; } } // The relax set @@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage, } FeedGraph CreateFeedGraph(const Schedule& sch) { - auto g = CreateReadGraph(sch->roots); + Array roots; + for (Operation op : sch->outputs) { + roots.push_back(sch->stage_map[op]->op); + } + auto g = CreateReadGraph(roots); FeedGraph fg; for (auto kv : g) { for (Tensor t : kv.second) { @@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) { Map InferBound(const Schedule& sch) { FeedGraph feed_graph = CreateFeedGraph(sch); AttachPath attach_path = CreateAttachPath(sch); + std::unordered_map ret; for (size_t i = sch->stages.size(); i != 0; --i) { const Stage& stage = sch->stages[i - 1]; diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 3975e4e9033c1..b18ae28e54754 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -3,6 +3,8 @@ * \file schedule.cc */ #include +#include +#include #include "./graph.h" namespace tvm { @@ -10,7 +12,8 @@ namespace tvm { namespace { // find first occurance location in leaf -size_t FindIterVar(ArrayNode* array_node, const IterVar& v) { +template +size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Node* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { if (array_node->data[i].get() == n) return i; @@ -19,10 +22,10 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) { } size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) { - size_t pos = FindIterVar(leaf_vars, v); + size_t pos = FindNodeRef(leaf_vars, v); if (pos < leaf_vars->data.size()) return pos; - if (FindIterVar(all_vars, v) < all_vars->data.size()) { + if (FindNodeRef(all_vars, v) < all_vars->data.size()) { LOG(FATAL) << "Operate on iter var " << v << "that has already been splitted"; } else { @@ -68,8 +71,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Stage::Stage(Operation op) { auto n = std::make_shared(); n->op = op; + n->origin_op = op; n->all_iter_vars = op->root_iter_vars(); - n->leaf_iter_vars = op->root_iter_vars(); + n->leaf_iter_vars = n->all_iter_vars; node_ = n; } @@ -89,7 +93,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) } } CHECK(found) - << "Cannot find the specified axis in parent stage's leaf_iter_vars"; + << "Cannot find the axis in parent's leaf_iter_vars or outermost_threads"; return *this; } @@ -176,13 +180,63 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent, return *this; } +Stage& Stage::outermost_threads(Array threads) { + StageNode* self = operator->(); + CHECK(self->op.as()) + << "outermost_threads is only valid for composite ops such as ScanOp"; + CHECK_EQ(self->outermost_threads.size(), 0U) + << "Already set outermost_threads"; + ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + std::vector > temp; + for (IterVar iv : threads) { + temp.push_back(iv.node_); + } + leaf_vars->data.insert( + leaf_vars->data.begin(), temp.begin(), temp.end()); + all_vars->data.insert( + all_vars->data.end(), temp.begin(), temp.end()); + (*this)->outermost_threads = threads; + return *this; +} + +inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) { + ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + FindLeafVar(all_vars, leaf_vars, var); + auto it = self->iter_var_attrs.find(var); + if (it != self->iter_var_attrs.end()) { + CHECK_EQ((*it).second->iter_type, attr->iter_type) + << "IterVar's is already set to " + << (*it).second << " instead of " << attr; + } else { + self->iter_var_attrs.Set(var, attr); + } +} + +Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + SetAttr(operator->(), var, IterVarAttr(kVectorized)); + return *this; +} + +Stage& Stage::unroll(IterVar var) { // NOLINT(*) + SetAttr(operator->(), var, IterVarAttr(kUnrolled)); + return *this; +} + Schedule::Schedule(Array ops) { auto n = std::make_shared(); - n->roots = ops; - auto g = schedule::CreateReadGraph(n->roots); - Array post_order = schedule::PostDFSOrder(n->roots, g); + n->outputs = ops; + auto g = schedule::CreateReadGraph(n->outputs); + Array post_order = schedule::PostDFSOrder(n->outputs, g); + // output set. + std::unordered_set output_set; + for (Operation x : ops) { + output_set.insert(x); + } for (Operation op : post_order) { Stage stage(op); + stage->is_output = output_set.count(op); n->stages.push_back(stage); n->stage_map.Set(op, stage); } @@ -237,7 +291,7 @@ void Schedule::normalize() { ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); for (IterVar iv : root_iter_vars) { - size_t idx = FindIterVar(leaf_vars, iv); + size_t idx = FindNodeRef(leaf_vars, iv); if (idx < leaf_vars->data.size()) { // insert rebase IterVar rebased(Range(), iv->var->name_hint + ".rb"); @@ -262,30 +316,6 @@ IterVarAttr::IterVarAttr(IterVarType t) { node_ = n; } -inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) { - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - FindLeafVar(all_vars, leaf_vars, var); - auto it = self->iter_var_attrs.find(var); - if (it != self->iter_var_attrs.end()) { - CHECK_EQ((*it).second->iter_type, attr->iter_type) - << "IterVar's is already set to " - << (*it).second << " instead of " << attr; - } else { - self->iter_var_attrs.Set(var, attr); - } -} - -Stage& Stage::vectorize(IterVar var) { // NOLINT(*) - SetAttr(operator->(), var, IterVarAttr(kVectorized)); - return *this; -} - -Stage& Stage::unroll(IterVar var) { // NOLINT(*) - SetAttr(operator->(), var, IterVarAttr(kUnrolled)); - return *this; -} - TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(IterVarAttrNode); TVM_REGISTER_NODE_TYPE(SplitNode); @@ -293,4 +323,190 @@ TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); +using ir::TensorKey; + +// The replacer of cache. +class TensorReplacer : public ir::IRMutator { + public: + TensorReplacer(const std::unordered_map& vmap) + : vmap_(vmap) {} + Expr Mutate_(const ir::Call* op, const Expr& e) { + if (op->call_type == ir::Call::Halide) { + ir::TensorKey key{op->func, op->value_index}; + auto it = vmap_.find(key); + if (it != vmap_.end()) { + Expr ret = ir::Call::make( + op->type, it->second->op->name, op->args, + op->call_type, it->second->op, it->second->value_index); + found = true; + return IRMutator::Mutate_(ret.as(), ret); + } + } + return IRMutator::Mutate_(op, e); + } + + // whether it is found. + bool found{false}; + + private: + const std::unordered_map& vmap_; +}; + +class VarReplacer : public ir::IRMutator { + public: + explicit VarReplacer( + const std::unordered_map& vsub) + : vsub_(vsub) {} + Expr Mutate_(const Variable* op, const Expr& e) { + auto it = vsub_.find(op); + if (it != vsub_.end()) return it->second; + return e; + } + + private: + const std::unordered_map& vsub_; +}; + +// Replace data flow appears in all stages given the tensor change. +// Also update vmap if subsequent dataflow need to be replaced. +void ReplaceDataFlow(const Array& stages, + std::unordered_map* vmap) { + for (Stage s : stages) { + if (s->op.as()) { + const ComputeOpNode* compute = s->op.as(); + TensorReplacer repl(*vmap); + Expr body = repl.Mutate(compute->body); + if (repl.found) { + Operation op = ComputeOpNode::make( + compute->name, compute->axis, body); + (*vmap)[TensorKey{s->op, 0}] = op.output(0); + s->op = op; + } + } else if (s->op.as()) { + const ScanOpNode* scan = s->op.as(); + std::shared_ptr n = + std::make_shared(*scan); + // copy on write semantics ganrantees correctness + for (size_t i = 0; i < n->init.size(); ++i) { + TensorKey key{n->init[i]->op, n->init[i]->value_index}; + if (vmap->count(key)) { + n->init.Set(i, vmap->at(key)); + } + } + for (size_t i = 0; i < n->update.size(); ++i) { + TensorKey key{n->update[i]->op, n->update[i]->value_index}; + if (vmap->count(key)) { + n->update.Set(i, vmap->at(key)); + } + } + if (!n->init.same_as(scan->init) || + !n->update.same_as(scan->update)) { + Operation op(n); + for (int i = 0; i < op->num_outputs(); ++i) { + (*vmap)[TensorKey{s->op, i}] = op.output(i); + } + s->op = op; + } + } else if (s->op.as()) { + } else { + LOG(FATAL) << "unhandled problem"; + } + } +} + +Tensor Schedule::cache_read(const Tensor& tensor, + const std::string& scope, + const Array& readers) { + // create identity mapping. + std::ostringstream os; + os << tensor->op->name; + if (tensor->op->num_outputs() != 1) { + os << ".v" << tensor->value_index; + } + os << "." << scope; + + Tensor cache = compute(tensor->shape, [&tensor](const Array& i) { + return tensor(Array(i.begin(), i.end())); + }, os.str()); + std::unordered_map vsub; + vsub[TensorKey{tensor->op, tensor->value_index}] = cache; + + std::unordered_map vmap; + for (Operation op : readers) { + const ComputeOpNode* compute = op.as(); + CHECK(compute) + << "cache read only take ComputeOp as readers"; + Stage s = operator[](op); + compute = s->op.as(); + + TensorReplacer repl(vsub); + Expr body = repl.Mutate(compute->body); + CHECK(repl.found) + << "Cannot find " << tensor + << " in the body of specified reader" << op; + Operation repl_op = ComputeOpNode::make( + compute->name, compute->axis, body); + vmap[TensorKey{s->op, 0}] = repl_op.output(0); + s->op = repl_op; + } + ReplaceDataFlow((*this)->stages, &vmap); + ArrayNode* stages = (*this)->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, operator[](tensor->op)); + Stage cache_stage = Stage(cache->op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos + 1, + cache_stage.node_); + (*this)->stage_map.Set(cache->op, cache_stage); + return cache; +} + +Tensor Schedule::cache_write(const Tensor& tensor, + const std::string& scope) { + Stage orig_stage = operator[](tensor->op); + const ComputeOpNode* compute = tensor->op.as(); + CHECK(compute) + << "cache write only take ComputeOp as writers"; + CHECK(!orig_stage.is_scheduled()) + << "Create cache_write before doing split/fuse/reorder"; + compute = orig_stage->op.as(); + CHECK(compute); + Array args; + Array new_axis; + std::unordered_map vsub; + for (IterVar iv : compute->axis) { + args.push_back(iv->var); + IterVar new_iv(iv->dom, iv->var->name_hint + ".c"); + new_axis.push_back(new_iv); + vsub[iv->var.get()] = new_iv->var; + } + VarReplacer repl(vsub); + Expr body = repl.Mutate(compute->body); + Operation cache_op = ComputeOpNode::make( + compute->name + "." + scope, new_axis, body); + Tensor cache_tensor = cache_op.output(0); + Operation orig_new_op = ComputeOpNode::make( + compute->name, compute->axis, + cache_tensor(args)); + + std::unordered_map vmap; + vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0); + ReplaceDataFlow((*this)->stages, &vmap); + + // mutate orig stage + orig_stage->op = orig_new_op; + orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); + orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; + // create schedule for new cached stage. + ArrayNode* stages = (*this)->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, orig_stage); + Stage cache_stage = Stage(cache_op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos, + cache_stage.node_); + (*this)->stage_map.Set(cache_op, cache_stage); + return cache_tensor; +} + } // namespace tvm diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index c69381967ec24..aa7c383635efa 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -23,7 +23,8 @@ using namespace ir; // Two private scope marks namespace attr { constexpr const char* loop_scope = "loop_scope"; -constexpr const char* scan_scope = "scan_scope"; +constexpr const char* scan_update_scope = "scan_update_scope"; +constexpr const char* scan_init_scope = "scan_init_scope"; } // namespace attr /*! @@ -280,23 +281,31 @@ Stmt MakeLoop(const Stage& s, if (init.defined()) { // try to find the location to insert the initialization. // Fuse the initialization and provide loop when possible. - std::unordered_map reduce_state; + std::unordered_map update_state; const ComputeOpNode* compute = s->op.as(); - for (IterVar iv : compute->reduce_axis) { - reduce_state[iv] = 2; - } - for (IterVar iv : compute->axis) { - reduce_state[iv] = 1; + const ScanOpNode* scan = s->op.as(); + if (compute) { + for (IterVar iv : compute->reduce_axis) { + update_state[iv] = 2; + } + for (IterVar iv : compute->axis) { + update_state[iv] = 1; + } + } else if (scan) { + update_state[scan->scan_axis] = 2; + for (IterVar iv : s->outermost_threads) { + update_state[iv] = 1; + } } // find which iter var is related to reduction and which is related to axis. - PassDownFlag(s, &reduce_state); + PassDownFlag(s, &update_state); auto leaf_iter_vars = s->leaf_iter_vars; std::unordered_map init_value_map; // first first loop that is related to reduction. size_t begin_loop = leaf_iter_vars.size(); for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { auto iv = leaf_iter_vars[i]; - int flag = reduce_state.at(iv); + int flag = update_state.at(iv); if ((flag & 2) != 0) { begin_loop = i; break; } @@ -304,7 +313,7 @@ Stmt MakeLoop(const Stage& s, } // skip loops that does not relates to axis. std::unordered_map skip_iter; - for (auto kv : reduce_state) { + for (auto kv : update_state) { int flag = kv.second; if ((flag & 1) == 0) skip_iter[kv.first] = true; } @@ -422,7 +431,10 @@ Stmt MakePipeline(const Stage& s, } else if (scan) { // Provide is done by the sub operations. provide = AttrStmt::make( - s->op, attr::scan_scope, scan->scan_axis->var, + s->op, attr::scan_update_scope, scan->scan_axis->var, + Evaluate::make(0)); + init = AttrStmt::make( + s->op, attr::scan_init_scope, 0, Evaluate::make(0)); } else { LOG(FATAL) << "not supported op " << s->op->type_key(); @@ -472,7 +484,9 @@ class InjectAttach : public IRMutator { const AttrStmt* op = stmt.as(); if (op != nullptr && op->type_key == attr::loop_scope) { - if (op->node == stage_->attach_ivar) { + CHECK_NE(producer_.size(), 0U); + if (op->node == stage_->attach_ivar && + producer_.back() == stage_->attach_stage->op.get()) { CHECK(!found_attach); found_attach = true; stmt = AttrStmt::make( @@ -482,6 +496,16 @@ class InjectAttach : public IRMutator { } return stmt; } + Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { + if (op->is_producer) { + producer_.push_back(op->func.get()); + Stmt ret = IRMutator::Mutate_(op, s); + producer_.pop_back(); + return ret; + } else { + return IRMutator::Mutate_(op, s); + } + } // whether attach point is found bool found_attach{false}; @@ -490,6 +514,8 @@ class InjectAttach : public IRMutator { const Stage& stage_; // domain map const Map& dom_map_; + // internal stack about realization scope. + std::vector producer_; }; // inject the operator's realization on the stmt. @@ -505,27 +531,16 @@ class InjectScanStep : public IRMutator { Stmt Mutate(Stmt stmt) final { CHECK(stmt.defined()); stmt = IRMutator::Mutate(stmt); - if (is_init_) { - const ProducerConsumer* op = stmt.as(); - if (op != nullptr && - op->is_producer && - op->func.same_as(scan_op_)) { - stmt = ProducerConsumer::make( - op->func, true, - MakePipeline(stage_, dom_map_, op->body)); + // update + const AttrStmt* op = stmt.as(); + if (op != nullptr && + ((op->type_key == attr::scan_update_scope && !is_init_) || + (op->type_key == attr::scan_init_scope && is_init_))) { + if (op->node.same_as(scan_op_)) { found_attach = true; - } - } else { - // update - const AttrStmt* op = stmt.as(); - if (op != nullptr && - op->type_key == attr::scan_scope) { - if (op->node.same_as(scan_op_)) { - found_attach = true; - stmt = AttrStmt::make( - op->node, op->type_key, op->value, - MakePipeline(stage_, dom_map_, op->body)); - } + stmt = AttrStmt::make( + op->node, op->type_key, op->value, + MakePipeline(stage_, dom_map_, op->body)); } } return stmt; @@ -561,8 +576,15 @@ Stmt InjectInline(const Operation op, Stmt body) { class SchedulePostProc : public IRMutator { public: Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { - if (to_remove_.count(op->func.get())) { - return this->Mutate(op->body); + auto it = replace_op_.find(op->func.get()); + if (it != replace_op_.end()) { + Stmt body = this->Mutate(op->body); + if (it->second.defined()) { + return ProducerConsumer::make( + it->second, op->is_producer, body); + } else { + return body; + } } else { return IRMutator::Mutate_(op, s); } @@ -579,14 +601,23 @@ class SchedulePostProc : public IRMutator { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (op->type_key == attr::loop_scope) { return this->Mutate(op->body); - } else if (op->type_key == attr::scan_scope) { + } else if (op->type_key == attr::scan_init_scope) { + return this->Mutate(op->body); + } else if (op->type_key == attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); CHECK(scan); var_value_[scan->scan_axis->var.get()] = op->value; return this->Mutate(op->body); } else if (op->type_key == ir::attr::realize_scope) { - if (to_remove_.count(op->node.get())) { - return this->Mutate(op->body); + auto it = replace_op_.find(op->node.get()); + if (it != replace_op_.end()) { + if (it->second.defined()) { + Stmt ret = AttrStmt::make( + it->second, op->type_key, op->value, op->body); + return this->Mutate_(ret.as(), ret); + } else { + return this->Mutate(op->body); + } } } return IRMutator::Mutate_(op, s); @@ -594,8 +625,16 @@ class SchedulePostProc : public IRMutator { Stmt Mutate_(const Realize* op, const Stmt& s) final { TensorKey key{op->func, op->value_index}; - if (replace_.count(key)) { - return this->Mutate(op->body); + auto it = replace_realize_.find(key); + if (it != replace_realize_.end()) { + if (it->second.defined()) { + Stmt ret = Realize::make( + it->second->op, it->second->value_index, + op->type, op->bounds, op->condition, op->body); + return this->Mutate_(ret.as(), ret); + } else { + return this->Mutate(op->body); + } } else { return IRMutator::Mutate_(op, s); } @@ -603,8 +642,8 @@ class SchedulePostProc : public IRMutator { Stmt Mutate_(const Provide* op, const Stmt& s) final { TensorKey key{op->func, op->value_index}; - auto it = replace_.find(key); - if (it != replace_.end()) { + auto it = replace_buffer_.find(key); + if (it != replace_buffer_.end()) { const Tensor& dst = it->second.first; Stmt ret = Provide::make( dst->op, dst->value_index, op->value, @@ -616,10 +655,10 @@ class SchedulePostProc : public IRMutator { } Expr Mutate_(const Call* op, const Expr& e) final { - if (op != nullptr && op->call_type == Call::Halide) { + if (op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; - auto it = replace_.find(key); - if (it != replace_.end()) { + auto it = replace_buffer_.find(key); + if (it != replace_buffer_.end()) { const Tensor& dst = it->second.first; Expr ret = Call::make( op->type, dst->op->name, @@ -642,22 +681,32 @@ class SchedulePostProc : public IRMutator { void Init(const Schedule& sch) { for (Stage s : sch->stages) { - const ScanOpNode* scan = s->op.as(); - if (!scan) continue; - for (size_t i = 0; i < scan->update.size(); ++i) { - Tensor t = s->op.output(i); - AddReplace(scan->init[i], t, Expr()); - AddReplace(scan->update[i], t, scan->scan_axis->var); - AddReplace(scan->state_placeholder[i], t, Expr()); + if (s->op.as()) { + const ScanOpNode* scan = s->op.as(); + for (size_t i = 0; i < scan->update.size(); ++i) { + Tensor t = s->origin_op.output(i); + AddReplace(scan->init[i], t, Expr()); + AddReplace(scan->update[i], t, scan->scan_axis->var); + AddReplace(scan->state_placeholder[i], t, Expr()); + } + } else if (!s->op.same_as(s->origin_op)) { + Tensor target = s->origin_op.output(0); + AddReplace(s->op.output(0), target, + Expr(), target, s->origin_op); } } } private: - void AddReplace(Tensor src, Tensor dst, Expr head_idx) { - replace_[TensorKey{src->op, src->value_index}] - = std::make_pair(dst, head_idx); - to_remove_.insert(src->op.get()); + void AddReplace(Tensor src, + Tensor dst, + Expr head_idx, + Tensor repl_realize = Tensor(), + Operation repl_op = Operation()) { + TensorKey key{src->op, src->value_index}; + replace_buffer_[key] = std::make_pair(dst, head_idx); + replace_realize_[key] = repl_realize; + replace_op_[src->op.get()] = repl_op; } Array RewriteArgs(Expr head, Array args) { if (!head.defined()) return args; @@ -670,9 +719,11 @@ class SchedulePostProc : public IRMutator { // The scan value std::unordered_map var_value_; // buffer replacement - std::unordered_map > replace_; - // replaced functions - std::unordered_set to_remove_; + std::unordered_map > replace_buffer_; + // buffere realization to be replaced + std::unordered_map replace_realize_; + // replace producer consumer. + std::unordered_map replace_op_; }; Stmt ScheduleOps( @@ -724,7 +775,9 @@ Stmt ScheduleOps( InjectAttach mutator(s, dom_map); body = mutator.Mutate(body); CHECK(mutator.found_attach) - << "did not find attachment point"; + << "did not find attachment point for " << s << " in" + << s->attach_stage->op << " x " + << body; } } SchedulePostProc post_proc; diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index a62a0b09af0ea..c33d4c6aeed8e 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -6,10 +6,10 @@ TEST(Tensor, Basic) { using namespace tvm; Var m("m"), n("n"), l("l"); - Tensor A = Placeholder({m, l}, Float(32), "A"); - Tensor B = Placeholder({n, l}, Float(32), "B"); + Tensor A = placeholder({m, l}, Float(32), "A"); + Tensor B = placeholder({n, l}, Float(32), "B"); - auto C = Compute({m, n}, [&](Var i, Var j) { + auto C = compute({m, n}, [&](Var i, Var j) { return A[i][j]; }, "C"); @@ -20,11 +20,11 @@ TEST(Tensor, Basic) { TEST(Tensor, Reduce) { using namespace tvm; Var m("m"), n("n"), l("l"); - Tensor A = Placeholder({m, l}, Float(32), "A"); - Tensor B = Placeholder({n, l}, Float(32), "B"); + Tensor A = placeholder({m, l}, Float(32), "A"); + Tensor B = placeholder({n, l}, Float(32), "B"); IterVar rv(Range{0, l}, "k"); - auto C = Compute({m, n}, [&](Var i, Var j) { + auto C = compute({m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index feab008f08c4c..587f42d956db3 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -22,21 +22,14 @@ def test_gemm(): l = n A = tvm.placeholder((n, l), name='A') B = tvm.placeholder((m, l), name='B') - AA = tvm.compute(A.shape, lambda *i : A(*i), name="AA") - BB = tvm.compute(B.shape, lambda *i : B(*i), name="BB") k = tvm.IterVar((0, l), name='k') - CC = tvm.compute( + C = tvm.compute( (n, m), - lambda ii, jj: tvm.sum(AA[ii, k] * BB[jj, k], axis=k), + lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k), name='CC') - C = tvm.compute(CC.shape, lambda *i: CC(*i), name="C") - # schedule s = tvm.Schedule(C.op) xtile, ytile = 32, 32 - s[AA].set_scope("shared") - s[BB].set_scope("shared") - scale = 8 num_thread = 8 block_factor = scale * num_thread @@ -45,6 +38,9 @@ def test_gemm(): block_y = tvm.IterVar(thread_tag="blockIdx.y") thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y") + CC = s.cache_write(C, "local") + AA = s.cache_read(A, "shared", [CC]) + BB = s.cache_read(B, "shared", [CC]) _, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y) _, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x) s[C].reorder(block_y, block_x, yi, xi) @@ -64,8 +60,10 @@ def test_gemm(): _, xi = s[BB].split(xi, outer=thread_x) max_auto_unroll_step = 0 + print("x") # lowering test s.normalize() + print("x>>") def check_device(target): codes = [] diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 69011e9257f5f..c413a220cc8bb 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -22,13 +22,13 @@ def test_schedule_create(): json_str = tvm.save_json(s) s_loaded = tvm.load_json(json_str) assert isinstance(s_loaded, tvm.schedule.Schedule) - assert(str(s_loaded.roots[0].body) == str(s.roots[0].body)) + assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body)) # pickle unpickle dump = pkl.dumps(s) s_loaded = pkl.loads(dump) assert isinstance(s_loaded, tvm.schedule.Schedule) - assert(str(s_loaded.roots[0].body) == str(s.roots[0].body)) + assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body)) def test_reorder(): m = tvm.Var('m') diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 278d1cc53be68..625bee5964141 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -74,6 +74,20 @@ def test_auto_inline(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_cache(): + m = tvm.Var('m') + n = tvm.Var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C') + + s = tvm.Schedule(C.op) + AA = s.cache_read(A, "shared", readers=[C]) + CC = s.cache_write(C, "shared") + s[AA].compute_at(s[CC], CC.op.axis[0]) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + if __name__ == "__main__": test_schedule_scan() @@ -81,3 +95,4 @@ def test_auto_inline(): test_schedule1() test_schedule2() test_auto_inline() + test_schedule_cache()