From 5cf3323c6d99df10304e21db3d5d08e1fa8446ba Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 6 Jun 2017 10:41:11 -0700 Subject: [PATCH 01/15] Support for batch ComputeOp --- include/tvm/operation.h | 4 +- python/tvm/api.py | 13 ++- src/op/compute_op.cc | 110 +++++++++++++++------- src/schedule/schedule_dataflow_rewrite.cc | 20 ++-- src/schedule/schedule_ops.cc | 2 +- tests/python/unittest/test_lang_tensor.py | 27 ++++++ 6 files changed, 125 insertions(+), 51 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index eb0ee37569f1..2941302168ad 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode { /*! \brief IterVar on each reduction axis, if the body is a Reduce */ Array reduce_axis; /*! \brief the compute expression */ - Expr body; + Array body; /*! \brief constructor */ ComputeOpNode() {} // override functions @@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode { } static Operation make(std::string name, Array axis, - Expr body); + Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); diff --git a/python/tvm/api.py b/python/tvm/api.py index 2ef18d210342..f97ee8461874 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -174,10 +174,19 @@ def compute(shape, fcompute, name="compute"): dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] body = fcompute(*[v.var for v in dim_var]) + if not isinstance(body, (list, tuple)): + body = [body] body = convert(body) op_node = _api_internal._ComputeOp( name, dim_var, body) - return op_node.output(0) + outputs = [] + num = op_node.num_outputs + if num == 1: + return op_node.output(0) + else: + for i in range(num): + outputs.append(op_node.output(i)) + return tuple(outputs) def scan(init, update, state_placeholder, inputs=None, name="scan"): @@ -526,9 +535,9 @@ def _reduce_directly(*args): def _make_reduce(expr, axis, where=None): expr = convert(expr) - dtype = expr.dtype code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 + dtype = expr.dtype arg_vars = [var(name, dtype) for name in code.co_varnames] result = fcombine(*[v for v in arg_vars]) result = convert(result) diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index a2d3b25e25e0..fcb81d462d4e 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -24,7 +24,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ComputeOpNode); int ComputeOpNode::num_outputs() const { - return 1; + return body.size(); } Array ComputeOpNode::root_iter_vars() const { @@ -36,13 +36,14 @@ Array ComputeOpNode::root_iter_vars() const { return ret; } -Type ComputeOpNode::output_dtype(size_t i) const { - CHECK_EQ(i, 0U); - return body.type(); +Type ComputeOpNode::output_dtype(size_t idx) const { + CHECK_LT(idx, num_outputs()); + return body[idx].type(); } -Array ComputeOpNode::output_shape(size_t i) const { - CHECK_EQ(i, 0U); +Array ComputeOpNode::output_shape(size_t idx) const { + CHECK_LT(idx, num_outputs()); + // for now, all outputs of ComputeOp have the same shape std::vector shape; for (size_t i = 0; i < axis.size(); ++i) { const Range& r = axis[i]->dom; @@ -65,18 +66,19 @@ Tensor compute(Array shape, FCompute fcompute, std::string name) { args.push_back(axis.back()->var); } - return ComputeOpNode::make(name, axis, fcompute(args)).output(0); + return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0); } Operation ComputeOpNode::make(std::string name, Array axis, - Expr body) { + Array body) { auto n = std::make_shared(); n->name = name; n->axis = axis; n->body = body; - if (n->body->is_type()) { - n->reduce_axis = n->body.as()->axis; + if (n->body[0]->is_type()) { + // batch reduction should have the same axis + n->reduce_axis = n->body[0].as()->axis; } return Operation(n); } @@ -85,16 +87,27 @@ Operation ComputeOpNode::make(std::string name, Array ComputeOpNode::InputTensors() const { Array ret; std::unordered_set visited; - ir::PostOrderVisit(body, [&ret, &visited](const NodeRef& n) { - const ir::Call *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Operation(call->func.node_).output(call->value_index); - if (!visited.count(t)) { - ret.push_back(t); - visited.insert(t); + for (auto& e : body) { + ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) { + const ir::Call *call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Operation(call->func.node_).output(call->value_index); + if (!visited.count(t)) { + ret.push_back(t); + visited.insert(t); + } } - } - }); + }); + } + return ret; +} + +Array ReplaceTensor(Array exprs, + const std::unordered_map& replace) { + Array ret; + for (auto& e : exprs) { + ret.push_back(op::ReplaceTensor(e, replace)); + } return ret; } @@ -102,7 +115,7 @@ Operation ComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - Expr new_body = op::ReplaceTensor(this->body, rmap); + Array new_body = ReplaceTensor(this->body, rmap); if (!new_body.same_as(this->body)) { return ComputeOpNode::make(name, axis, new_body); } else { @@ -151,13 +164,17 @@ Stmt ComputeOpNode::BuildRealize( const std::unordered_map& realize_map, const Stmt& realize_body) const { CHECK_EQ(self.operator->(), this); - Tensor t = self.output(0); Halide::Internal::Region bounds; for (IterVar iv : this->axis) { bounds.push_back(realize_map.at(iv)); } - return ir::Realize::make(t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + Stmt realize = realize_body; + for (int i = self->num_outputs(); i > 0; --i) { + Tensor t = self.output(i-1); + realize = ir::Realize::make(t->op, t->value_index, + t->dtype, bounds, const_true(), realize); + } + return realize; } // Build a reduction body. @@ -165,13 +182,11 @@ void MakeReduction(const ComputeOpNode* op, const Tensor& t, Stmt* init, Stmt* provide) { - Stmt no_op = Evaluate::make(0); - std::vector nest; Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } - const Reduce* reduce = op->body.as(); + const Reduce* reduce = op->body[t->value_index].as(); CHECK(reduce); const CommReducerNode* combiner = reduce->combiner.as(); CHECK(combiner); @@ -193,6 +208,19 @@ Stmt Substitute(Stmt s, return ir::Substitute(s, temp); } +std::vector Substitute(std::vector stmt, + const std::unordered_map& value_map) { + Map temp; + for (const auto& kv : value_map) { + temp.Set(kv.first->var, kv.second); + } + std::vector ret; + for (auto& s : stmt) { + ret.push_back(ir::Substitute(s, temp)); + } + return ret; +} + // Cross Thread reduction marker. bool IsCrossThreadReduction(const ComputeOpNode* self, const Stage& stage) { @@ -289,7 +317,7 @@ Stmt MakeProvide(const ComputeOpNode* op, for (IterVar iv : op->axis) { args.push_back(iv->var); } - return Provide::make(t->op, t->value_index, op->body, args); + return Provide::make(t->op, t->value_index, op->body[t->value_index], args); } Stmt ComputeOpNode::BuildProvide( @@ -301,12 +329,22 @@ Stmt ComputeOpNode::BuildProvide( // specially handle cross thread reduction. return MakeCrossThreadReduction(this, stage, dom_map); } - Stmt init, provide; + + std::vector inits; + std::vector provides; if (this->reduce_axis.size() == 0) { - provide = MakeProvide(this, stage->op.output(0)); + for (int i = 0; i < this->num_outputs(); ++i) { + provides.push_back(MakeProvide(this, stage->op.output(i))); + } } else { - MakeReduction(this, stage->op.output(0), &init, &provide); + for (int i = 0; i < this->num_outputs(); ++i) { + Stmt init, provide; + MakeReduction(this, stage->op.output(i), &init, &provide); + inits.push_back(init); + provides.push_back(provide); + } } + // make loop nest std::unordered_map value_map; auto nest = op::MakeLoopNest( @@ -318,9 +356,9 @@ Stmt ComputeOpNode::BuildProvide( if (stage->store_predicate.defined()) { nest.emplace_back(op::MakeIfNest({stage->store_predicate})); } - provide = Substitute(provide, value_map); + provides = Substitute(provides, value_map); - if (init.defined()) { + if (!inits.empty()) { // try to find the location to insert the initialization. // Fuse the initialization and provide loop when possible. std::unordered_map update_state; @@ -356,15 +394,15 @@ Stmt ComputeOpNode::BuildProvide( auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map); for (auto& e : preds) e = likely(e); init_nest.push_back(op::MakeIfNest(preds)); - init = Substitute(init, init_value_map); - init = MergeNest(init_nest, init); + inits = Substitute(inits, init_value_map); + Stmt init = MergeNest(init_nest, Block::make(inits)); // common nest std::vector > common(nest.begin(), nest.begin() + begin_loop + 1); std::vector > reduce(nest.begin() + begin_loop + 1, nest.end()); - provide = MergeNest(reduce, provide); + Stmt provide = MergeNest(reduce, Block::make(provides)); return MergeNest(common, Block::make(init, provide)); } else { - return MergeNest(nest, provide); + return MergeNest(nest, Block::make(provides)); } } } // namespace tvm diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 25319dc24eff..20f16b910317 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -120,13 +120,13 @@ Tensor Schedule::cache_write(const Tensor& tensor, vsub[iv->var.get()] = new_iv->var; } VarReplacer repl(vsub); - Expr body = repl.Mutate(compute->body); + Expr body = repl.Mutate(compute->body[tensor->value_index]); Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, new_axis, body); + compute->name + "." + scope, new_axis, {body}); Tensor cache_tensor = cache_op.output(0); Operation orig_new_op = ComputeOpNode::make( compute->name, compute->axis, - cache_tensor(args)); + {cache_tensor(args)}); std::unordered_map vmap; vmap[orig_stage->op.output(0)] = orig_new_op.output(0); @@ -214,14 +214,14 @@ void InjectInline(ScheduleNode* sch) { for (auto iv : compute->axis) { args.push_back(iv->var); } - body = compute->body; + body = compute->body[0]; } for (size_t j = i; j < sch->stages.size(); ++j) { Stage s = sch->stages[j]; const ComputeOpNode* compute = s->op.as(); if (compute) { if (!new_body[j].defined()) { - new_body[j] = s->op.as()->body; + new_body[j] = s->op.as()->body[0]; } new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]), stage->op, args, body).as()->value; @@ -241,7 +241,7 @@ void InjectInline(ScheduleNode* sch) { Operation op = s->op; if (!new_body[i].same_as(compute->body)) { op = ComputeOpNode::make( - compute->name, compute->axis, new_body[i]); + compute->name, compute->axis, {new_body[i]}); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { @@ -359,10 +359,10 @@ Tensor Schedule::rfactor(const Tensor& tensor, n->reduce_axis.push_back(IterVar(ncpy)); } } - n->body = Reduce::make(reduce->combiner, - VarReplacer(vsub).Mutate(reduce->source), - n->reduce_axis, - predicate); + n->body = {Reduce::make(reduce->combiner, + VarReplacer(vsub).Mutate(reduce->source), + n->reduce_axis, + predicate)}; // refresh relations, keep the un-touched relations. Array rels; for (IterVarRelation rel : reduce_stage->relations) { diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index edbc2878a5a9..347cf69884b6 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator { // This must be checked for all ops, including scan. if (!s->op.same_as(s->origin_op)) { for (int i = 0; i < s->op->num_outputs(); ++i) { - Tensor target = s->origin_op.output(0); + Tensor target = s->origin_op.output(i); AddReplace(s->op.output(i), target, target, s->origin_op); } diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index e81b15ffa653..cb8d48813e4c 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -118,6 +118,33 @@ def extern_func(ins, outs): assert(len(res) == 2) assert(res[1].value_index == 1) +def test_multi_inputs_outputs(): + m = tvm.var('m') + n = tvm.var('n') + A0 = tvm.placeholder((m, n), name='A1') + A1 = tvm.placeholder((m, n), name='A2') + T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T') + s = tvm.create_schedule(T0.op) + + sch = s.normalize() + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + +def test_multi_inputs_outputs_reduce(): + m = tvm.var('m') + n = tvm.var('n') + A0 = tvm.placeholder((m, n), name='A0') + A1 = tvm.placeholder((m, n), name='A1') + k = tvm.reduce_axis((0, n), "k") + mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) + myprod = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t)) + T0, T1 = tvm.compute((m,), lambda i: (mysum(A0[i, k], axis=k), myprod(A1[i, k], axis=k))) + s = tvm.create_schedule(T1.op) + + sch = s.normalize() + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + if __name__ == "__main__": test_conv1d() From c75d50f718e46aa8d7e390b03971e632b8a8cf70 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 8 Jun 2017 14:49:48 -0700 Subject: [PATCH 02/15] Support for batch ComputeOp --- include/tvm/expr.h | 10 ++++++ include/tvm/ir_pass.h | 4 +-- include/tvm/schedule.h | 2 +- python/tvm/schedule.py | 2 +- src/op/compute_op.cc | 6 ++-- src/pass/storage_flatten.cc | 4 ++- src/schedule/auto_inline_elem_wise.cc | 2 +- src/schedule/graph.cc | 22 +++++++++---- src/schedule/graph.h | 4 +-- src/schedule/schedule_dataflow_rewrite.cc | 31 +++++++++++-------- tests/python/unittest/test_lang_schedule.py | 10 +++--- tests/python/unittest/test_lang_tensor.py | 20 +++++++----- tests/python/unittest/test_pass_inline.py | 4 +-- .../unittest/test_schedule_schedule_ops.py | 2 +- 14 files changed, 78 insertions(+), 45 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index c35fd6d4c35c..149e00e7ea3f 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -43,6 +43,16 @@ using Halide::Internal::is_no_op; using Halide::likely; using Halide::likely_if_innermost; +/*! \brief whether two array have the same content */ +template +bool IsSame(const Array& a, const Array& b) { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); ++i) { + if (!a[i].same_as(b[i])) return false; + } + return true; +} + inline Type TVMShapeIndexType() { if (std::is_signed::value) { return Int(sizeof(tvm_index_t) * 8); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 7e3238704137..fc0bc1f1abd2 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map& value_map); /*! * \brief inline all calls of f in stmt. * + * \param stmt The statement to apply inline optimization. * \param f The function reference to be inlined * \param args The arguments variable of the function. - * \param body The defintion body of the function. - * \param stmt The statement to apply inline optimization. + * \param body The definition body of the function. * \return The result stmt * * \note All the passes in this file uses SSA form and outputs SSA form. diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 085730a9bff8..89a8845061c3 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -257,7 +257,7 @@ class Schedule : public NodeRef { /*! * \brief Factor a reduction axis in tensor's schedule to be an explicit axis. * This will create a new stage that generated the new tensor with axis - * as the first dimension. The tensor's body wil be rewriten as a reduction + * as the first dimension. The tensor's body will be rewritten as a reduction * over the factored tensor. * * \param tensor The tensor to be factored. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index a23721514b27..ed62923b443d 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -181,7 +181,7 @@ def rfactor(self, tensor, axis): """ Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage that generated the new tensor with axis - as the first dimension. The tensor's body wil be rewriten as a reduction + as the first dimension. The tensor's body will be rewritten as a reduction over the factored tensor. Parameters diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index fcb81d462d4e..c0f183df11e2 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -116,7 +116,7 @@ Operation ComputeOpNode::ReplaceInputs( const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); Array new_body = ReplaceTensor(this->body, rmap); - if (!new_body.same_as(this->body)) { + if (!IsSame(new_body, this->body)) { return ComputeOpNode::make(name, axis, new_body); } else { return self; @@ -140,7 +140,7 @@ void ComputeOpNode::PropBoundToInputs( } } }; - ir::PostOrderVisit(body, fvisit); + for (auto& e : body) ir::PostOrderVisit(e, fvisit); } void ComputeOpNode::GatherBound( @@ -253,7 +253,7 @@ Stmt MakeCrossThreadReduction( for (IterVar iv : self->axis) { args.push_back(iv->var); } - const Reduce* reduce = self->body.as(); + const Reduce* reduce = self->body[0].as(); CHECK(reduce); std::unordered_map value_map; auto nest = op::MakeLoopNest( diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index ff4c5912dae0..f807b92dceaa 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator { CHECK_EQ(extern_buf_remap_.size(), 0U); for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) { TensorKey key{func, static_cast(i)}; - CHECK(buf_map_.count(key)); + CHECK(buf_map_.count(key)) + << "Cannot find allocated buffer for " << key.f + << "(" << key.value_index << ")"; extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] = buf_map_.at(key).buffer->data; } diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 312be19a0874..9fd073c0ac7a 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor { bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); - v.Visit(compute->body); + for (auto& e : compute->body) v.Visit(e); return v.is_elem_wise_; } return false; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 0fcf21def2fc..7afe0f23ea21 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array& ops) { } } }; - ir::PostOrderVisit(op.as()->body, fvisit); + for (auto& e: op.as()->body) { + ir::PostOrderVisit(e, fvisit); + } } } return reach; @@ -321,11 +323,14 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (op.as()) { - std::unordered_map vmap; + std::unordered_map> vmap; const auto& axis = op.as()->axis; - Tensor t = op.output(0); for (size_t i = 0; i < axis.size(); ++i) { - vmap[axis[i]->var.get()] = TensorDimKey(t, i); + std::vector keys; + for (int j = 0; j < op->num_outputs(); ++j) { + keys.emplace_back(op.output(j), i); + } + vmap[axis[i]->var.get()] = std::move(keys); } auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( const NodeRef& n) { @@ -335,7 +340,10 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { auto it = vmap.find(call->args[i].get()); TensorDimKey src(call, static_cast(i)); if (it != vmap.end()) { - f_merge_key(it->second, src); + const std::vector& keys = it->second; + for (const auto& key: keys) { + f_merge_key(key, src); + } } else { if (exact_reach.count(src)) { fail_set.insert(exact_reach.at(src)); @@ -344,7 +352,9 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } }; - ir::PostOrderVisit(op.as()->body, fvisit); + for (auto& e : op.as()->body) { + ir::PostOrderVisit(e, fvisit); + } } } ReachGraph reach; diff --git a/src/schedule/graph.h b/src/schedule/graph.h index 7908dc9e1de6..50d35355cc64 100644 --- a/src/schedule/graph.h +++ b/src/schedule/graph.h @@ -27,7 +27,7 @@ using ReadGraph = Map >; using AttachPath = Map >; /*! - * \brief The map beteen tensor and operation it feeds to. + * \brief The map between tensor and operation it feeds to. */ using FeedGraph = std::unordered_map >; @@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array& roots); * The operations contains node which input-reachable from any inputs * output reachable to any outputs. * - * The inputs won't be included in the subgraph, the outputs will be inclued. + * The inputs won't be included in the subgraph, the outputs will be included. * * \param outputs The outputs of the subgraph * \param inputs The inputs to the subgraph. diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 20f16b910317..981deb726366 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -198,14 +198,14 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { void InjectInline(ScheduleNode* sch) { sch->InvalidateCache(); - std::vector new_body(sch->stages.size()); + std::vector> new_body(sch->stages.size()); // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; if (stage->attach_type == kInline) { stage->attach_type = kInlinedAlready; Array args; - Expr body; + Array body; { // setup args const ComputeOpNode* compute = stage->op.as(); @@ -214,17 +214,19 @@ void InjectInline(ScheduleNode* sch) { for (auto iv : compute->axis) { args.push_back(iv->var); } - body = compute->body[0]; + body = compute->body; } for (size_t j = i; j < sch->stages.size(); ++j) { Stage s = sch->stages[j]; const ComputeOpNode* compute = s->op.as(); if (compute) { - if (!new_body[j].defined()) { - new_body[j] = s->op.as()->body[0]; + if (!new_body[j].size()) { + new_body[j] = s->op.as()->body; + } + for (size_t k = 0; k < body.size(); ++k) { + new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]), + stage->op, args, body[k]).as()->value); } - new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]), - stage->op, args, body).as()->value; } } } @@ -234,19 +236,21 @@ void InjectInline(ScheduleNode* sch) { for (size_t i = 0; i < sch->stages.size(); ++i) { Stage s = sch->stages[i]; if (s->attach_type == kInlinedAlready) continue; - if (new_body[i].defined()) { + if (new_body[i].size()) { // Logics from ReplaceDataFlow const ComputeOpNode* compute = sch->stages[i]->op.as(); CHECK(compute); Operation op = s->op; - if (!new_body[i].same_as(compute->body)) { + if (!IsSame(new_body[i], compute->body)) { op = ComputeOpNode::make( - compute->name, compute->axis, {new_body[i]}); + compute->name, compute->axis, new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { - repl[s->op.output(0)] = op.output(0); - s->op = op; + for (int idx = 0; idx < s->op->num_outputs(); ++idx) { + repl[s->op.output(idx)] = op.output(idx); + s->op = op; + } } } else { Operation op = s->op->ReplaceInputs(s->op, repl); @@ -329,7 +333,8 @@ Tensor Schedule::rfactor(const Tensor& tensor, } } // predicate generation, copy not touched axis. - const Reduce* reduce = compute_op->body.as(); + int idx = tensor->value_index; + const Reduce* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; Expr predicate = reduce->condition; std::unordered_map vsub; diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 5fed5a23f750..1b0eac15fe07 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -101,8 +101,8 @@ def test_rfactor(): s = tvm.create_schedule(B.op) BF = s.rfactor(B, k1) assert(tuple(BF.shape) == (n, n)) - assert(set(BF.op.body.axis) == set([k2])) - assert(s[B].op.body.axis[0].dom.extent == n) + assert(set(BF.op.body[0].axis) == set([k2])) + assert(s[B].op.body[0].axis[0].dom.extent == n) assert(len(s[B].all_iter_vars) == 2) # schedule with splot s = tvm.create_schedule(B.op) @@ -111,9 +111,9 @@ def test_rfactor(): BF = s.rfactor(B, ki) assert(BF.shape[0].value == 4) assert(BF.shape[1] == n) - assert(BF.op.body.axis[0] == k2) - assert(BF.op.body.axis[1].var == ko.var) - assert(s[B].op.body.axis[0].dom.extent.value == 4) + assert(BF.op.body[0].axis[0] == k2) + assert(BF.op.body[0].axis[1].var == ko.var) + assert(s[B].op.body[0].axis[0].dom.extent.value == 4) if __name__ == "__main__": diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index cb8d48813e4c..ddfcb01c8742 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -126,9 +126,11 @@ def test_multi_inputs_outputs(): T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T') s = tvm.create_schedule(T0.op) - sch = s.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) + for i in range(len(T0.shape)): + assert(T0.shape[i] == T1.shape[i]) + assert(T0.op == T1.op) + assert(T0.value_index == 0) + assert(T1.value_index == 1) def test_multi_inputs_outputs_reduce(): m = tvm.var('m') @@ -138,12 +140,14 @@ def test_multi_inputs_outputs_reduce(): k = tvm.reduce_axis((0, n), "k") mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) myprod = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t)) - T0, T1 = tvm.compute((m,), lambda i: (mysum(A0[i, k], axis=k), myprod(A1[i, k], axis=k))) + T0, T1 = tvm.compute((m,), lambda i: (mysum(A0[i, k], axis=k), myprod(A1[i, k], axis=k)), name='T') s = tvm.create_schedule(T1.op) - sch = s.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) + for i in range(len(T0.shape)): + assert(T0.shape[i] == T1.shape[i]) + assert(T0.op == T1.op) + assert(T0.value_index == 0) + assert(T1.value_index == 1) if __name__ == "__main__": @@ -155,3 +159,5 @@ def test_multi_inputs_outputs_reduce(): test_scan_multi_out() test_extern() test_extern_multi_out() + test_multi_inputs_outputs() + test_multi_inputs_outputs_reduce() diff --git a/tests/python/unittest/test_pass_inline.py b/tests/python/unittest/test_pass_inline.py index 1988d54083c7..398c0d34d58d 100644 --- a/tests/python/unittest/test_pass_inline.py +++ b/tests/python/unittest/test_pass_inline.py @@ -6,7 +6,7 @@ def test_inline(): T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body) + stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) print(stmt) assert(tvm.ir_pass.VerifySSA(stmt)) @@ -25,7 +25,7 @@ def test_inline2(): T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100]) stmt = tvm.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body) + stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) def check(op): if isinstance(op, tvm.expr.Call): assert op.func != T.op diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index eceba04f5b0b..297800e1632d 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -89,7 +89,7 @@ def test_inline_mixed(): def check(x): if isinstance(x, tvm.expr.Call): assert x.func != A2 - tvm.ir_pass.PostOrderVisit(s[C].op.body, check) + tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check) def test_scan_inline1(): From 8a224180547d940ac452ad794dfbf31caae2b5d2 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 8 Jun 2017 15:27:36 -0700 Subject: [PATCH 03/15] Fix CrossThreadReduction --- src/op/compute_op.cc | 99 +++++++++++++++++++++++-------------------- src/schedule/graph.cc | 4 +- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index c0f183df11e2..bcfb79835daa 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -253,62 +253,68 @@ Stmt MakeCrossThreadReduction( for (IterVar iv : self->axis) { args.push_back(iv->var); } - const Reduce* reduce = self->body[0].as(); - CHECK(reduce); std::unordered_map value_map; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set(), &value_map); auto conds = op::MakeBoundCheck( stage, dom_map, false, std::unordered_set(), value_map); - Expr cond = reduce->condition; - for (Expr v : conds) { - cond = cond && v; - } - Var res_handle("reduce_temp", Handle()); - Array freduce_args; - freduce_args.push_back(reduce->source); - freduce_args.push_back(cond); - for (IterVar iv : stage->leaf_iter_vars) { - if (iv->iter_type == kCommReduce) { - auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { - IterVar tv = (*it).second->bind_thread; - freduce_args.push_back(tv->var); + std::vector reduction_bodies; + for (size_t idx = 0; idx < self->body.size(); ++idx) { + const Reduce* reduce = self->body[idx].as(); + CHECK(reduce); + Expr cond = reduce->condition; + for (Expr v : conds) { + cond = cond && v; + } + Var res_handle("reduce_temp"+std::to_string(idx), Handle()); + Array freduce_args; + freduce_args.push_back(reduce->source); + freduce_args.push_back(cond); + + for (IterVar iv : stage->leaf_iter_vars) { + if (iv->iter_type == kCommReduce) { + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end() && + (*it).second->bind_thread.defined()) { + IterVar tv = (*it).second->bind_thread; + freduce_args.push_back(tv->var); + } } } + // Checks for the thread. + std::vector thread_head_check; + if (stage->store_predicate.defined()) { + thread_head_check.emplace_back(stage->store_predicate); + } + Type t = reduce->type; + Expr pred = const_true(t.lanes()); + Stmt reduce_body = Store::make(res_handle, + Call::make( + reduce->type, + ir::intrinsic::tvm_thread_allreduce, + freduce_args, Call::Intrinsic), + 0, pred); + reduce_body = AttrStmt::make( + reduce->combiner, + attr::reduce_scope, + make_zero(reduce->type), + reduce_body); + Stmt assign_body = Provide::make( + stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args); + + assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); + assign_body = MergeNest(op::MakeIfNest(conds), assign_body); + Stmt body = Allocate::make( + res_handle, reduce->type, {1}, const_true(), + Block::make(reduce_body, assign_body)); + body = AttrStmt::make( + res_handle, attr::storage_scope, StringImm::make("local"), body); + body = Substitute(body, value_map); + reduction_bodies.push_back(body); } - // Checks for the thread. - std::vector thread_head_check; - if (stage->store_predicate.defined()) { - thread_head_check.emplace_back(stage->store_predicate); - } - Type t = reduce->type; - Expr pred = const_true(t.lanes()); - Stmt reduce_body = Store::make(res_handle, - Call::make( - reduce->type, - ir::intrinsic::tvm_thread_allreduce, - freduce_args, Call::Intrinsic), - 0, pred); - reduce_body = AttrStmt::make( - reduce->combiner, - attr::reduce_scope, - make_zero(reduce->type), - reduce_body); - Stmt assign_body = Provide::make( - stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args); - assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); - assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Allocate::make( - res_handle, reduce->type, {1}, const_true(), - Block::make(reduce_body, assign_body)); - body = AttrStmt::make( - res_handle, attr::storage_scope, StringImm::make("local"), body); - body = Substitute(body, value_map); - return MergeNest(nest, body); + return MergeNest(nest, Block::make(reduction_bodies)); } Stmt MakeProvide(const ComputeOpNode* op, @@ -326,6 +332,7 @@ Stmt ComputeOpNode::BuildProvide( CHECK_EQ(stage->op.operator->(), this); if (IsCrossThreadReduction(this, stage)) { + LOG(INFO) << stage; // specially handle cross thread reduction. return MakeCrossThreadReduction(this, stage, dom_map); } diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 7afe0f23ea21..a01bde30de89 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array& ops) { } } }; - for (auto& e: op.as()->body) { + for (auto& e : op.as()->body) { ir::PostOrderVisit(e, fvisit); } } @@ -341,7 +341,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { TensorDimKey src(call, static_cast(i)); if (it != vmap.end()) { const std::vector& keys = it->second; - for (const auto& key: keys) { + for (const auto& key : keys) { f_merge_key(key, src); } } else { From efa79fbd3deec4e67aab768d05c5f2fbf4b32077 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 8 Jun 2017 15:28:52 -0700 Subject: [PATCH 04/15] Fix lint --- python/tvm/api.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/api.py b/python/tvm/api.py index f97ee8461874..44bb734b6fa1 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -183,10 +183,9 @@ def compute(shape, fcompute, name="compute"): num = op_node.num_outputs if num == 1: return op_node.output(0) - else: - for i in range(num): - outputs.append(op_node.output(i)) - return tuple(outputs) + for i in range(num): + outputs.append(op_node.output(i)) + return tuple(outputs) def scan(init, update, state_placeholder, inputs=None, name="scan"): From 8869899705fc9b680bcdaead61e42a93a3a594d7 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 8 Jun 2017 23:08:37 -0700 Subject: [PATCH 05/15] Add UpdateArray, remove support for batch reduce --- include/tvm/expr.h | 10 -- src/op/compute_op.cc | 121 ++++++++++------------ src/pass/ir_mutator.cc | 16 +-- src/pass/ir_util.h | 17 +++ src/schedule/schedule_dataflow_rewrite.cc | 4 +- tests/python/unittest/test_lang_tensor.py | 18 ---- 6 files changed, 78 insertions(+), 108 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 149e00e7ea3f..c35fd6d4c35c 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -43,16 +43,6 @@ using Halide::Internal::is_no_op; using Halide::likely; using Halide::likely_if_innermost; -/*! \brief whether two array have the same content */ -template -bool IsSame(const Array& a, const Array& b) { - if (a.size() != b.size()) return false; - for (size_t i = 0; i < a.size(); ++i) { - if (!a[i].same_as(b[i])) return false; - } - return true; -} - inline Type TVMShapeIndexType() { if (std::is_signed::value) { return Int(sizeof(tvm_index_t) * 8); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index bcfb79835daa..40145103e39e 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -77,7 +77,8 @@ Operation ComputeOpNode::make(std::string name, n->axis = axis; n->body = body; if (n->body[0]->is_type()) { - // batch reduction should have the same axis + CHECK_EQ(n->body.size(), 1) + << "Only support single reduction expression for now"; n->reduce_axis = n->body[0].as()->axis; } return Operation(n); @@ -102,22 +103,16 @@ Array ComputeOpNode::InputTensors() const { return ret; } -Array ReplaceTensor(Array exprs, - const std::unordered_map& replace) { - Array ret; - for (auto& e : exprs) { - ret.push_back(op::ReplaceTensor(e, replace)); - } - return ret; -} - Operation ComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - Array new_body = ReplaceTensor(this->body, rmap); - if (!IsSame(new_body, this->body)) { - return ComputeOpNode::make(name, axis, new_body); + std::function fupdate = [&rmap] (Expr e) { + return op::ReplaceTensor(e, rmap); + }; + Array arr = UpdateArray(this->body, fupdate); + if (!arr.same_as(this->body)) { + return ComputeOpNode::make(name, axis, arr); } else { return self; } @@ -259,62 +254,57 @@ Stmt MakeCrossThreadReduction( auto conds = op::MakeBoundCheck( stage, dom_map, false, std::unordered_set(), value_map); + const Reduce* reduce = self->body[0].as(); + CHECK(reduce); + Expr cond = reduce->condition; + for (Expr v : conds) { + cond = cond && v; + } + Var res_handle("reduce_temp", Handle()); + Array freduce_args; + freduce_args.push_back(reduce->source); + freduce_args.push_back(cond); - std::vector reduction_bodies; - for (size_t idx = 0; idx < self->body.size(); ++idx) { - const Reduce* reduce = self->body[idx].as(); - CHECK(reduce); - Expr cond = reduce->condition; - for (Expr v : conds) { - cond = cond && v; - } - Var res_handle("reduce_temp"+std::to_string(idx), Handle()); - Array freduce_args; - freduce_args.push_back(reduce->source); - freduce_args.push_back(cond); - - for (IterVar iv : stage->leaf_iter_vars) { - if (iv->iter_type == kCommReduce) { - auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { - IterVar tv = (*it).second->bind_thread; - freduce_args.push_back(tv->var); - } + for (IterVar iv : stage->leaf_iter_vars) { + if (iv->iter_type == kCommReduce) { + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end() && + (*it).second->bind_thread.defined()) { + IterVar tv = (*it).second->bind_thread; + freduce_args.push_back(tv->var); } } - // Checks for the thread. - std::vector thread_head_check; - if (stage->store_predicate.defined()) { - thread_head_check.emplace_back(stage->store_predicate); - } - Type t = reduce->type; - Expr pred = const_true(t.lanes()); - Stmt reduce_body = Store::make(res_handle, - Call::make( - reduce->type, - ir::intrinsic::tvm_thread_allreduce, - freduce_args, Call::Intrinsic), - 0, pred); - reduce_body = AttrStmt::make( - reduce->combiner, - attr::reduce_scope, - make_zero(reduce->type), - reduce_body); - Stmt assign_body = Provide::make( - stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args); - - assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); - assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Allocate::make( - res_handle, reduce->type, {1}, const_true(), - Block::make(reduce_body, assign_body)); - body = AttrStmt::make( - res_handle, attr::storage_scope, StringImm::make("local"), body); - body = Substitute(body, value_map); - reduction_bodies.push_back(body); } - return MergeNest(nest, Block::make(reduction_bodies)); + // Checks for the thread. + std::vector thread_head_check; + if (stage->store_predicate.defined()) { + thread_head_check.emplace_back(stage->store_predicate); + } + Type t = reduce->type; + Expr pred = const_true(t.lanes()); + Stmt reduce_body = Store::make(res_handle, + Call::make( + reduce->type, + ir::intrinsic::tvm_thread_allreduce, + freduce_args, Call::Intrinsic), + 0, pred); + reduce_body = AttrStmt::make( + reduce->combiner, + attr::reduce_scope, + make_zero(reduce->type), + reduce_body); + Stmt assign_body = Provide::make( + stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args); + + assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); + assign_body = MergeNest(op::MakeIfNest(conds), assign_body); + Stmt body = Allocate::make( + res_handle, reduce->type, {1}, const_true(), + Block::make(reduce_body, assign_body)); + body = AttrStmt::make( + res_handle, attr::storage_scope, StringImm::make("local"), body); + body = Substitute(body, value_map); + return MergeNest(nest, body); } Stmt MakeProvide(const ComputeOpNode* op, @@ -332,7 +322,6 @@ Stmt ComputeOpNode::BuildProvide( CHECK_EQ(stage->op.operator->(), this); if (IsCrossThreadReduction(this, stage)) { - LOG(INFO) << stage; // specially handle cross thread reduction. return MakeCrossThreadReduction(this, stage, dom_map); } diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index ce179b79188a..7e0cbd034e0b 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -4,6 +4,7 @@ */ #include #include +#include "./ir_util.h" namespace tvm { namespace ir { @@ -17,19 +18,8 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) } inline Array MutateArray(Array arr, IRMutator *m) { - std::vector new_arr(arr.size()); - bool changed = false; - for (size_t i = 0; i < arr.size(); i++) { - Expr old_elem = arr[i]; - Expr new_elem = m->Mutate(old_elem); - if (!new_elem.same_as(old_elem)) changed = true; - new_arr[i] = new_elem; - } - if (!changed) { - return arr; - } else { - return Array(new_arr); - } + std::function fupdate = [m] (Expr e) { return m->Mutate(e); }; + return UpdateArray(arr, fupdate); } inline Array MutateIterVarArr(Array rdom, IRMutator *m) { diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 1982f977365f..495ae240c443 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -12,6 +12,23 @@ namespace tvm { namespace ir { +template +inline Array UpdateArray(Array arr, std::function fupdate) { + std::vector new_arr(arr.size()); + bool changed = false; + for (size_t i = 0; i < arr.size(); ++i) { + T old_elem = arr[i]; + T new_elem = fupdate(old_elem); + if (!new_elem.same_as(old_elem)) changed = true; + new_arr[i] = new_elem; + } + if (!changed) { + return arr; + } else { + return Array(new_arr); + } +} + /*! * \brief combine the nest stmt, whose body is not defined. * \param nest A list of For and LetStmt, whose body is not defined. diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 981deb726366..8b6055b0a738 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -199,6 +199,7 @@ void InjectInline(ScheduleNode* sch) { sch->InvalidateCache(); std::vector> new_body(sch->stages.size()); + std::vector changed(sch->stages.size(), false); // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; @@ -224,6 +225,7 @@ void InjectInline(ScheduleNode* sch) { new_body[j] = s->op.as()->body; } for (size_t k = 0; k < body.size(); ++k) { + changed[j] = true; new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]), stage->op, args, body[k]).as()->value); } @@ -241,7 +243,7 @@ void InjectInline(ScheduleNode* sch) { const ComputeOpNode* compute = sch->stages[i]->op.as(); CHECK(compute); Operation op = s->op; - if (!IsSame(new_body[i], compute->body)) { + if (changed[i]) { op = ComputeOpNode::make( compute->name, compute->axis, new_body[i]); } diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index ddfcb01c8742..7d0f4ffadbb9 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -132,23 +132,6 @@ def test_multi_inputs_outputs(): assert(T0.value_index == 0) assert(T1.value_index == 1) -def test_multi_inputs_outputs_reduce(): - m = tvm.var('m') - n = tvm.var('n') - A0 = tvm.placeholder((m, n), name='A0') - A1 = tvm.placeholder((m, n), name='A1') - k = tvm.reduce_axis((0, n), "k") - mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) - myprod = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t)) - T0, T1 = tvm.compute((m,), lambda i: (mysum(A0[i, k], axis=k), myprod(A1[i, k], axis=k)), name='T') - s = tvm.create_schedule(T1.op) - - for i in range(len(T0.shape)): - assert(T0.shape[i] == T1.shape[i]) - assert(T0.op == T1.op) - assert(T0.value_index == 0) - assert(T1.value_index == 1) - if __name__ == "__main__": test_conv1d() @@ -160,4 +143,3 @@ def test_multi_inputs_outputs_reduce(): test_extern() test_extern_multi_out() test_multi_inputs_outputs() - test_multi_inputs_outputs_reduce() From 67038508d7ad4a9effe55babc49595ca66e47f07 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 9 Jun 2017 18:43:07 -0700 Subject: [PATCH 06/15] Tuple input support for reduce --- include/tvm/ir.h | 28 ++++++++---- python/tvm/api.py | 42 +++++++++++++---- src/api/api_ir.cc | 10 ++-- src/lang/expr.cc | 15 +++--- src/lang/ir.cc | 46 ++++++++++++------- src/op/compute_op.cc | 56 +++++++++++++++-------- src/pass/ir_mutator.cc | 2 +- src/pass/ir_visitor.cc | 2 +- src/pass/lower_thread_allreduce.cc | 15 ++++-- src/schedule/schedule_dataflow_rewrite.cc | 8 +++- tests/python/integration/test_reduce.py | 48 +++++++++++++++++++ 11 files changed, 199 insertions(+), 73 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index f2f47b8cb4b0..edaee8428c0a 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -47,23 +47,27 @@ struct CommReducer : public NodeRef { * binary operator with identity element */ struct CommReducerNode : public Node { - /*! \brief The arguments of reducer */ - Array args; + /*! \brief The left argument of reducer */ + Array lhs; + /*! \brief The right argument of reducer */ + Array rhs; /*! \brief The result of reducer */ - Expr result; + Array result; /*! * \brief The identity element of reducer, which leaves other * elements unchanged when combined with it, with respect to * the binary operation of this reducer uses. */ - Expr identity_element; + Array identity_element; /*! \brief Function call operator to combine a and b */ - Expr operator()(Expr a, Expr b) const; + Array operator()(Array a, Array b) const; /*! \brief construct CommReducer from args, result and identity_element */ - static CommReducer make(Array args, Expr result, Expr identity_element); + static CommReducer make(Array lhs, Array rhs, + Array result, Array identity_element); void VisitAttrs(AttrVisitor* v) final { - v->Visit("args", &args); + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); v->Visit("result", &result); v->Visit("identity_element", &identity_element); } @@ -84,7 +88,7 @@ struct Reduce : public ExprNode { /*! \brief The commutative combiner */ CommReducer combiner; /*! \brief The source operand */ - Expr source; + Array source; /*! \brief The reduction axis */ Array axis; /*! @@ -92,18 +96,22 @@ struct Reduce : public ExprNode { * Only add the body to reduction if condition is true. */ Expr condition; + /*! \brief the index of this reduce node */ + int value_index; /*! \brief construct expr from op and rdom */ static Expr make(CommReducer combiner, - Expr src, + Array src, Array rdom, - Expr condition = const_true()); + Expr condition = const_true(), + int value_index = 0); void VisitAttrs(AttrVisitor* v) final { v->Visit("dtype", &type); v->Visit("source", &source); v->Visit("axis", &axis); v->Visit("condition", &condition); + v->Visit("value_index", &value_index); } static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static constexpr const char* _type_key = "Reduce"; diff --git a/python/tvm/api.py b/python/tvm/api.py index 44bb734b6fa1..fda9be29340a 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -533,18 +533,44 @@ def _reduce_directly(*args): return res def _make_reduce(expr, axis, where=None): - expr = convert(expr) code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - dtype = expr.dtype - arg_vars = [var(name, dtype) for name in code.co_varnames] - result = fcombine(*[v for v in arg_vars]) + expr = convert(expr) + if isinstance(expr, _collections.Array): + size = len(expr) + larr = [] + rarr = [] + dtypes = [] + for i in range(size): + dtype = expr[i].dtype + dtypes.append(dtype) + lname = code.co_varnames[0] + '_' + str(i) + larr.append(var(lname, dtype)) + rname = code.co_varnames[1] + '_' + str(i) + rarr.append(var(rname, dtype)) + lhs = convert(larr) + rhs = convert(rarr) + result = fcombine(lhs, rhs) + id_elem = fidentity(*dtypes) + else: + assert isinstance(expr, _expr.Expr) + size = 1 + dtype = expr.dtype + lvar = var(code.co_varnames[0], dtype) + rvar = var(code.co_varnames[1], dtype) + result = [fcombine(lvar, rvar)] + id_elem = [fidentity(dtype)] + lhs = convert([lvar]) + rhs = convert([rvar]) + expr = convert([expr]) result = convert(result) - id_elem = fidentity(dtype) - assert isinstance(id_elem, _expr.Expr) - combiner = _make.CommReducer(arg_vars, result, id_elem) + id_elem = convert(id_elem) + combiner = _make.CommReducer(lhs, rhs, result, id_elem) axis = axis if isinstance(axis, list) else [axis] - return _make.Reduce(combiner, expr, axis, where) + if size == 1: + return _make.Reduce(combiner, expr, axis, where, 0) + return [_make.Reduce(combiner, expr, axis, where, i) + for i in range(size)] def reducer(expr, axis, where=None, *args): if isinstance(axis, (_schedule.IterVar, list)): diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 00ab79b19167..f66652b99157 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call") }); TVM_REGISTER_API("make.CommReducer") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CommReducerNode::make(args[0], args[1], args[2]); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CommReducerNode::make(args[0], + args[1], + args[2], + args[3]); }); - // make from two arguments #define REGISTER_MAKE1(Node) \ TVM_REGISTER_API("make."#Node) \ @@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer") *ret = Node::make(a, b); \ }) -REGISTER_MAKE4(Reduce); +REGISTER_MAKE5(Reduce); REGISTER_MAKE4(AttrStmt); REGISTER_MAKE2(IntImm); diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 795f34fe673f..b0e33fc94721 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -50,24 +50,27 @@ Expr sum(Expr source, Array rdom) { Var x("x"), y("y"); Expr result = ir::Add::make(x, y); Expr identity_element = make_zero(source.type()); - ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); - return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true)); } Expr max(Expr source, Array rdom) { Var x("x"), y("y"); Expr result = ir::Max::make(x, y); Expr identity_element = source.type().min(); - ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); - return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true)); } Expr min(Expr source, Array rdom) { Var x("x"), y("y"); Expr result = ir::Min::make(x, y); Expr identity_element = source.type().max(); - ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); - return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true)); } std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 52ba225253da..89fc57af2751 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -9,6 +9,7 @@ #include #include #include +#include "../pass/ir_util.h" namespace Halide { namespace Internal { @@ -25,9 +26,8 @@ void ExprNode::accept(IRVisitor *v, const Expr&) const { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const Reduce *op, IRPrinter *p) { p->stream << "reduce(combiner=" - << op->combiner - << ", "; - p->print(op->source); + << op->combiner; + p->stream << ", source=" << op->source; p->stream << ", axis=" << op->axis; if (!is_const(op->condition, 1)) { p->stream << ", where=" << op->condition; @@ -37,11 +37,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const CommReducerNode *op, IRPrinter *p) { - p->stream << "comm_reducer(result=" - << op->result - << ", args=" << op->args - << ", identity_element=" - << op->identity_element + p->stream << "comm_reducer(result=" << op->result + << ", lhs=" << op->lhs + << ", rhs=" << op->rhs + << ", identity_element=" << op->identity_element << ")"; }); } // namespace Internal @@ -50,23 +49,35 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) namespace tvm { namespace ir { -CommReducer CommReducerNode::make(Array args, Expr result, Expr identity_element) { +CommReducer CommReducerNode::make(Array lhs, + Array rhs, + Array result, + Array identity_element) { auto node = std::make_shared(); - node->args = args; + node->lhs = lhs; + node->rhs = rhs; node->result = result; node->identity_element = identity_element; return CommReducer(node); } -Expr CommReducerNode::operator()(Expr a, Expr b) const { +Array CommReducerNode::operator()(Array a, Array b) const { + CHECK_EQ(a.size(), b.size()); + CHECK_EQ(lhs.size(), a.size()); + CHECK_EQ(rhs.size(), b.size()); Map value_map; - value_map.Set(args[0], a); - value_map.Set(args[1], b); - return Substitute(result, value_map); + for (size_t i = 0; i < a.size(); ++i) { + value_map.Set(lhs[i], a[i]); + value_map.Set(rhs[i], b[i]); + } + std::function fupdate = [&value_map] (Expr e) { + return Substitute(e, value_map); + }; + return UpdateArray(result, fupdate); } -Expr Reduce::make(CommReducer combiner, Expr source, - Array axis, Expr condition) { +Expr Reduce::make(CommReducer combiner, Array source, + Array axis, Expr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; @@ -79,11 +90,12 @@ Expr Reduce::make(CommReducer combiner, Expr source, for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); } - n->type = source.type(); + n->type = source[0].type(); n->combiner = combiner; n->source = source; n->axis = axis; n->condition = condition; + n->value_index = value_index; return Expr(n); } diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 40145103e39e..13ecef4ba197 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -77,8 +77,6 @@ Operation ComputeOpNode::make(std::string name, n->axis = axis; n->body = body; if (n->body[0]->is_type()) { - CHECK_EQ(n->body.size(), 1) - << "Only support single reduction expression for now"; n->reduce_axis = n->body[0].as()->axis; } return Operation(n); @@ -174,23 +172,39 @@ Stmt ComputeOpNode::BuildRealize( // Build a reduction body. void MakeReduction(const ComputeOpNode* op, - const Tensor& t, - Stmt* init, - Stmt* provide) { + const Array& tensors, + std::vector* inits, + std::vector* provides) { + CHECK_EQ(inits->size(), 0); + CHECK_EQ(provides->size(), 0); Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } - const Reduce* reduce = op->body[t->value_index].as(); + + size_t size = op->body.size(); + const Reduce* reduce = op->body[0].as(); CHECK(reduce); const CommReducerNode* combiner = reduce->combiner.as(); CHECK(combiner); - Expr init_value = combiner->identity_element; - Expr update_value = (*combiner)(t(args), reduce->source); - *init = Provide::make(t->op, t->value_index, init_value, args); - *provide = Provide::make(t->op, t->value_index, update_value, args); + Array lhs; + for (size_t i = 0; i < size; ++i) { + lhs.push_back(tensors[i](args)); + } + Array init_value = combiner->identity_element; + Array update_value = (*combiner)(lhs, reduce->source); + for (size_t i = 0; i < size; ++i) { + Tensor t = tensors[i]; + inits->emplace_back(Provide::make( + t->op, t->value_index, init_value[i], args)); + provides->emplace_back(Provide::make( + t->op, t->value_index, update_value[i], args)); + } + if (!is_one(reduce->condition)) { - *provide = IfThenElse::make(reduce->condition, *provide); + for (size_t i = 0; i < size; ++i) { + provides->at(i) = IfThenElse::make(reduce->condition, provides->at(i)); + } } } @@ -262,7 +276,11 @@ Stmt MakeCrossThreadReduction( } Var res_handle("reduce_temp", Handle()); Array freduce_args; - freduce_args.push_back(reduce->source); + size_t size = reduce->source.size(); + freduce_args.push_back(make_const(UInt(32), size)); + for (size_t i = 0; i < size; ++i) { + freduce_args.push_back(reduce->source[i]); + } freduce_args.push_back(cond); for (IterVar iv : stage->leaf_iter_vars) { @@ -326,19 +344,19 @@ Stmt ComputeOpNode::BuildProvide( return MakeCrossThreadReduction(this, stage, dom_map); } + size_t size = this->body.size(); std::vector inits; std::vector provides; if (this->reduce_axis.size() == 0) { - for (int i = 0; i < this->num_outputs(); ++i) { - provides.push_back(MakeProvide(this, stage->op.output(i))); + for (size_t i = 0; i < size; ++i) { + provides.emplace_back(MakeProvide(this, stage->op.output(i))); } } else { - for (int i = 0; i < this->num_outputs(); ++i) { - Stmt init, provide; - MakeReduction(this, stage->op.output(i), &init, &provide); - inits.push_back(init); - provides.push_back(provide); + Array source; + for (size_t i = 0; i < size; ++i) { + source.push_back(stage->op.output(i)); } + MakeReduction(this, source, &inits, &provides); } // make loop nest diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 7e0cbd034e0b..5beccea7b666 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -313,7 +313,7 @@ DEFINE_BIOP_EXPR_MUTATE_(Or) Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { Array new_axis = MutateIterVarArr(op->axis, this); - Expr new_source = this->Mutate(op->source); + Array new_source = MutateArray(op->source, this); Expr new_cond = this->Mutate(op->condition); if (op->axis.same_as(new_axis) && op->source.same_as(new_source) && diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index bb1b3678e0d8..bae93f9d00b6 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or) void IRVisitor::Visit_(const Reduce* op) { VisitRDom(op->axis, this); - this->Visit(op->source); + VisitArray(op->source, this); } void IRVisitor::Visit_(const Cast* op) { diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 9de57d5b84fa..6391b1ae8437 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -100,15 +100,20 @@ class ThreadAllreduceBuilder : public IRMutator { Stmt MakeAllreduce(const Store* op, const Call* call) { CHECK(!reduce_combiner_.empty()); const CommReducerNode *combiner = reduce_combiner_.back(); - Expr init = combiner->identity_element; - Expr value = call->args[0]; - Expr cond = call->args[1]; + size_t size = combiner->result.size(); + CHECK_EQ(size, 1) + << "for now, only support single argument for allreduce"; + Expr init = combiner->identity_element[0]; + const UIntImm *size_of_args = call->args[0].as(); + CHECK(size_of_args) << call->args[0]->type_key(); + Expr value = call->args[1]; + Expr cond = call->args[size+1]; if (!is_one(cond)) { value = Select::make(cond, value, init); } std::unordered_set reduce_set; - for (size_t i = 2; i < call->args.size(); ++i) { + for (size_t i = size + 2; i < call->args.size(); ++i) { const Variable* v = call->args[i].as(); CHECK(v); reduce_set.insert(v); @@ -196,7 +201,7 @@ class ThreadAllreduceBuilder : public IRMutator { type, shared_buf, BufIndex(reduce_index + offset, group_index, reduce_extent), const_true()); Expr a = Load::make(type, shared_buf, buf_index, const_true()); - return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true()); + return Store::make(shared_buf, (*combiner)({a}, {b})[0], buf_index, const_true()); }; // Step one, check for if (reduce_align > reduce_extent) { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 8b6055b0a738..d8479a2d39b0 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -8,6 +8,7 @@ #include #include #include "./message_passing.h" +#include "../pass/ir_util.h" namespace tvm { @@ -366,8 +367,11 @@ Tensor Schedule::rfactor(const Tensor& tensor, n->reduce_axis.push_back(IterVar(ncpy)); } } + VarReplacer replacer(vsub); + std::function fupdate = + [&replacer] (Expr e) { return replacer.Mutate(e); }; n->body = {Reduce::make(reduce->combiner, - VarReplacer(vsub).Mutate(reduce->source), + ir::UpdateArray(reduce->source, fupdate), n->reduce_axis, predicate)}; // refresh relations, keep the un-touched relations. @@ -413,7 +417,7 @@ Tensor Schedule::rfactor(const Tensor& tensor, indices.push_back(v); } return Reduce::make(reduce->combiner, - factor_tensor(indices), {repl_red_axis}, const_true()); + {factor_tensor(indices)}, {repl_red_axis}, const_true()); }, old_tensor->op->name + ".repl"); std::unordered_map vmap; diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index bf3be27fd3bf..737c3e16d2b3 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -128,7 +128,55 @@ def check_target(device, host="stackvm"): check_target("metal") check_target("opencl") +def test_argmax(): + def fcombine(x, y): + lhs = tvm.make.Select((x[1] > y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] > y[1]), x[1], y[1]) + return [lhs, rhs] + + def fidentity(t0, t1): + return [tvm.const(-1, t0), tvm.min_value(t1)] + + argmax = tvm.comm_reducer(fcombine, + fidentity, + name='argmax') + dtype = 'int32' + m = tvm.var('m') + n = tvm.var('n') + idx = tvm.placeholder((m, n), name='idx', dtype=dtype) + val = tvm.placeholder((m, n), name='val', dtype=dtype) + k = tvm.reduce_axis((0, n), 'k') + T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T') + s = tvm.create_schedule(T0.op) + + def check_target(): + device = 'cpu' + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) + return + ctx = tvm.context(device, 0) + fapi = tvm.lower(s, args=[idx, val, T0, T1]) + fargmax = tvm.build(fapi, + target='llvm', + name="argmax") + + mm = 12 + nn = 16 + np_idx = np.repeat(np.arange(nn, dtype=dtype).reshape(1, nn), mm, axis=0) + np_val = np.random.randint(low=0, high=100, size=(mm, nn)).astype(dtype) + np_res = np.argmax(np_val, axis=1) + + nd_idx = tvm.nd.array(np_idx, ctx) + nd_val = tvm.nd.array(np_val, ctx) + nd_res0 = tvm.nd.array(np.zeros(mm, dtype=dtype), ctx) + nd_res1 = tvm.nd.array(np.zeros(mm, dtype=dtype), ctx) + fargmax(nd_idx, nd_val, nd_res0, nd_res1) + np.testing.assert_allclose(np_res, nd_res0.asnumpy()) + + check_target() + if __name__ == "__main__": test_rfactor_threads() test_rfactor() test_reduce_prims() + test_argmax() From 50d20be90ede9929563e1f768cf9c1423aef67d1 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 13:02:18 -0700 Subject: [PATCH 07/15] rfactor works with multiple reducer; support multiple reducers with different types --- include/tvm/operation.h | 12 +++ python/tvm/api.py | 4 +- src/lang/ir.cc | 12 ++- src/op/compute_op.cc | 104 ++++++++++++++++------ src/pass/ir_mutator.cc | 3 +- src/pass/ir_util.h | 11 ++- src/pass/lower_thread_allreduce.cc | 100 +++++++++++++-------- src/schedule/schedule_dataflow_rewrite.cc | 55 ++++++++---- tests/python/integration/test_reduce.py | 79 +++++++++++++--- 9 files changed, 279 insertions(+), 101 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 2941302168ad..0533bdcea6fb 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode { /*! \brief The compute function to specify the input source of a Tensor */ using FCompute = std::function& i)>; +/*! \brief The compute function to specify the inputs source of Tensors */ +using FBatchCompute = std::function (const Array& i)>; + /*! * \brief create a place holder tensor. * \param shape The shape of the tensor. @@ -377,6 +380,15 @@ Tensor placeholder(Array shape, */ Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor"); +/*! + * \brief Construct a new tensor by computing over shape, + * using the computation rule: result_tensor[axis] = fcompute(axis) + * \param shape Shape of the tensor. + * \param fcompute The compute function to create the tensors. + * \param name The optional name of the tensor. + */ +Array compute(Array shape, FBatchCompute fcompute, std::string name = "tensor"); + /*! * \brief Construct new tensors by scan. * diff --git a/python/tvm/api.py b/python/tvm/api.py index fda9be29340a..b41227a2693c 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -566,7 +566,9 @@ def _make_reduce(expr, axis, where=None): result = convert(result) id_elem = convert(id_elem) combiner = _make.CommReducer(lhs, rhs, result, id_elem) - axis = axis if isinstance(axis, list) else [axis] + axis = convert(axis if isinstance(axis, list) else [axis]) + if where is None: + where = convert(True) if size == 1: return _make.Reduce(combiner, expr, axis, where, 0) return [_make.Reduce(combiner, expr, axis, where, i) diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 89fc57af2751..c38f0edf76f7 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -29,9 +29,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) << op->combiner; p->stream << ", source=" << op->source; p->stream << ", axis=" << op->axis; - if (!is_const(op->condition, 1)) { - p->stream << ", where=" << op->condition; - } + p->stream << ", where=" << op->condition; + p->stream << ", value_index=" << op->value_index; p->stream << ")"; }); @@ -70,10 +69,9 @@ Array CommReducerNode::operator()(Array a, Array b) const { value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); } - std::function fupdate = [&value_map] (Expr e) { + return UpdateArray(result, [&value_map] (const Expr& e) { return Substitute(e, value_map); - }; - return UpdateArray(result, fupdate); + }); } Expr Reduce::make(CommReducer combiner, Array source, @@ -90,7 +88,7 @@ Expr Reduce::make(CommReducer combiner, Array source, for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); } - n->type = source[0].type(); + n->type = source[value_index].type(); n->combiner = combiner; n->source = source; n->axis = axis; diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 13ecef4ba197..e8a26ea32cf8 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -69,6 +69,35 @@ Tensor compute(Array shape, FCompute fcompute, std::string name) { return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0); } +Array compute(Array shape, FBatchCompute fcompute, std::string name) { + auto op_node = std::make_shared(); + // compute dimension. + size_t ndim = shape.size(); + std::vector axis; + std::vector args; + for (size_t i = 0; i < ndim; ++i) { + std::ostringstream os; + os << "ax" << i; + axis.emplace_back(IterVarNode::make( + Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar)); + args.push_back(axis.back()->var); + } + + Operation op = ComputeOpNode::make(name, axis, fcompute(args)); + Array outputs; + for (int idx = 0; idx < op->num_outputs(); ++idx) { + outputs.push_back(op.output(idx)); + } + return outputs; +} + +bool CheckReduce(const ir::Reduce* a, const ir::Reduce* b) { + return (a->combiner.same_as(b->combiner)) && + (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && + (a->condition.same_as(b->condition)); +} + Operation ComputeOpNode::make(std::string name, Array axis, Array body) { @@ -77,7 +106,15 @@ Operation ComputeOpNode::make(std::string name, n->axis = axis; n->body = body; if (n->body[0]->is_type()) { - n->reduce_axis = n->body[0].as()->axis; + const ir::Reduce* reduce = n->body[0].as(); + for (size_t i = 1; i < n->body.size(); ++i) { + const ir::Reduce* reduce_ = n->body[i].as(); + CHECK(reduce_); + CHECK(CheckReduce(reduce_, reduce)) + << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; + } + n->reduce_axis = reduce->axis; } return Operation(n); } @@ -105,10 +142,9 @@ Operation ComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - std::function fupdate = [&rmap] (Expr e) { + Array arr = UpdateArray(this->body, [&rmap] (const Expr& e) { return op::ReplaceTensor(e, rmap); - }; - Array arr = UpdateArray(this->body, fupdate); + }); if (!arr.same_as(this->body)) { return ComputeOpNode::make(name, axis, arr); } else { @@ -268,20 +304,30 @@ Stmt MakeCrossThreadReduction( auto conds = op::MakeBoundCheck( stage, dom_map, false, std::unordered_set(), value_map); - const Reduce* reduce = self->body[0].as(); - CHECK(reduce); - Expr cond = reduce->condition; + + size_t size = self->body.size(); + CHECK_GT(size, 0); + std::vector reduces(size); + for (size_t i = 0; i < size; ++i) { + const Reduce* reduce = self->body[i].as(); + CHECK(reduce); + reduces[i] = reduce; + } + Expr cond = reduces[0]->condition; for (Expr v : conds) { cond = cond && v; } - Var res_handle("reduce_temp", Handle()); Array freduce_args; - size_t size = reduce->source.size(); freduce_args.push_back(make_const(UInt(32), size)); for (size_t i = 0; i < size; ++i) { - freduce_args.push_back(reduce->source[i]); + freduce_args.push_back(reduces[0]->source[i]); } freduce_args.push_back(cond); + std::vector res_handles(size); + for (size_t idx = 0; idx < size; ++idx) { + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle()); + freduce_args.push_back(res_handles[idx]); + } for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { @@ -298,29 +344,33 @@ Stmt MakeCrossThreadReduction( if (stage->store_predicate.defined()) { thread_head_check.emplace_back(stage->store_predicate); } - Type t = reduce->type; - Expr pred = const_true(t.lanes()); - Stmt reduce_body = Store::make(res_handle, - Call::make( - reduce->type, + + Stmt reduce_body = Evaluate::make(Call::make( + Handle(), ir::intrinsic::tvm_thread_allreduce, - freduce_args, Call::Intrinsic), - 0, pred); + freduce_args, Call::Intrinsic)); reduce_body = AttrStmt::make( - reduce->combiner, + reduces[0]->combiner, attr::reduce_scope, - make_zero(reduce->type), + make_zero(Handle()), reduce_body); - Stmt assign_body = Provide::make( - stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args); - + std::vector assigns(size); + for (size_t idx = 0; idx < size; ++idx) { + Type t = reduces[idx]->type; + assigns[idx] = Provide::make( + stage->op, idx, + Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + } + Stmt assign_body = Block::make(assigns); assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Allocate::make( - res_handle, reduce->type, {1}, const_true(), - Block::make(reduce_body, assign_body)); - body = AttrStmt::make( - res_handle, attr::storage_scope, StringImm::make("local"), body); + Stmt body = Block::make(reduce_body, assign_body); + for (int idx = size - 1; idx >= 0; --idx) { + body = Allocate::make( + res_handles[idx], reduces[idx]->type, {1}, const_true(), body); + body = AttrStmt::make( + res_handles[idx], attr::storage_scope, StringImm::make("local"), body); + } body = Substitute(body, value_map); return MergeNest(nest, body); } diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 5beccea7b666..40a8e3de286d 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -18,8 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) } inline Array MutateArray(Array arr, IRMutator *m) { - std::function fupdate = [m] (Expr e) { return m->Mutate(e); }; - return UpdateArray(arr, fupdate); + return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); }); } inline Array MutateIterVarArr(Array rdom, IRMutator *m) { diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 495ae240c443..5f4ced64dd94 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -12,8 +12,15 @@ namespace tvm { namespace ir { -template -inline Array UpdateArray(Array arr, std::function fupdate) { +/*! + * \brief update array with an unary function + * \param arr array + * \param fupdate an unary function + * \return if update happens, return the new array, else return the + * original array + */ +template +inline Array UpdateArray(Array arr, F fupdate) { std::vector new_arr(arr.size()); bool changed = false; for (size_t i = 0; i < arr.size(); ++i) { diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 6391b1ae8437..18cd46b56053 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator { return IRMutator::Mutate_(op, s); } } - Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt Mutate_(const Evaluate* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); + op = stmt.as(); const Call* call = op->value.as(); if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { - return MakeAllreduce(op, call); + return MakeAllreduce(call); } else { return stmt; } @@ -97,23 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator { } }; // make allreduce. - Stmt MakeAllreduce(const Store* op, const Call* call) { + Stmt MakeAllreduce(const Call* call) { CHECK(!reduce_combiner_.empty()); const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - CHECK_EQ(size, 1) - << "for now, only support single argument for allreduce"; - Expr init = combiner->identity_element[0]; + const UIntImm *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->type_key(); - Expr value = call->args[1]; + CHECK_EQ(size, size_of_args->value); + Array inits = combiner->identity_element; + std::vector values(size); + std::vector types(size); Expr cond = call->args[size+1]; - if (!is_one(cond)) { - value = Select::make(cond, value, init); + for (size_t idx = 0; idx < size; ++idx) { + values[idx] = call->args[1+idx]; + if (!is_one(cond)) { + values[idx] = Select::make(cond, values[idx], inits[idx]); + } + types[idx] = values[idx].type(); + } + std::vector buffers(size); + for (size_t idx = 0; idx < size; ++idx) { + const Variable* buffer = call->args[2+size+idx].as(); + CHECK(buffer); + buffers[idx] = buffer; } std::unordered_set reduce_set; - for (size_t i = size + 2; i < call->args.size(); ++i) { + for (size_t i = 2+2*size; i < call->args.size(); ++i) { const Variable* v = call->args[i].as(); CHECK(v); reduce_set.insert(v); @@ -148,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator { int threadx_extent = 1; Expr reduce_index = FlattenThread(vred, &reduce_extent); Expr group_index = FlattenThread(vpar, &group_extent); - Expr pred = const_true(value.type().lanes()); if (reduce_extent == 1) { // special case, no reduction is needed. - return Store::make(op->buffer_var, value, 0, pred); + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + Expr pred = const_true(types[i].lanes()); + Var buffer_var(call->args[2+size+i].node_); + stores[i] = Store::make(buffer_var, values[i], 0, pred); + } + return Block::make(stores); } // Whether the threadIdx.x is involved in reduction. if (vred[0].scope.dim_index == 0) { threadx_extent = vred[0].extent; } - Var shared_buf("red_buf", Handle()); std::vector seq; - seq.emplace_back(Store::make( - shared_buf, value, - BufIndex(reduce_index, group_index, reduce_extent), pred)); + std::vector shared_bufs(size); + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle()); + Expr pred = const_true(types[idx].lanes()); + seq.emplace_back(Store::make( + shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); + } seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce( - combiner, value.type(), shared_buf, + combiner, types, shared_bufs, reduce_index, group_index, reduce_extent, threadx_extent)); - CHECK(!load_remap_.count(op->buffer_var.get())); - load_remap_[op->buffer_var.get()] = - Load::make( - value.type(), shared_buf, - BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), - pred); - alloc_remap_[op->buffer_var.get()] = - Allocate::make(shared_buf, value.type(), - {Expr(group_extent), Expr(reduce_extent)}, - pred, Evaluate::make(0)); + for (size_t idx = 0; idx < size; ++idx) { + CHECK(!load_remap_.count(buffers[idx])); + Expr pred = const_true(types[idx].lanes()); + load_remap_[buffers[idx]] = Load::make( + types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred); + alloc_remap_[buffers[idx]] = Allocate::make( + shared_bufs[idx], types[idx], + {Expr(group_extent), Expr(reduce_extent)}, + pred, Evaluate::make(0)); + } return MergeSeq(seq); } // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, - Type type, - Var shared_buf, + const std::vector& types, + Array shared_bufs, Expr reduce_index, Expr group_index, int reduce_extent, @@ -194,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator { CHECK_GT(reduce_align, 1); std::vector seq; + size_t size = shared_bufs.size(); Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent); // make reduction auto freduce = [&](int offset) { - Expr b = Load::make( - type, shared_buf, - BufIndex(reduce_index + offset, group_index, reduce_extent), const_true()); - Expr a = Load::make(type, shared_buf, buf_index, const_true()); - return Store::make(shared_buf, (*combiner)({a}, {b})[0], buf_index, const_true()); + Array a, b; + for (size_t i = 0; i < size; ++i) { + b.push_back(Load::make(types[i], shared_bufs[i], + BufIndex(reduce_index + offset, group_index, reduce_extent), + const_true())); + a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true())); + } + Array ret = (*combiner)(a, b); + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true()); + } + return Block::make(stores); }; // Step one, check for if (reduce_align > reduce_extent) { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index d8479a2d39b0..a5ed6862f9cd 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -283,7 +283,7 @@ Tensor Schedule::rfactor(const Tensor& tensor, << "Can only factor reduction axis"; Stage reduce_stage = operator[](tensor->op); const ComputeOpNode* compute_op = reduce_stage->op.as(); - CHECK(compute_op) << "Can only factor ComputeOp"; + CHECK(compute_op) << "Can only factor ComputeOp"; ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); { size_t axis_pos = FindNodeRef(leaf_vars, axis); @@ -368,12 +368,17 @@ Tensor Schedule::rfactor(const Tensor& tensor, } } VarReplacer replacer(vsub); - std::function fupdate = - [&replacer] (Expr e) { return replacer.Mutate(e); }; - n->body = {Reduce::make(reduce->combiner, - ir::UpdateArray(reduce->source, fupdate), - n->reduce_axis, - predicate)}; + Array new_source = ir::UpdateArray(reduce->source, + [&replacer] (const Expr& e) { return replacer.Mutate(e); }); + std::vector body; + for (size_t idx = 0; idx < reduce->source.size(); ++idx) { + body.emplace_back(Reduce::make(reduce->combiner, + new_source, + n->reduce_axis, + predicate, + idx)); + } + n->body = Array(body); // refresh relations, keep the un-touched relations. Array rels; for (IterVarRelation rel : reduce_stage->relations) { @@ -408,26 +413,42 @@ Tensor Schedule::rfactor(const Tensor& tensor, // Replace the old reduction. IterVar repl_red_axis = reduce_axis( dom_map.at(axis), axis->var->name_hint + ".v"); - Tensor factor_tensor = factor_op.output(0); - Tensor old_tensor = reduce_stage->op.output(0); - Tensor repl_tensor = compute(old_tensor->shape, [&](const Array& i) { + Array factor_tensors; + Array old_tensors; + int size = factor_op->num_outputs(); + for (int idx = 0; idx < size; ++idx) { + factor_tensors.push_back(factor_op.output(idx)); + old_tensors.push_back(reduce_stage->op.output(idx)); + } + Array repl_tensors = compute(old_tensors[0]->shape, + [&](const Array& i) { Array indices; indices.push_back(repl_red_axis->var); for (Var v : i) { indices.push_back(v); } - return Reduce::make(reduce->combiner, - {factor_tensor(indices)}, {repl_red_axis}, const_true()); - }, old_tensor->op->name + ".repl"); + Array factor_exprs; + for (int idx = 0; idx < size; ++idx) { + factor_exprs.push_back(factor_tensors[idx](indices)); + } + Array reductions; + for (int idx = 0; idx < size; ++idx) { + reductions.push_back(Reduce::make(reduce->combiner, + factor_exprs, {repl_red_axis}, const_true(), idx)); + } + return reductions; + }, reduce_stage->op->name + ".repl"); std::unordered_map vmap; - vmap[old_tensor] = repl_tensor; + for (int idx = 0; idx < size; ++idx) { + vmap[old_tensors[idx]] = repl_tensors[idx]; + } ReplaceDataFlow((*this)->stages, &vmap); // revamp the reduction stage. - reduce_stage->op = repl_tensor->op; - reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars(); + reduce_stage->op = repl_tensors[0]->op; + reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars(); reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; reduce_stage->relations = Array(); - return factor_tensor; + return factor_tensors[0]; } } // namespace tvm diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 737c3e16d2b3..1347060ce8c1 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -49,7 +49,6 @@ def check_device(device, host="stackvm"): test_prim(tvm.max, np.amax) - def test_rfactor(): n = tvm.convert(1027) A = tvm.placeholder((n,), name='A') @@ -130,8 +129,8 @@ def check_target(device, host="stackvm"): def test_argmax(): def fcombine(x, y): - lhs = tvm.make.Select((x[1] > y[1]), x[0], y[0]) - rhs = tvm.make.Select((x[1] > y[1]), x[1], y[1]) + lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) return [lhs, rhs] def fidentity(t0, t1): @@ -140,11 +139,10 @@ def fidentity(t0, t1): argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') - dtype = 'int32' m = tvm.var('m') n = tvm.var('n') - idx = tvm.placeholder((m, n), name='idx', dtype=dtype) - val = tvm.placeholder((m, n), name='val', dtype=dtype) + idx = tvm.placeholder((m, n), name='idx', dtype='int32') + val = tvm.placeholder((m, n), name='val', dtype='float32') k = tvm.reduce_axis((0, n), 'k') T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T') s = tvm.create_schedule(T0.op) @@ -162,21 +160,82 @@ def check_target(): mm = 12 nn = 16 - np_idx = np.repeat(np.arange(nn, dtype=dtype).reshape(1, nn), mm, axis=0) - np_val = np.random.randint(low=0, high=100, size=(mm, nn)).astype(dtype) + np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) + np_val = np.random.randint(low=0, high=100, size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) nd_val = tvm.nd.array(np_val, ctx) - nd_res0 = tvm.nd.array(np.zeros(mm, dtype=dtype), ctx) - nd_res1 = tvm.nd.array(np.zeros(mm, dtype=dtype), ctx) + nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) + nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) fargmax(nd_idx, nd_val, nd_res0, nd_res1) np.testing.assert_allclose(np_res, nd_res0.asnumpy()) check_target() + +def test_rfactor_argmax(): + def fcombine(x, y): + lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + return [lhs, rhs] + + def fidentity(t0, t1): + return [tvm.const(-1, t0), tvm.min_value(t1)] + + argmax = tvm.comm_reducer(fcombine, + fidentity, + name='argmax') + + nn = 1027 + mm = 10 + n = tvm.convert(nn) + m = tvm.convert(mm) + A0 = tvm.placeholder((m, n), name='A0', dtype='int32') + A1 = tvm.placeholder((m, n), name='A1', dtype='float32') + k = tvm.reduce_axis((0, n)) + B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B') + + # schedule + s = tvm.create_schedule(B0.op) + nthread = 16 + ko, kf = s[B0].split(k, factor=nthread) + BF = s.rfactor(B0, kf) + bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread) + s[B0].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B0].bind(ty, tvm.thread_axis("threadIdx.y")) + tx = s[B0].op.reduce_axis[0] + thread_x = tvm.thread_axis("threadIdx.x") + s[B0].bind(tx, thread_x) + s[BF].compute_at(s[B0], tx) + s[B0].set_store_predicate(thread_x.var.equal(0)) + + def check_target(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) + return + ctx = tvm.context(device, 0) + fapi = tvm.lower(s, args=[A0, A1, B0, B1]) + fargmax = tvm.build(fapi, + target=device, + name="argmax") + + np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) + np_val = np.random.randint(low=0, high=100000, size=(mm, nn)).astype('float32') + np_res = np.argmax(np_val, axis=1) + + nd_idx = tvm.nd.array(np_idx, ctx) + nd_val = tvm.nd.array(np_val, ctx) + nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) + nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) + fargmax(nd_idx, nd_val, nd_res0, nd_res1) + np.testing.assert_allclose(np_res, nd_res0.asnumpy()) + + check_target("cuda") + if __name__ == "__main__": test_rfactor_threads() test_rfactor() test_reduce_prims() test_argmax() + test_rfactor_argmax() From 371a6c00e421ed9d2feda1854b9fea84b9f44214 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 13:28:28 -0700 Subject: [PATCH 08/15] Small fix --- src/schedule/schedule_dataflow_rewrite.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index a5ed6862f9cd..94ee7681c978 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -432,9 +432,11 @@ Tensor Schedule::rfactor(const Tensor& tensor, factor_exprs.push_back(factor_tensors[idx](indices)); } Array reductions; + Array axis = {repl_red_axis}; + Expr cond = const_true(); for (int idx = 0; idx < size; ++idx) { reductions.push_back(Reduce::make(reduce->combiner, - factor_exprs, {repl_red_axis}, const_true(), idx)); + factor_exprs, axis, cond, idx)); } return reductions; }, reduce_stage->op->name + ".repl"); From 7b3ba6179322f511a1e5f6b1aa1650b8b3d957fe Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 13:57:07 -0700 Subject: [PATCH 09/15] Small fix --- src/op/compute_op.cc | 4 ++-- tests/python/integration/test_reduce.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index e8a26ea32cf8..62646a631b71 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -211,8 +211,8 @@ void MakeReduction(const ComputeOpNode* op, const Array& tensors, std::vector* inits, std::vector* provides) { - CHECK_EQ(inits->size(), 0); - CHECK_EQ(provides->size(), 0); + CHECK(inits->empty()); + CHECK(provides->empty()); Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 1347060ce8c1..aff35c469ebe 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -161,7 +161,7 @@ def check_target(): mm = 12 nn = 16 np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) - np_val = np.random.randint(low=0, high=100, size=(mm, nn)).astype('float32') + np_val = np.random.uniform(size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) @@ -221,7 +221,7 @@ def check_target(device): name="argmax") np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) - np_val = np.random.randint(low=0, high=100000, size=(mm, nn)).astype('float32') + np_val = np.random.uniform(size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) From 0fa045273a02bff16e0f6359d9be2458e1b66f53 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 17:27:58 -0700 Subject: [PATCH 10/15] Change return type of rfactor to Array --- include/tvm/ir.h | 4 +- include/tvm/schedule.h | 6 +-- python/tvm/schedule.py | 7 +++- src/lang/expr.cc | 6 +-- src/op/compute_op.cc | 51 +++++++++-------------- src/pass/ir_mutator.cc | 3 +- src/pass/ir_util.h | 2 + src/pass/lower_thread_allreduce.cc | 2 +- src/schedule/graph.cc | 2 +- src/schedule/schedule_dataflow_rewrite.cc | 6 +-- tests/python/integration/test_reduce.py | 12 +++--- tests/python/unittest/test_lang_tensor.py | 24 +++++++++++ 12 files changed, 71 insertions(+), 54 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index edaee8428c0a..dafb3e09d528 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -103,8 +103,8 @@ struct Reduce : public ExprNode { static Expr make(CommReducer combiner, Array src, Array rdom, - Expr condition = const_true(), - int value_index = 0); + Expr condition, + int value_index); void VisitAttrs(AttrVisitor* v) final { v->Visit("dtype", &type); diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 654200abeb08..9f8a4bd51f2f 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -257,10 +257,10 @@ class Schedule : public NodeRef { * * \param tensor The tensor to be factored. * \param axis The reduction axis in tensor's schedule to be factored. - * \return The created factored tensor. + * \return The created factored tensors. */ - Tensor rfactor(const Tensor& tensor, - const IterVar& axis); + Array rfactor(const Tensor& tensor, + const IterVar& axis); /*! * \brief Normalize the schedule. * This is needed before bound inference. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index d5baada62366..5425375c82d9 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -193,10 +193,13 @@ def rfactor(self, tensor, axis): Returns ------- - tfactor : Tensor + tfactor : Tensor or Array The created factored tensor. """ - return _api_internal._ScheduleRFactor(self, tensor, axis) + factored = _api_internal._ScheduleRFactor(self, tensor, axis) + if len(factored) == 1: + return factored[0] + return factored @register_node diff --git a/src/lang/expr.cc b/src/lang/expr.cc index b0e33fc94721..9e0feb44479f 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -52,7 +52,7 @@ Expr sum(Expr source, Array rdom) { Expr identity_element = make_zero(source.type()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true)); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } Expr max(Expr source, Array rdom) { @@ -61,7 +61,7 @@ Expr max(Expr source, Array rdom) { Expr identity_element = source.type().min(); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true)); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } Expr min(Expr source, Array rdom) { @@ -70,7 +70,7 @@ Expr min(Expr source, Array rdom) { Expr identity_element = source.type().max(); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true)); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 62646a631b71..6baa38ad7f66 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -209,14 +209,13 @@ Stmt ComputeOpNode::BuildRealize( // Build a reduction body. void MakeReduction(const ComputeOpNode* op, const Array& tensors, - std::vector* inits, - std::vector* provides) { - CHECK(inits->empty()); - CHECK(provides->empty()); + Stmt* init, + Stmt* provide) { Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } + std::vector inits, provides; size_t size = op->body.size(); const Reduce* reduce = op->body[0].as(); @@ -231,16 +230,15 @@ void MakeReduction(const ComputeOpNode* op, Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; - inits->emplace_back(Provide::make( + inits.emplace_back(Provide::make( t->op, t->value_index, init_value[i], args)); - provides->emplace_back(Provide::make( + provides.emplace_back(Provide::make( t->op, t->value_index, update_value[i], args)); } - + *init = Block::make(inits); + *provide = Block::make(provides); if (!is_one(reduce->condition)) { - for (size_t i = 0; i < size; ++i) { - provides->at(i) = IfThenElse::make(reduce->condition, provides->at(i)); - } + *provide = IfThenElse::make(reduce->condition, *provide); } } @@ -253,19 +251,6 @@ Stmt Substitute(Stmt s, return ir::Substitute(s, temp); } -std::vector Substitute(std::vector stmt, - const std::unordered_map& value_map) { - Map temp; - for (const auto& kv : value_map) { - temp.Set(kv.first->var, kv.second); - } - std::vector ret; - for (auto& s : stmt) { - ret.push_back(ir::Substitute(s, temp)); - } - return ret; -} - // Cross Thread reduction marker. bool IsCrossThreadReduction(const ComputeOpNode* self, const Stage& stage) { @@ -395,18 +380,20 @@ Stmt ComputeOpNode::BuildProvide( } size_t size = this->body.size(); - std::vector inits; - std::vector provides; + Stmt init; + Stmt provide; if (this->reduce_axis.size() == 0) { + std::vector provides; for (size_t i = 0; i < size; ++i) { provides.emplace_back(MakeProvide(this, stage->op.output(i))); } + provide = Block::make(provides); } else { Array source; for (size_t i = 0; i < size; ++i) { source.push_back(stage->op.output(i)); } - MakeReduction(this, source, &inits, &provides); + MakeReduction(this, source, &init, &provide); } // make loop nest @@ -420,9 +407,9 @@ Stmt ComputeOpNode::BuildProvide( if (stage->store_predicate.defined()) { nest.emplace_back(op::MakeIfNest({stage->store_predicate})); } - provides = Substitute(provides, value_map); + provide = Substitute(provide, value_map); - if (!inits.empty()) { + if (init.defined()) { // try to find the location to insert the initialization. // Fuse the initialization and provide loop when possible. std::unordered_map update_state; @@ -458,15 +445,15 @@ Stmt ComputeOpNode::BuildProvide( auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map); for (auto& e : preds) e = likely(e); init_nest.push_back(op::MakeIfNest(preds)); - inits = Substitute(inits, init_value_map); - Stmt init = MergeNest(init_nest, Block::make(inits)); + init = Substitute(init, init_value_map); + init = MergeNest(init_nest, init); // common nest std::vector > common(nest.begin(), nest.begin() + begin_loop + 1); std::vector > reduce(nest.begin() + begin_loop + 1, nest.end()); - Stmt provide = MergeNest(reduce, Block::make(provides)); + provide = MergeNest(reduce, provide); return MergeNest(common, Block::make(init, provide)); } else { - return MergeNest(nest, Block::make(provides)); + return MergeNest(nest, provide); } } } // namespace tvm diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 40a8e3de286d..b12f6648dffc 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -319,7 +319,8 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { op->condition.same_as(new_cond)) { return e; } else { - return Reduce::make(op->combiner, new_source, new_axis, new_cond); + return Reduce::make( + op->combiner, new_source, new_axis, new_cond, op->value_index); } } diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 5f4ced64dd94..472b408e32d5 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -16,6 +16,8 @@ namespace ir { * \brief update array with an unary function * \param arr array * \param fupdate an unary function + * \tparam T type of array element + * \tparam F type of the unary function * \return if update happens, return the new array, else return the * original array */ diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 3e2de17d356b..9763cc4b68bc 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -124,7 +124,7 @@ class ThreadAllreduceBuilder : public IRMutator { } std::unordered_set reduce_set; - for (size_t i = 2+2*size; i < call->args.size(); ++i) { + for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { const Variable* v = call->args[i].as(); CHECK(v); reduce_set.insert(v); diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index a01bde30de89..da0aeb0eccaa 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -323,7 +323,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (op.as()) { - std::unordered_map> vmap; + std::unordered_map > vmap; const auto& axis = op.as()->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 94ee7681c978..d24ba17bf6e3 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -275,8 +275,8 @@ Schedule Schedule::normalize() { } // Handle reduction factor. -Tensor Schedule::rfactor(const Tensor& tensor, - const IterVar& axis) { +Array Schedule::rfactor(const Tensor& tensor, + const IterVar& axis) { (*this)->InvalidateCache(); using ir::Reduce; CHECK_EQ(axis->iter_type, kCommReduce) @@ -451,6 +451,6 @@ Tensor Schedule::rfactor(const Tensor& tensor, reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars(); reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; reduce_stage->relations = Array(); - return factor_tensors[0]; + return factor_tensors; } } // namespace tvm diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index aff35c469ebe..ffc4b79b58d2 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -131,10 +131,10 @@ def test_argmax(): def fcombine(x, y): lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) - return [lhs, rhs] + return lhs, rhs def fidentity(t0, t1): - return [tvm.const(-1, t0), tvm.min_value(t1)] + return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, @@ -178,10 +178,10 @@ def test_rfactor_argmax(): def fcombine(x, y): lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) - return [lhs, rhs] + return lhs, rhs def fidentity(t0, t1): - return [tvm.const(-1, t0), tvm.min_value(t1)] + return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, @@ -200,14 +200,14 @@ def fidentity(t0, t1): s = tvm.create_schedule(B0.op) nthread = 16 ko, kf = s[B0].split(k, factor=nthread) - BF = s.rfactor(B0, kf) + BF0, BF1 = s.rfactor(B0, kf) bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread) s[B0].bind(bx, tvm.thread_axis("blockIdx.x")) s[B0].bind(ty, tvm.thread_axis("threadIdx.y")) tx = s[B0].op.reduce_axis[0] thread_x = tvm.thread_axis("threadIdx.x") s[B0].bind(tx, thread_x) - s[BF].compute_at(s[B0], tx) + s[BF0.op].compute_at(s[B0], tx) s[B0].set_store_predicate(thread_x.var.equal(0)) def check_target(device): diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 7d0f4ffadbb9..89000aede8f6 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -132,6 +132,29 @@ def test_multi_inputs_outputs(): assert(T0.value_index == 0) assert(T1.value_index == 1) +def test_multi_different_deps(): + m = tvm.var('m') + n = tvm.var('n') + A0 = tvm.placeholder((m, n), name='A1') + A1 = tvm.placeholder((m, n), name='A2') + B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B') + C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C') + + s = tvm.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=10) + s[B0.op].compute_at(s[C], xo) + sch = s.normalize() + bounds = tvm.schedule.InferBound(sch) + stmt = tvm.schedule.ScheduleOps(sch, bounds) + + def is_B1_realize(x): + if isinstance(x, tvm.stmt.Realize) and \ + x.func == B1.op and x.value_index == 1: + ret.append(x) + ret = [] + tvm.ir_pass.PostOrderVisit(stmt, is_B1_realize) + + assert stmt.node == C.op and len(ret) == 1 if __name__ == "__main__": test_conv1d() @@ -143,3 +166,4 @@ def test_multi_inputs_outputs(): test_extern() test_extern_multi_out() test_multi_inputs_outputs() + test_multi_different_deps() From 3739cb7080f22c7f3bc634465857f5fe2ae6b8bc Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 17:29:37 -0700 Subject: [PATCH 11/15] Fix lint --- python/tvm/schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 5425375c82d9..b3d622e8e039 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -196,7 +196,7 @@ def rfactor(self, tensor, axis): tfactor : Tensor or Array The created factored tensor. """ - factored = _api_internal._ScheduleRFactor(self, tensor, axis) + factored = _api_internal._ScheduleRFactor(self, tensor, axis) if len(factored) == 1: return factored[0] return factored From 0e7258eac6286b08574c7c31900983304c583c85 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 18:39:42 -0700 Subject: [PATCH 12/15] Improve --- include/tvm/ir.h | 9 +++++---- python/tvm/api.py | 15 +++++---------- python/tvm/schedule.py | 6 ++---- src/lang/ir.cc | 6 +++--- src/op/compute_op.cc | 10 +++++----- src/pass/lower_thread_allreduce.cc | 2 +- 6 files changed, 21 insertions(+), 27 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index dafb3e09d528..834b1baf364f 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -300,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; /*! * \brief See pesudo code * - * Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond, - * Var thread_idx1, thread_idx2...) { + * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, + * Var reduce_temp0, .., Var thread_idx1, ...) { * // constraint by the other thread_idx remain the same. - * return reduce(combiner, value, cond, - * over [thread_idx1, thread_idx2] passed by any caller) + * // reduce_temp is used to save intermediate result. + * reduce_temp0, ... = reduce(combiner, source0, ..., cond + * over [thread_idx1, thread_idx2] passed by any caller) * } */ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; diff --git a/python/tvm/api.py b/python/tvm/api.py index b41227a2693c..7ea6a8e81e6b 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -179,13 +179,9 @@ def compute(shape, fcompute, name="compute"): body = convert(body) op_node = _api_internal._ComputeOp( name, dim_var, body) - outputs = [] num = op_node.num_outputs - if num == 1: - return op_node.output(0) - for i in range(num): - outputs.append(op_node.output(i)) - return tuple(outputs) + outputs = tuple(op_node.output(i) for i in range(num)) + return outputs[0] if num == 1 else outputs def scan(init, update, state_placeholder, inputs=None, name="scan"): @@ -569,10 +565,9 @@ def _make_reduce(expr, axis, where=None): axis = convert(axis if isinstance(axis, list) else [axis]) if where is None: where = convert(True) - if size == 1: - return _make.Reduce(combiner, expr, axis, where, 0) - return [_make.Reduce(combiner, expr, axis, where, i) - for i in range(size)] + outputs = tuple(_make.Reduce(combiner, expr, axis, where, i) + for i in range(size)) + return outputs[0] if size == 1 else outputs def reducer(expr, axis, where=None, *args): if isinstance(axis, (_schedule.IterVar, list)): diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index b3d622e8e039..e9c8a179c95f 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -193,13 +193,11 @@ def rfactor(self, tensor, axis): Returns ------- - tfactor : Tensor or Array + tfactor : Tensor or Array of Tensor The created factored tensor. """ factored = _api_internal._ScheduleRFactor(self, tensor, axis) - if len(factored) == 1: - return factored[0] - return factored + return factored[0] if len(factored) == 1 else factored @register_node diff --git a/src/lang/ir.cc b/src/lang/ir.cc index c38f0edf76f7..e7903333562f 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -89,9 +89,9 @@ Expr Reduce::make(CommReducer combiner, Array source, CHECK(axis[i].defined()); } n->type = source[value_index].type(); - n->combiner = combiner; - n->source = source; - n->axis = axis; + n->combiner = std::move(combiner); + n->source = std::move(source); + n->axis = std::move(axis); n->condition = condition; n->value_index = value_index; return Expr(n); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 6baa38ad7f66..be594a6b6e4a 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -91,7 +91,7 @@ Array compute(Array shape, FBatchCompute fcompute, std::string nam return outputs; } -bool CheckReduce(const ir::Reduce* a, const ir::Reduce* b) { +bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && (a->axis.same_as(b->axis)) && @@ -110,7 +110,7 @@ Operation ComputeOpNode::make(std::string name, for (size_t i = 1; i < n->body.size(); ++i) { const ir::Reduce* reduce_ = n->body[i].as(); CHECK(reduce_); - CHECK(CheckReduce(reduce_, reduce)) + CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " << "have the same attribute except value_index"; } @@ -350,11 +350,11 @@ Stmt MakeCrossThreadReduction( assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(conds), assign_body); Stmt body = Block::make(reduce_body, assign_body); - for (int idx = size - 1; idx >= 0; --idx) { + for (size_t idx = size; idx != 0; --idx) { body = Allocate::make( - res_handles[idx], reduces[idx]->type, {1}, const_true(), body); + res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body); body = AttrStmt::make( - res_handles[idx], attr::storage_scope, StringImm::make("local"), body); + res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body); } body = Substitute(body, value_map); return MergeNest(nest, body); diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 9763cc4b68bc..1e59723d59d5 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -202,7 +202,7 @@ class ThreadAllreduceBuilder : public IRMutator { // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, const std::vector& types, - Array shared_bufs, + const Array& shared_bufs, Expr reduce_index, Expr group_index, int reduce_extent, From 2c42e95ae071f20029f2fccc7d971ac31171b28d Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 22:10:22 -0700 Subject: [PATCH 13/15] Add tutorial --- tests/python/unittest/test_lang_tensor.py | 16 ++-- tutorials/python/tuple_inputs_operation.py | 91 ++++++++++++++++++++++ 2 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 tutorials/python/tuple_inputs_operation.py diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 89000aede8f6..9160baec3789 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -118,11 +118,11 @@ def extern_func(ins, outs): assert(len(res) == 2) assert(res[1].value_index == 1) -def test_multi_inputs_outputs(): +def test_tuple_inputs(): m = tvm.var('m') n = tvm.var('n') - A0 = tvm.placeholder((m, n), name='A1') - A1 = tvm.placeholder((m, n), name='A2') + A0 = tvm.placeholder((m, n), name='A0') + A1 = tvm.placeholder((m, n), name='A1') T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T') s = tvm.create_schedule(T0.op) @@ -132,7 +132,7 @@ def test_multi_inputs_outputs(): assert(T0.value_index == 0) assert(T1.value_index == 1) -def test_multi_different_deps(): +def test_tuple_with_different_deps(): m = tvm.var('m') n = tvm.var('n') A0 = tvm.placeholder((m, n), name='A1') @@ -147,12 +147,12 @@ def test_multi_different_deps(): bounds = tvm.schedule.InferBound(sch) stmt = tvm.schedule.ScheduleOps(sch, bounds) - def is_B1_realize(x): + def get_B1_realize(x): if isinstance(x, tvm.stmt.Realize) and \ x.func == B1.op and x.value_index == 1: ret.append(x) ret = [] - tvm.ir_pass.PostOrderVisit(stmt, is_B1_realize) + tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize) assert stmt.node == C.op and len(ret) == 1 @@ -165,5 +165,5 @@ def is_B1_realize(x): test_scan_multi_out() test_extern() test_extern_multi_out() - test_multi_inputs_outputs() - test_multi_different_deps() + test_tuple_inputs() + test_tuple_with_different_deps() diff --git a/tutorials/python/tuple_inputs_operation.py b/tutorials/python/tuple_inputs_operation.py new file mode 100644 index 000000000000..b89bb46f8601 --- /dev/null +++ b/tutorials/python/tuple_inputs_operation.py @@ -0,0 +1,91 @@ +""" +Operation with Tuple Inputs +=========================== +**Author**: `Ziheng Jiang `_ + +In this tutorial, we will introduce the usage of tuple input in TVM. +""" +from __future__ import absolute_import, print_function + +import tvm +import numpy as np + +###################################################################### +# Describe Batchwise Computation +# ------------------------------ +# For operators which have the same shape, we can put them together as +# the input of :any:`tvm.compute`, if we wish they can be scheduled +# together in the next schedule procedure. +# +n = tvm.var("n") +m = tvm.var("m") +A0 = tvm.placeholder((m, n), name='A0') +A1 = tvm.placeholder((m, n), name='A1') +B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B') + +# The generated IR code would be: +s = tvm.create_schedule(B0.op) +print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True)) + +###################################################################### +# Describe Reduction with Collaborative Inputs +# -------------------------------------------- +# Sometimes, we requires multiple inputs to express some reduction +# operators, and the inputs will collaborate together, e.g. :code:`argmax`. +# In the reduction procedure, :code:`argmax` need to compare the value of +# operands, also need to keep the index of operand. It can be expressed +# with :any:`comm_reducer` as below: + +# x and y are the operands of reduction, both of them is a tuple of index +# and value. +def fcombine(x, y): + lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + +# our identity element also need to be a tuple, so `fidentity` accepts +# two types as inputs. +def fidentity(t0, t1): + return tvm.const(-1, t0), tvm.min_value(t1) + +argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') + +# describe the reduction computation +m = tvm.var('m') +n = tvm.var('n') +idx = tvm.placeholder((m, n), name='idx', dtype='int32') +val = tvm.placeholder((m, n), name='val', dtype='int32') +k = tvm.reduce_axis((0, n), 'k') +T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T') + +# the generated IR code would be: +s = tvm.create_schedule(T0.op) +print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True)) + +###################################################################### +# Schedule Operation with Tuple Inputs +# ------------------------------------ +# It is worth mentioning that although you will get multiple outputs +# with one batch operation, but they can only be scheduled together +# in terms of operation. + +n = tvm.var("n") +m = tvm.var("m") +A0 = tvm.placeholder((m, n), name='A0') +B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B') +A1 = tvm.placeholder((m, n), name='A1') +C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C') + +s = tvm.create_schedule(C.op) +s[B0].compute_at(s[C], C.op.axis[0]) +# as you can see in the below generated IR code: +print(tvm.lower(s, [A0, A1, C], simple_mode=True)) + +###################################################################### +# Summary +# ------- +# This tutorial introduces the usage of tuple inputs operation. +# +# - Describe normal batchwise computation. +# - Describe reduction operation with tuple inputs. +# - Notice that you can only schedule computation in terms of operation instead of tensor. From 8491288c325fb022e75fa1ee0d341b9dbf932e01 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 22:58:49 -0700 Subject: [PATCH 14/15] Improve tutorial --- tutorials/python/reduction.py | 2 ++ tutorials/python/tuple_inputs_operation.py | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tutorials/python/reduction.py b/tutorials/python/reduction.py index 1bdc1f9b8e75..69d86596704b 100644 --- a/tutorials/python/reduction.py +++ b/tutorials/python/reduction.py @@ -125,6 +125,8 @@ b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) ###################################################################### +# .. _general-reduction: +# # Define General Commutative Reduction Operation # ---------------------------------------------- # Besides the built-in reduction operations like :any:`tvm.sum`, diff --git a/tutorials/python/tuple_inputs_operation.py b/tutorials/python/tuple_inputs_operation.py index b89bb46f8601..fb5fc6a8d214 100644 --- a/tutorials/python/tuple_inputs_operation.py +++ b/tutorials/python/tuple_inputs_operation.py @@ -1,9 +1,13 @@ """ -Operation with Tuple Inputs -=========================== +Compute and Reduction with Tuple Inputs +======================================= **Author**: `Ziheng Jiang `_ -In this tutorial, we will introduce the usage of tuple input in TVM. +Often we want to compute multiple outputs with the same shape within +a single loop or perform reduction that involves multiple values like +:code:`argmax`. These problems can be addressed by tuple inputs. + +In this tutorial, we will introduce the usage of tuple inputs in TVM. """ from __future__ import absolute_import, print_function @@ -14,7 +18,7 @@ # Describe Batchwise Computation # ------------------------------ # For operators which have the same shape, we can put them together as -# the input of :any:`tvm.compute`, if we wish they can be scheduled +# the inputs of :any:`tvm.compute`, if we wish they can be scheduled # together in the next schedule procedure. # n = tvm.var("n") @@ -62,6 +66,12 @@ def fidentity(t0, t1): s = tvm.create_schedule(T0.op) print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True)) +###################################################################### +# .. note:: +# +# For ones who are not familiar with reduction, please refer to +# :ref:`general-reduction`. + ###################################################################### # Schedule Operation with Tuple Inputs # ------------------------------------ From a3c596ba7f9fdb89a8bce4ebd913149fce02548e Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 10 Jun 2017 23:17:51 -0700 Subject: [PATCH 15/15] Improve tutorial --- tutorials/python/reduction.py | 6 ++++++ tutorials/python/tuple_inputs_operation.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/tutorials/python/reduction.py b/tutorials/python/reduction.py index 69d86596704b..e7295cb927a3 100644 --- a/tutorials/python/reduction.py +++ b/tutorials/python/reduction.py @@ -142,6 +142,12 @@ k = tvm.reduce_axis((0, m), name='k') B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B') +###################################################################### +# .. note:: +# +# Sometimes we would like to perform reduction that involves multiple +# values like :code:`argmax`, which can be done by tuple inputs. +# See :ref:`reduction-with-tuple-inputs` for more detail. ###################################################################### # Summary diff --git a/tutorials/python/tuple_inputs_operation.py b/tutorials/python/tuple_inputs_operation.py index fb5fc6a8d214..8c101a59e86e 100644 --- a/tutorials/python/tuple_inputs_operation.py +++ b/tutorials/python/tuple_inputs_operation.py @@ -32,6 +32,8 @@ print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True)) ###################################################################### +# .. _reduction-with-tuple-inputs: +# # Describe Reduction with Collaborative Inputs # -------------------------------------------- # Sometimes, we requires multiple inputs to express some reduction