From 595dc9485fce8f65c2f79e562fe3c82a91c7aa23 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 11 Feb 2017 23:17:47 -0800 Subject: [PATCH] [LANG] Introduce Scan, Bugfix Canonical --- include/tvm/ir.h | 35 ++- include/tvm/operation.h | 64 ++++ python/tvm/api.py | 55 +++- python/tvm/tensor.py | 9 +- src/api/api_lang.cc | 9 + src/arithmetic/canonical.cc | 11 +- src/codegen/codegen_cuda.cc | 2 +- src/codegen/codegen_cuda.h | 2 +- src/lang/operation.cc | 87 ++++++ src/pass/inject_virtual_thread.cc | 20 +- src/pass/storage_flatten.cc | 34 -- src/schedule/bound.cc | 125 +++++++- src/schedule/graph.cc | 22 +- src/schedule/schedule_lang.cc | 2 + src/schedule/schedule_ops.cc | 290 ++++++++++++++++-- tests/python/integration/test_scan.py | 54 ++++ tests/python/unittest/test_lang_tensor.py | 14 + tests/python/unittest/test_pass_simplify.py | 10 +- .../unittest/test_schedule_schedule_ops.py | 48 ++- 19 files changed, 776 insertions(+), 117 deletions(-) create mode 100644 tests/python/integration/test_scan.py diff --git a/include/tvm/ir.h b/include/tvm/ir.h index d6a258053e11..e6aa692af379 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -49,12 +49,27 @@ struct Reduce : public ExprNode { static constexpr const char* Min = "Min"; }; -/*! \brief namespace of possible attribute sin AttrStmt.type_key */ -namespace attr { /*! - * \brief Mark scope of iteration variable, used by Schedule. + * \brief Auxiliary data structure used in IR Pass to indicate a tensor. */ -constexpr const char* scope = "scope"; +struct TensorKey { + FunctionRef f; + int value_index; + + inline bool operator==(const TensorKey& other) const { + return f == other.f && value_index == other.value_index; + } + inline std::string GetName() const { + if (f->num_outputs() == 1) return f->func_name(); + std::ostringstream os; + os << f->func_name() << ".v" << value_index; + return os.str(); + } +}; + +/*! \brief namespace of possible attribute sin AttrStmt.type_key */ +namespace attr { +// The above attr does not pass to ir stage. /*! * \brief Mark launching extent of thread, used by device API. */ @@ -189,4 +204,16 @@ using Halide::Internal::Evaluate; } // namespace ir } // namespace tvm +namespace std { +template <> +struct hash<::tvm::ir::TensorKey> { + std::size_t operator()(const ::tvm::ir::TensorKey& k) const { + size_t lhs = k.f.hash(); + size_t rhs = static_cast(k.value_index); + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; +} // namespace std + #endif // TVM_IR_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index a48d0e5b8e6e..1d16c3428279 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode); }; +/*! + * \brief Symbolic scan. + */ +class ScanOpNode : public OperationNode { + public: + /*! \brief IterVar to scan over */ + IterVar scan_axis; + /*! \brief the initialization tensors */ + Array init; + /*! \brief the update function represented by tensor */ + Array update; + /*! \brief The placeholder to refer as states in update. */ + Array state_placeholder; + /*! + * \brief Spatial axis to indicate spatial dimension of each output. + * They corresponds to flattened spatial axis of the outputs. + * + * [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...] + * These are auxiliary data structure for storing result of bound inference. + * They do not corresponds to splittable iterations, thus the name comes + * with underscore. + */ + Array spatial_axis_; + /*! \brief constructor */ + ScanOpNode() {} + // override behavior. + int num_outputs() const final; + Array root_iter_vars() const final; + Type output_dtype(size_t i) const final; + Array output_shape(size_t i) const final; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("scan_axis", &scan_axis); + v->Visit("init", &init); + v->Visit("update", &update); + v->Visit("state_placeholder", &state_placeholder); + v->Visit("spatial_axis_", &spatial_axis_); + } + static Operation make(std::string name, + IterVar axis, + Array init, + Array update, + Array state_placeholder); + + static constexpr const char* _type_key = "ScanOp"; + TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode); +}; + /*! \brief The compute function to specify the input source of a Tensor */ using FCompute = std::function& i)>; @@ -100,6 +149,21 @@ Tensor Placeholder(Array shape, */ Tensor Compute(Array shape, FCompute fcompute, std::string name = "tensor"); +/*! + * \brief Construct new tensors by scan over scan_axis. + * + * \param scan_axis The iteration representing the scan. + * \param init The intialize tensor of first K steps. + * \param update The update tensor indicated the updated result after each timestamp. + * \param state_placeholder The placeholder for the states. + * \param name The optional name of the tensor. + */ +Array Scan(IterVar scan_axis, + Array init, + Array update, + Array state_placeholder, + std::string name = "scan"); + // same as compute, specialized for different fcompute function inline Tensor Compute(Array shape, std::function f, diff --git a/python/tvm/api.py b/python/tvm/api.py index bb1a563b23fa..2c3f544836d4 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -14,6 +14,7 @@ from . import _api_internal from . import make as _make from . import expr as _expr +from . import tensor as _tensor from . import collections as _collections int32 = "int32" @@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"): shape: Tuple of Expr The shape of the tensor - fcompute: lambda function of *indices-> value Specifies the input source expression @@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"): body = convert(body) op_node = _api_internal._ComputeOp( name, dim_var, body) - return _api_internal._Tensor( - shape, body.dtype, op_node, 0) + return op_node.output(0) + + +def scan(axis, init, update, state_placeholder, name="scan"): + """Construct new tensors by scanning over axis. + + Parameters + ---------- + axis: IterVar + The scanning axis. + + init: Tensor or list of Tensor + The initial condition of first init.shape[0] timestamps + + update: Tensor or list of Tensor + The update rule of the scan given by symbolic tensor. + + state_placeholder: Tensor or list of Tensor + The placeholder variables used by update. + + name: str, optional + The name hint of the tensor + + Returns + ------- + tensor: tensor.Tensor + The created tensor + + Example + ------- + # The following code is equivalent to numpy.cumsum + m = tvm.Var("m") + n = tvm.Var("n") + t = tvm.IterVar((1, m), name="t") + X = tvm.placeholder((m, n), name="X") + s_state = tvm.placeholder((m, n)) + s_init = tvm.compute((1, n), lambda _, i: X[0, i]) + s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i]) + res = tvm.scan(t, s_init, s_update, s_state) + """ + if isinstance(init, _tensor.Tensor): + init = [init] + if isinstance(update, _tensor.Tensor): + update = [update] + if isinstance(state_placeholder, _tensor.Tensor): + state_placeholder = [state_placeholder] + if len(init) != len(update) or len(init) != len(state_placeholder): + raise ValueError("init, update, state_placeholder must have same length") + op = _api_internal._ScanOp(name, axis, init, update, state_placeholder) + res = [op.output(i) for i in range(len(update))] + return (res[0] if len(res) == 1 else res) def Buffer(shape, dtype=None, diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 47a7ec88c7ef..2dbab96defe8 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -74,12 +74,17 @@ def output(self, index): """ return _api_internal._OpGetOutput(self, index) +@register_node +class PlaceholderOp(Operation): + """Placeholder operation.""" + pass + @register_node class ComputeOp(Operation): """Compute operation.""" pass @register_node -class PlaceholderOp(Operation): - """Placeholder operation.""" +class ScanOp(Operation): + """Scan operation.""" pass diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 769345fc415e..32fcc41a1593 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp) args[2]); }); +TVM_REGISTER_API(_ScanOp) +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ScanOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4]); + }); + TVM_REGISTER_API(_OpGetOutput) .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Operation().output( diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 8ae8ed47e0d5..ae95b04a5305 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator { const ComExpr& sumb, int bscale) { std::shared_ptr n = std::make_shared(); - n->base = suma->base + sumb->base; + n->base = suma->base + sumb->base * bscale; // merge of suma and sumb; size_t i = 0, j = 0; while (i < suma->elem.size() && j < sumb->elem.size()) { @@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator { // convert sum to expr Expr Sum2Expr(const ComExpr& com, Type t) { Expr vsum; - if (com->base != 0) { + if (com->base > 0) { vsum = make_const(t, com->base); } for (const ComExprEntry& e : com->elem) { @@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator { } } } + if (com->base < 0) { + if (vsum.defined()) { + vsum = Sub::make(vsum, make_const(t, -com->base)); + } else { + vsum = make_const(t, com->base); + } + } for (const ComExprEntry& e : com->elem) { if (e.scale < 0) { Expr v = e.value; diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index c4c5d99f35ad..c526ec8d0587 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -168,7 +168,7 @@ MakeNVRTC(Array funcs) { const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc"); code = f(code).operator std::string(); } - LOG(INFO) << code; + std::string ptx; if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) { const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile"); diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 428f9ffddd2e..641c28f95ee7 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC { private: // magic number to add pragma unroll to it. // used to generate code that is compact but still unrolls. - int max_auto_unroll_{8}; + int max_auto_unroll_{1025}; }; } // namespace codegen diff --git a/src/lang/operation.cc b/src/lang/operation.cc index 95c292e48dd2..9e16f1c1ba38 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace tvm { @@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ComputeOpNode); +// Scan +inline bool prove_equal(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + +int ScanOpNode::num_outputs() const { + return update.size(); +} +Array ScanOpNode::root_iter_vars() const { + return Array{scan_axis}; +} + +Type ScanOpNode::output_dtype(size_t i) const { + return update[i]->dtype; +} + +Array ScanOpNode::output_shape(size_t i) const { + CHECK_LT(i, state_placeholder.size()); + return state_placeholder[i]->shape; +} + +Operation ScanOpNode::make(std::string name, + IterVar axis, + Array init, + Array update, + Array state_placeholder) { + auto n = std::make_shared(); + CHECK_EQ(init.size(), update.size()); + CHECK_EQ(init.size(), state_placeholder.size()); + + for (size_t i = 0; i < init.size(); ++i) { + CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); + CHECK_EQ(init[i]->dtype, update[i]->dtype); + CHECK(can_prove(init[i]->shape[0] == axis->dom->min)) + << "init.shape[0] need to match scan_axis.dom.min"; + CHECK(prove_equal( + state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) + << "shate_placeholder.shape[0] need to match" + << " scan_axis.dom.min + scan_axis.dom.extent"; + CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) + << "The dimension of init need to match state_placeholder"; + CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim()) + << "The update.ndim need to be state_placeholder.ndim - 1"; + for (size_t k = 0; k < update[i].ndim(); ++k) { + CHECK(prove_equal( + update[i]->shape[k], state_placeholder[i]->shape[k + 1])); + // setup spatial axis + std::ostringstream spatial_name; + spatial_name << name << ".out" << i << ".i" << k + 1; + n->spatial_axis_.push_back( + IterVar(Range::make_with_min_extent(0, update[i]->shape[k]), + spatial_name.str())); + } + for (size_t k = 1; k < init[i].ndim(); ++k) { + CHECK(prove_equal( + init[i]->shape[k], state_placeholder[i]->shape[k])); + } + } + + n->name = name; + n->scan_axis = axis; + n->init = init; + n->update = update; + n->state_placeholder = state_placeholder; + return Operation(n); +} + +Array Scan(IterVar scan_axis, + Array init, + Array update, + Array state_placeholder, + std::string name) { + Operation op = ScanOpNode::make( + name, scan_axis, init, update, state_placeholder); + Array res; + for (int i = 0; i < op->num_outputs(); ++i) { + res.push_back(op.output(i)); + } + return res; +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ScanOpNode *op, IRPrinter *p) { + p->stream << "scan(" << op->name << ", " << op << ")"; +}); + } // namespace tvm diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 0a9f5b38ff62..2ca1d7c4158a 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -191,20 +191,16 @@ class VTInjector : public IRMutator { } // Attribute Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { - if (op->type_key == attr::scope) { - return Mutate(op->body); + Expr value = Mutate(op->value); + if (visit_touched_var_) { + return InjectVTLoop(s, true); } else { - Expr value = Mutate(op->value); - if (visit_touched_var_) { - return InjectVTLoop(s, true); + Stmt body = Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; } else { - Stmt body = Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return s; - } else { - return AttrStmt::make(op->node, op->type_key, value, body); - } + return AttrStmt::make(op->node, op->type_key, value, body); } } } diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 944a8c0a496d..e7a881640ece 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -11,40 +11,6 @@ namespace tvm { namespace ir { -// key of function buffer -struct TensorKey { - FunctionRef f; - int value_index; - - inline bool operator==(const TensorKey& other) const { - return f == other.f && value_index == other.value_index; - } - inline std::string GetName() const { - if (f->num_outputs() == 1) return f->func_name(); - std::ostringstream os; - os << f->func_name() << ".v" << value_index; - return os.str(); - } -}; - -} // namespace ir -} // namespace tvm - -namespace std { -template <> -struct hash<::tvm::ir::TensorKey> { - std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = k.f.hash(); - size_t rhs = static_cast(k.value_index); - lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - return lhs; - } -}; -} // namespace std - -namespace tvm { -namespace ir { - using Halide::Internal::Region; using runtime::StorageScope; using runtime::ThreadScope; diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 4514d02282b8..88729a3ce42a 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) { return ir::Simplify((a + b - 1) / b); } +inline bool prove_equal(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + // Downward message passing algorithm on stage schedule s, // pass the range state down from the root to the leaves // after this pass, every IterVar in the stage hyper graph will have a range(domain) @@ -41,9 +45,18 @@ void PassDown(const Stage& s, if (r->outer->dom.defined()) { state[r->outer] = r->outer->dom; } else { - CHECK(!state.count(r->outer)); - state[r->outer] = Range::make_with_min_extent( - 0, DivCeil(range_parent->extent, r->factor)); + if (!state.count(r->outer)) { + state[r->outer] = Range::make_with_min_extent( + 0, DivCeil(range_parent->extent, r->factor)); + } else { + Expr outer_ext = DivCeil(range_parent->extent, r->factor); + Range outer_rng = state.at(r->outer); + bool match = is_zero(outer_rng->min); + if (!prove_equal(outer_ext, outer_rng->extent)) match = false; + CHECK(match) + << "IterVar is used in two places as outer scope," + << " cannot prove their extents are the same"; + } } } else { CHECK(r->outer->dom.defined()); @@ -181,6 +194,21 @@ void PassUp(const Stage& s, } } +// All the itervars that are needed to output bound of op. +// For most op, it is root_iter_vars +// For Scan, it also contains the additional spatial axis. +Array OutputRelatedIterVars(const Operation& op) { + if (op.as()) { + const ScanOpNode* scan = op.as(); + Array ret{scan->scan_axis}; + for (IterVar iv : scan->spatial_axis_) { + ret.push_back(iv); + } + return ret; + } else { + return op->root_iter_vars(); + } +} /*! \brief temporary data structure to store Tensor domain */ struct TensorDom { @@ -214,6 +242,34 @@ void BoundProp(const Operation& op, } }; ir::PostOrderVisit(op.as()->body, fvisit); + } else if (op.as()) { + const ScanOpNode* scan = op.as(); + size_t sp_idx = 0; + for (size_t i = 0; i < scan->init.size(); ++i) { + TensorDom* init_dom = nullptr; + TensorDom* update_dom = nullptr; + if (out->count(scan->init[i])) { + init_dom = &out->at(scan->init[i]); + } + if (out->count(scan->update[i])) { + update_dom = &out->at(scan->update[i]); + } + // first dimension, always needed. + if (init_dom) { + init_dom->data[0].push_back(IntSet::range( + Range::make_with_min_extent(0, scan->init[i]->shape[0]))); + } + // The update dimensions + for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + IterVar sp_ax = scan->spatial_axis_[sp_idx]; + if (init_dom) { + init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get())); + } + if (update_dom) { + update_dom->data[k].push_back(dom_map.at(sp_ax->var.get())); + } + } + } } else if (op.as()) { // do nothing } else { @@ -221,14 +277,49 @@ void BoundProp(const Operation& op, } } -void InferOpBound(const Operation& op, - const std::unordered_map& tmap, - std::unordered_map* rmap) { +// Given the bound of output of op +// Pass the bound to the related axis in op. +void GatherOpBound(const ScanOpNode* scan, + const Operation& op, + const std::unordered_map& tmap, + std::unordered_map* rmap) { + CHECK(!rmap->count(scan->scan_axis)); + std::vector output(op->num_outputs()); + for (size_t i = 0; i < output.size(); ++i) { + output[i] = op.output(i); + } + // Update for time axis. + std::vector time_dom; + for (size_t i = 0; i < output.size(); ++i) { + const TensorDom& d = tmap.at(output[i]); + time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); + } + LOG(INFO) << time_dom.size(); + CHECK(!rmap->count(scan->scan_axis)); + Range sdom = scan->scan_axis->dom; + Range r = arith::Union(time_dom).cover_range(sdom); + (*rmap)[scan->scan_axis] = Range::make_with_min_extent( + sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); + // Update for spatial axis. + size_t sp_idx = 0; + for (size_t i = 0; i < output.size(); ++i) { + for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + IterVar sp_ax = scan->spatial_axis_[sp_idx]; + CHECK(!rmap->count(sp_ax)); + // In default, we always need all spatial axis + // Unless that axis only refers back to itself as a fixed point. + // TODO(tqchen): Add fix point detection. + (*rmap)[sp_ax] = sp_ax->dom; + } + } +} + +void GatherOpBound(const Operation& op, + const std::unordered_map& tmap, + std::unordered_map* rmap) { if (op.as()) { - auto root_iter_vars = op->root_iter_vars(); const ComputeOpNode* compute = op.as(); const TensorDom& tdom = tmap.at(op.output(0)); - for (size_t i = 0; i < compute->axis.size(); ++i) { Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom); CHECK(!rmap->count(compute->axis[i])); @@ -238,6 +329,8 @@ void InferOpBound(const Operation& op, CHECK(!rmap->count(compute->reduce_axis[i])); (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom; } + } else if (op.as()) { + GatherOpBound(op.as(), op, tmap, rmap); } else if (op.as()) { // dp nothing } else { @@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage, std::unordered_map* rmap) { if (stage->attach_type == kInline) return; if (stage->attach_type == kRoot || stage->attach_type == kNone) { - auto root_iter_vars = stage->op->root_iter_vars(); - for (auto iv : root_iter_vars) { + for (auto iv : OutputRelatedIterVars(stage->op)) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); (*rmap)[iv] = iv->dom; @@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage, PassUp(parent, *rmap, &up_state); std::unordered_map dom_map; - for (auto iv : parent->op->root_iter_vars()) { - Range r = up_state.at(iv).cover_range(iv->dom); + for (auto iv : OutputRelatedIterVars(parent->op)) { + Range r; + if (up_state.count(iv)) { + r = up_state.at(iv).cover_range(iv->dom); + } else { + r = iv->dom; + } if (relax_set.size() != 0) { dom_map[iv->var.get()] = EvalSet(r, relax_set); } else { @@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage, CHECK(found) << "Invalid Schedule, cannot find the producer " << stage->op << " along the loop nest specified by compute_at of consumer " << op; - for (auto iv : op->root_iter_vars()) { + for (auto iv : OutputRelatedIterVars(op)) { Range r = rmap->at(iv); dom_map[iv->var.get()] = EvalSet(r, relax_set); } BoundProp(op, dom_map, &tmap); } - InferOpBound(stage->op, tmap, rmap); + GatherOpBound(stage->op, tmap, rmap); } FeedGraph CreateFeedGraph(const Schedule& sch) { diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 33272fceb222..f1047bf95ac9 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array& roots) { if (call != nullptr && call->func.defined()) { Operation call_op(call->func.node_); deps.push_back(call_op.output(call->value_index)); - if (call_op.defined() && visited.count(call_op.get()) == 0) { - visited.insert(call_op.get()); - stack.push_back(call_op); - } } }; ir::PostOrderVisit(op.as()->body, fvisit); - rmap.Set(op, deps); + } else if (op.as()) { + const ScanOpNode* scan = op.as(); + for (Tensor t : scan->init) { + deps.push_back(t); + } + for (Tensor t : scan->update) { + deps.push_back(t); + } } else if (op.as()) { - // empty set of deps - rmap.Set(op, deps); } else { LOG(FATAL) << "unknown Operation" << op->type_key(); } + rmap.Set(op, deps); + for (Tensor t : deps) { + if (t->op.defined() && visited.count(t->op.get()) == 0) { + visited.insert(t->op.get()); + stack.push_back(t->op); + } + } } return rmap; } diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index d7d514b0c75e..3975e4e9033c 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT Stage& Stage::reorder(const Array& order) { // NOLINT(*) StageNode* self = operator->(); + CHECK(!self->op.as()) + << "Cannot reorder axis of scan"; ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); std::vector pos; diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index ed4ad7011a4e..c69381967ec2 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -7,7 +7,9 @@ #include #include #include - +#include +#include +#include #include "../pass/ir_util.h" #include "../arithmetic/compute_expr.h" #include "./graph.h" @@ -18,6 +20,12 @@ namespace schedule { using namespace arith; using namespace ir; +// Two private scope marks +namespace attr { +constexpr const char* loop_scope = "loop_scope"; +constexpr const char* scan_scope = "scan_scope"; +} // namespace attr + /*! * \brief message passing to find if IterVar is related to reduction. * \param s The stage to be used. @@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch, value_map[iv] = iv->var; continue; } - Range dom = dom_map.at(iv); // initialize the offset and loop_level Var var = iv->var; @@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch, if (!reduce_init_loop) { // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmt::make(iv, ir::attr::scope, iv->var, no_op)); + AttrStmt::make(iv, attr::loop_scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. @@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s, init = Substitute(init, init_value_map); init = MergeNest(init_nest, init); // common nest - std::vector > common(nest.begin(), nest.begin() + begin_loop); - std::vector > reduce(nest.begin() + begin_loop, nest.end()); + std::vector > common(nest.begin(), nest.begin() + begin_loop + 1); + std::vector > reduce(nest.begin() + begin_loop + 1, nest.end()); provide = MergeNest(reduce, provide); return MergeNest( common, Block::make(init, provide)); @@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op, bounds, make_const(Bool(1), true), body); } +Stmt MakeRealize(const ScanOpNode* op, + const Map& dom_map, + const std::vector& tensors, + Stmt body) { + Range sdom = dom_map.at(op->scan_axis); + Range tdom = Range::make_with_min_extent( + 0, ir::Simplify(sdom->extent + sdom->min)); + size_t sp_idx = 0; + for (size_t i = 0; i < tensors.size(); ++i) { + const Tensor& t = tensors[i]; + CHECK_EQ(static_cast(t->value_index), i); + Halide::Internal::Region bounds; + bounds.push_back(tdom); + for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) { + IterVar sp_ax = op->spatial_axis_[sp_idx]; + bounds.push_back(dom_map.at(sp_ax)); + } + body = Realize::make(t->op, t->value_index, t->dtype, + bounds, make_const(Bool(1), true), body); + } + return body; +} + void MakeReduction(const ComputeOpNode* op, const std::vector& tensors, @@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s, Stmt init, provide; const ComputeOpNode* compute = s->op.as(); + const ScanOpNode* scan = s->op.as(); if (compute) { if (compute->reduce_axis.size() == 0) { provide = MakeProvide(compute, tensors); } else { MakeReduction(compute, tensors, &init, &provide); } + } else if (scan) { + // Provide is done by the sub operations. + provide = AttrStmt::make( + s->op, attr::scan_scope, scan->scan_axis->var, + Evaluate::make(0)); } else { LOG(FATAL) << "not supported op " << s->op->type_key(); } @@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s, producer = ProducerConsumer::make(s->op, true, producer); Stmt pipeline = producer; - if (consumer.defined()) { + // check if consumer is nop. + bool is_no_op{false}; + const Evaluate* ev = consumer.as(); + if (ev && ev->value.as()) is_no_op = true; + + if (consumer.defined() && !is_no_op) { consumer = ProducerConsumer::make(s->op, false, consumer); pipeline = Block::make(producer, consumer); } @@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s, if (s->op.as()) { pipeline = MakeRealize(s->op.as(), dom_map, tensors, pipeline); + } else if (s->op.as()) { + pipeline = MakeRealize(s->op.as(), + dom_map, tensors, pipeline); } else { LOG(FATAL) << "not supported op"; - return Stmt(); } // use attribute to mark scope of the operation. pipeline = AttrStmt::make( - s->op, "realize_scope", + s->op, ir::attr::realize_scope, StringImm::make(s->scope), pipeline); return pipeline; } // inject the operator's realization on the stmt. -class InjectRealize : public IRMutator { +class InjectAttach : public IRMutator { public: - InjectRealize(Stage schedule, Map dom_map) - : schedule(schedule), dom_map(dom_map) {} + InjectAttach(const Stage& stage, + const Map& dom_map) + : stage_(stage), dom_map_(dom_map) {} Stmt Mutate(Stmt stmt) final { CHECK(stmt.defined()); stmt = IRMutator::Mutate(stmt); const AttrStmt* op = stmt.as(); if (op != nullptr && - op->type_key == "scope") { - if (op->node == schedule->attach_ivar) { + op->type_key == attr::loop_scope) { + if (op->node == stage_->attach_ivar) { CHECK(!found_attach); found_attach = true; stmt = AttrStmt::make( op->node, op->type_key, op->value, - MakePipeline(schedule, dom_map, - IRMutator::Mutate(op->body))); + MakePipeline(stage_, dom_map_, op->body)); } } return stmt; } + // whether attach point is found + bool found_attach{false}; + + private: // the operations to be carried - Stage schedule; + const Stage& stage_; // domain map - Map dom_map; + const Map& dom_map_; +}; + +// inject the operator's realization on the stmt. +class InjectScanStep : public IRMutator { + public: + InjectScanStep(const Stage& stage, + const Operation& scan_op, + const Map& dom_map, + bool is_init) + : stage_(stage), scan_op_(scan_op), + dom_map_(dom_map), is_init_(is_init) {} + + Stmt Mutate(Stmt stmt) final { + CHECK(stmt.defined()); + stmt = IRMutator::Mutate(stmt); + if (is_init_) { + const ProducerConsumer* op = stmt.as(); + if (op != nullptr && + op->is_producer && + op->func.same_as(scan_op_)) { + stmt = ProducerConsumer::make( + op->func, true, + MakePipeline(stage_, dom_map_, op->body)); + found_attach = true; + } + } else { + // update + const AttrStmt* op = stmt.as(); + if (op != nullptr && + op->type_key == attr::scan_scope) { + if (op->node.same_as(scan_op_)) { + found_attach = true; + stmt = AttrStmt::make( + op->node, op->type_key, op->value, + MakePipeline(stage_, dom_map_, op->body)); + } + } + } + return stmt; + } + // whether attach point is found bool found_attach{false}; + + private: + // the operations to be carried + const Stage& stage_; + const Operation& scan_op_; + // domain map + const Map& dom_map_; + // whether it is init. + bool is_init_; }; Stmt InjectInline(const Operation op, Stmt body) { @@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) { return Inline(body, op, args, compute->body); } +// Postprocessing of schedule op +// Replace the init and update's expression by scan's buffer. +class SchedulePostProc : public IRMutator { + public: + Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { + if (to_remove_.count(op->func.get())) { + return this->Mutate(op->body); + } else { + return IRMutator::Mutate_(op, s); + } + } + Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + if (!HasSideEffect(op->value)) { + var_value_[op->var.get()] = Mutate(op->value); + return this->Mutate(op->body); + } else { + return IRMutator::Mutate_(op, s); + } + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->type_key == attr::loop_scope) { + return this->Mutate(op->body); + } else if (op->type_key == attr::scan_scope) { + const ScanOpNode* scan = op->node.as(); + CHECK(scan); + var_value_[scan->scan_axis->var.get()] = op->value; + return this->Mutate(op->body); + } else if (op->type_key == ir::attr::realize_scope) { + if (to_remove_.count(op->node.get())) { + return this->Mutate(op->body); + } + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Realize* op, const Stmt& s) final { + TensorKey key{op->func, op->value_index}; + if (replace_.count(key)) { + return this->Mutate(op->body); + } else { + return IRMutator::Mutate_(op, s); + } + } + + Stmt Mutate_(const Provide* op, const Stmt& s) final { + TensorKey key{op->func, op->value_index}; + auto it = replace_.find(key); + if (it != replace_.end()) { + const Tensor& dst = it->second.first; + Stmt ret = Provide::make( + dst->op, dst->value_index, op->value, + RewriteArgs(it->second.second, op->args)); + return IRMutator::Mutate_(ret.as(), ret); + } else { + return IRMutator::Mutate_(op, s); + } + } + + Expr Mutate_(const Call* op, const Expr& e) final { + if (op != nullptr && op->call_type == Call::Halide) { + TensorKey key{op->func, op->value_index}; + auto it = replace_.find(key); + if (it != replace_.end()) { + const Tensor& dst = it->second.first; + Expr ret = Call::make( + op->type, dst->op->name, + RewriteArgs(it->second.second, op->args), + op->call_type, dst->op, dst->value_index); + return IRMutator::Mutate_(ret.as(), ret); + } + } + return IRMutator::Mutate_(op, e); + } + + Expr Mutate_(const Variable* op, const Expr& e) final { + auto it = var_value_.find(op); + if (it != var_value_.end()) { + return it->second; + } else { + return e; + } + } + + void Init(const Schedule& sch) { + for (Stage s : sch->stages) { + const ScanOpNode* scan = s->op.as(); + if (!scan) continue; + for (size_t i = 0; i < scan->update.size(); ++i) { + Tensor t = s->op.output(i); + AddReplace(scan->init[i], t, Expr()); + AddReplace(scan->update[i], t, scan->scan_axis->var); + AddReplace(scan->state_placeholder[i], t, Expr()); + } + } + } + + private: + void AddReplace(Tensor src, Tensor dst, Expr head_idx) { + replace_[TensorKey{src->op, src->value_index}] + = std::make_pair(dst, head_idx); + to_remove_.insert(src->op.get()); + } + Array RewriteArgs(Expr head, Array args) { + if (!head.defined()) return args; + Array new_args{head}; + for (Expr e : args) { + new_args.push_back(e); + } + return new_args; + } + // The scan value + std::unordered_map var_value_; + // buffer replacement + std::unordered_map > replace_; + // replaced functions + std::unordered_set to_remove_; +}; + Stmt ScheduleOps( Schedule sch, Map dom_map) { Stmt body = Stmt(); + // scan init and scan updates + std::unordered_map > scan_attach; + for (Stage s : sch->stages) { + const ScanOpNode* scan = s->op.as(); + if (!scan) continue; + for (Tensor t : scan->init) { + if (scan_attach.count(t->op)) { + CHECK(scan_attach.at(t->op).first.same_as(s->op)) + << "Scan init tensor can only belong to one scan"; + } else { + scan_attach[t->op] = std::make_pair(s->op, true); + } + } + for (Tensor t : scan->update) { + if (scan_attach.count(t->op)) { + CHECK(scan_attach.at(t->op).first.same_as(s->op)) + << "Scan update tensor can only belong to one scan"; + } else { + scan_attach[t->op] = std::make_pair(s->op, false); + } + } + } + // reverse the post DFS order. for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; // no need to specify place holder op. if (s->op.as()) continue; - if (s->attach_type == kInline) { + if (scan_attach.count(s->op)) { + CHECK(s->attach_type == kNone || s->attach_type == kInline) + << "Cannot specify compute_at for scan's init/update"; + CHECK(body.defined()); + const auto& p = scan_attach.at(s->op); + InjectScanStep mu(s, p.first, dom_map, p.second); + body = mu.Mutate(body); + CHECK(mu.found_attach) + << "did not find attachment point for scan.init/update"; + } else if (s->attach_type == kInline) { body = InjectInline(s->op, body); } else if (s->attach_type == kRoot || s-> attach_type == kNone) { body = MakePipeline(s, dom_map, body); } else if (s->attach_type == kScope) { CHECK(body.defined()); - InjectRealize mutator(s, dom_map); + InjectAttach mutator(s, dom_map); body = mutator.Mutate(body); CHECK(mutator.found_attach) << "did not find attachment point"; } } - return body; + SchedulePostProc post_proc; + post_proc.Init(sch); + return post_proc.Mutate(body); } } // namespace schedule diff --git a/tests/python/integration/test_scan.py b/tests/python/integration/test_scan.py new file mode 100644 index 000000000000..38cd832f2e43 --- /dev/null +++ b/tests/python/integration/test_scan.py @@ -0,0 +1,54 @@ +import tvm +import numpy as np + +def test_scan(): + m = tvm.Var("m") + n = tvm.Var("n") + t = tvm.IterVar((1, m), name="t") + X = tvm.placeholder((m, n), name="X") + s_state = tvm.placeholder((m, n)) + s_init = tvm.compute((1, n), lambda _, i: X[0, i]) + s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i]) + res = tvm.scan(t, s_init, s_update, s_state) + + # schedule + s = tvm.Schedule(res.op) + num_thread = 256 + block_x = tvm.IterVar(thread_tag="blockIdx.x") + thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") + _, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x) + _, x = s[s_init].split(x, outer=thread_x) + _, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x) + _, x = s[s_update].split(x, outer=thread_x) + + # one line to build the function. + def check_device(target): + codes = [] + fscan = tvm.build(s, [X, res], + target, record_codes=codes, + name="myscan") + if target == "cuda": + ctx = tvm.gpu(0) + else: + ctx = tvm.cl(0) + if not ctx.enabled: + return + + for c in codes[1:]: + print(c) + # launch the kernel. + n = 1024 + m = 10 + a_np = np.random.uniform(size=(m, n)).astype(res.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx) + fscan(a, b) + np.testing.assert_allclose( + b.asnumpy(), np.cumsum(a_np, axis=0)) + + tvm.init_opencl() + check_device("cuda") + + +if __name__ == "__main__": + test_scan() diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 9d9115f5c2ea..3459e80e918c 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -34,6 +34,20 @@ def test_tensor_reduce(): assert(str(C_loaded) == str(C)) +def test_tensor_scan(): + m = tvm.Var("m") + n = tvm.Var("n") + t = tvm.IterVar((1, m), "t") + x = tvm.placeholder((m, n)) + s = tvm.placeholder((m, n)) + res = tvm.scan(t, + tvm.compute((1, n), lambda _, i: x[0, i]), + tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]), + s) + assert tuple(res.shape) == (m, n) + + if __name__ == "__main__": test_tensor() test_tensor_reduce() + test_tensor_scan() diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index 9002b9686675..197c4b1f4777 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -18,9 +18,15 @@ def test_simplify(): tvm.make.Load(dtype, Ab.data, i + 4) + 1, (j + 1) * 4 - 4 * j + i), None))) - print(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt) - print(stmt) + + +def test_basic(): + m = tvm.Var('m') + ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1)) + assert str(ret.value) == "(m - 1)" + if __name__ == "__main__": + test_basic() test_simplify() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 9689a1c34fc4..278d1cc53be6 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -6,13 +6,11 @@ def test_schedule0(): l = tvm.Var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') - s = tvm.Schedule(A1.op) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) - print(stmt) def test_schedule1(): m = tvm.Var('m') @@ -25,7 +23,7 @@ def test_schedule1(): bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) - print(stmt) + def test_schedule2(): m = tvm.Var('m') @@ -40,25 +38,45 @@ def test_schedule2(): bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_scan(): + m = tvm.Var("m") + n = tvm.Var("n") + l = tvm.Var("l") + t = tvm.IterVar((1, m), name="t") + x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") + s_state = tvm.placeholder((m, n)) + s_init = tvm.compute((1, n), lambda _, i: x[0, i]) + s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i]) + res = tvm.scan(t, s_init, s_update, s_state) + + assert tuple(res.shape) == (m, n) + s = tvm.Schedule(res.op) + s.normalize() + bounds = tvm.schedule.InferBound(s) + assert(bounds[res.op.scan_axis].min.value == 1) + stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) + def test_auto_inline(): - m = tvm.Var('m') - n = tvm.Var('n') - A = tvm.placeholder((m, n), name='A') - B = tvm.placeholder((m, n), name='B') - C = tvm.placeholder((m, n), name='C') - T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1') - T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') + m = tvm.Var('m') + n = tvm.Var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + C = tvm.placeholder((m, n), name='C') + T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1') + T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') - s = tvm.Schedule(T2.op) - tvm.schedule.AutoInlineElemWise(s) - bounds = tvm.schedule.InferBound(s) - stmt = tvm.schedule.ScheduleOps(s, bounds) - print(stmt) + s = tvm.Schedule(T2.op) + tvm.schedule.AutoInlineElemWise(s) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) if __name__ == "__main__": + test_schedule_scan() test_schedule0() test_schedule1() test_schedule2()